ratelimit commands

This commit is contained in:
jellywx 2021-05-13 18:50:22 +01:00
parent 0a9624d12d
commit a0da4dcf00
2 changed files with 177 additions and 121 deletions

View File

@ -21,7 +21,8 @@ use std::{collections::HashMap, fmt};
use crate::language_manager::LanguageManager; use crate::language_manager::LanguageManager;
use crate::models::{CtxGuildData, GuildData, UserData}; use crate::models::{CtxGuildData, GuildData, UserData};
use crate::{models::ChannelData, SQLPool}; use crate::{models::ChannelData, CurrentlyExecuting, SQLPool};
use std::time::Duration;
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, ()>; type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, ()>;
@ -345,144 +346,191 @@ impl Framework for RegexFramework {
// gate to prevent analysing messages unnecessarily // gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() { if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() {
} } else {
// Guild Command let currently_executing = ctx
else if let (Some(guild), Some(Channel::Guild(channel))) = .data
(msg.guild(&ctx).await, msg.channel(&ctx).await) .read()
{ .await
let data = ctx.data.read().await; .get::<CurrentlyExecuting>()
let pool = data
.get::<SQLPool>()
.cloned() .cloned()
.expect("Could not get SQLPool from data"); .unwrap();
if let Some(full_match) = self.command_matcher.captures(&msg.content) { let user_is_executing;
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
let lm = data.get::<LanguageManager>().unwrap();
let language = UserData::language_of(&msg.author, &pool); {
let mut lock = currently_executing.lock().unwrap();
match check_self_permissions(&ctx, &guild, &channel).await { user_is_executing = lock.contains(&msg.author.id);
Ok(perms) => match perms { lock.insert(msg.author.id);
PermissionCheck::All => { }
let command = self
.commands
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.unwrap();
let channel_data = ChannelData::from_channel( if !user_is_executing {
msg.channel(&ctx).await.unwrap(), // Guild Command
&pool, if let (Some(guild), Some(Channel::Guild(channel))) =
) (msg.guild(&ctx).await, msg.channel(&ctx).await)
.await {
.unwrap(); let data = ctx.data.read().await;
if !command.can_blacklist || !channel_data.blacklisted { let pool = data
let args = full_match .get::<SQLPool>()
.name("args") .cloned()
.map(|m| m.as_str()) .expect("Could not get SQLPool from data");
.unwrap_or("")
.to_string();
let member = guild.member(&ctx, &msg.author).await.unwrap(); if let Some(full_match) = self.command_matcher.captures(&msg.content) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
let lm = data.get::<LanguageManager>().unwrap();
if command.check_permissions(&ctx, &guild, &member).await { let language = UserData::language_of(&msg.author, &pool);
dbg!(command.name);
{ match check_self_permissions(&ctx, &guild, &channel).await {
let guild_id = guild.id.as_u64().to_owned(); Ok(perms) => match perms {
PermissionCheck::All => {
let command = self
.commands
.get(
&full_match
.name("cmd")
.unwrap()
.as_str()
.to_lowercase(),
)
.unwrap();
GuildData::from_guild(guild, &pool).await.expect( let channel_data = ChannelData::from_channel(
&format!( msg.channel(&ctx).await.unwrap(),
"Failed to create new guild object for {}", &pool,
guild_id )
), .await
); .unwrap();
if !command.can_blacklist || !channel_data.blacklisted {
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
let member =
guild.member(&ctx, &msg.author).await.unwrap();
if command
.check_permissions(&ctx, &guild, &member)
.await
{
dbg!(command.name);
{
let guild_id = guild.id.as_u64().to_owned();
GuildData::from_guild(guild, &pool)
.await
.expect(&format!(
"Failed to create new guild object for {}",
guild_id
));
}
(command.func)(&ctx, &msg, args).await;
} else if command.required_perms
== PermissionLevel::Restricted
{
let _ = msg
.channel_id
.say(
&ctx,
lm.get(
&language.await,
"no_perms_restricted",
),
)
.await;
} else if command.required_perms
== PermissionLevel::Managed
{
let _ = msg
.channel_id
.say(
&ctx,
lm.get(&language.await, "no_perms_managed")
.replace(
"{prefix}",
&ctx.prefix(msg.guild_id).await,
),
)
.await;
}
} }
(command.func)(&ctx, &msg, args).await;
} else if command.required_perms == PermissionLevel::Restricted
{
let _ = msg
.channel_id
.say(
&ctx,
lm.get(&language.await, "no_perms_restricted"),
)
.await;
} else if command.required_perms == PermissionLevel::Managed {
let _ = msg
.channel_id
.say(
&ctx,
lm.get(&language.await, "no_perms_managed")
.replace(
"{prefix}",
&ctx.prefix(msg.guild_id).await,
),
)
.await;
} }
PermissionCheck::Basic(
manage_webhooks,
embed_links,
add_reactions,
manage_messages,
) => {
let response = lm
.get(&language.await, "no_perms_general")
.replace(
"{manage_webhooks}",
if manage_webhooks { "" } else { "" },
)
.replace(
"{embed_links}",
if embed_links { "" } else { "" },
)
.replace(
"{add_reactions}",
if add_reactions { "" } else { "" },
)
.replace(
"{manage_messages}",
if manage_messages { "" } else { "" },
);
let _ = msg.channel_id.say(&ctx, response).await;
}
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
},
Err(e) => {
error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
} }
} }
PermissionCheck::Basic(
manage_webhooks,
embed_links,
add_reactions,
manage_messages,
) => {
let response = lm
.get(&language.await, "no_perms_general")
.replace(
"{manage_webhooks}",
if manage_webhooks { "" } else { "" },
)
.replace("{embed_links}", if embed_links { "" } else { "" })
.replace(
"{add_reactions}",
if add_reactions { "" } else { "" },
)
.replace(
"{manage_messages}",
if manage_messages { "" } else { "" },
);
let _ = msg.channel_id.say(&ctx, response).await;
}
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
},
Err(e) => {
error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
} }
} }
} }
} // DM Command
} else if self.dm_enabled {
// DM Command if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
else if self.dm_enabled { let command = self
if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { .commands
let command = self .get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.commands .unwrap();
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) let args = full_match
.unwrap(); .name("args")
let args = full_match .map(|m| m.as_str())
.name("args") .unwrap_or("")
.map(|m| m.as_str()) .to_string();
.unwrap_or("")
.to_string();
dbg!(command.name); dbg!(command.name);
(command.func)(&ctx, &msg, args).await; (command.func)(&ctx, &msg, args).await;
}
}
{
// wait 500 ms before allowing the user to execute a command again
tokio::time::sleep(Duration::from_millis(500)).await;
let mut lock = currently_executing.lock().unwrap();
lock.remove(&msg.author.id);
}
} }
} }
} }

View File

@ -47,6 +47,8 @@ use dashmap::DashMap;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use chrono_tz::Tz; use chrono_tz::Tz;
use std::collections::HashSet;
use std::sync::Mutex;
struct GuildDataCache; struct GuildDataCache;
@ -78,6 +80,12 @@ impl TypeMapKey for PopularTimezones {
type Value = Arc<Vec<Tz>>; type Value = Arc<Vec<Tz>>;
} }
struct CurrentlyExecuting;
impl TypeMapKey for CurrentlyExecuting {
type Value = Arc<Mutex<HashSet<UserId>>>;
}
struct Handler; struct Handler;
#[async_trait] #[async_trait]
@ -309,7 +317,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut data = client.data.write().await; let mut data = client.data.write().await;
data.insert::<GuildDataCache>(Arc::new(guild_data_cache)); data.insert::<GuildDataCache>(Arc::new(guild_data_cache));
data.insert::<CurrentlyExecuting>(Arc::new(Mutex::new(HashSet::new())));
data.insert::<SQLPool>(pool); data.insert::<SQLPool>(pool);
data.insert::<PopularTimezones>(Arc::new(popular_timezones)); data.insert::<PopularTimezones>(Arc::new(popular_timezones));
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new())); data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));