From a0da4dcf0081391ecde759c17ce49ba5e25532df Mon Sep 17 00:00:00 2001 From: jellywx Date: Thu, 13 May 2021 18:50:22 +0100 Subject: [PATCH] ratelimit commands --- src/framework.rs | 288 +++++++++++++++++++++++++++-------------------- src/main.rs | 10 +- 2 files changed, 177 insertions(+), 121 deletions(-) diff --git a/src/framework.rs b/src/framework.rs index 2dc62e0..465235e 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -21,7 +21,8 @@ use std::{collections::HashMap, fmt}; use crate::language_manager::LanguageManager; use crate::models::{CtxGuildData, GuildData, UserData}; -use crate::{models::ChannelData, SQLPool}; +use crate::{models::ChannelData, CurrentlyExecuting, SQLPool}; +use std::time::Duration; type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, ()>; @@ -345,144 +346,191 @@ impl Framework for RegexFramework { // gate to prevent analysing messages unnecessarily if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() { - } - // Guild Command - else if let (Some(guild), Some(Channel::Guild(channel))) = - (msg.guild(&ctx).await, msg.channel(&ctx).await) - { - let data = ctx.data.read().await; - - let pool = data - .get::() + } else { + let currently_executing = ctx + .data + .read() + .await + .get::() .cloned() - .expect("Could not get SQLPool from data"); + .unwrap(); - if let Some(full_match) = self.command_matcher.captures(&msg.content) { - if check_prefix(&ctx, &guild, full_match.name("prefix")).await { - let lm = data.get::().unwrap(); + let user_is_executing; - let language = UserData::language_of(&msg.author, &pool); + { + let mut lock = currently_executing.lock().unwrap(); - match check_self_permissions(&ctx, &guild, &channel).await { - Ok(perms) => match perms { - PermissionCheck::All => { - let command = self - .commands - .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) - .unwrap(); + user_is_executing = lock.contains(&msg.author.id); + lock.insert(msg.author.id); + } - let channel_data = ChannelData::from_channel( - msg.channel(&ctx).await.unwrap(), - &pool, - ) - .await - .unwrap(); + if !user_is_executing { + // Guild Command + if let (Some(guild), Some(Channel::Guild(channel))) = + (msg.guild(&ctx).await, msg.channel(&ctx).await) + { + let data = ctx.data.read().await; - if !command.can_blacklist || !channel_data.blacklisted { - let args = full_match - .name("args") - .map(|m| m.as_str()) - .unwrap_or("") - .to_string(); + let pool = data + .get::() + .cloned() + .expect("Could not get SQLPool from data"); - let member = guild.member(&ctx, &msg.author).await.unwrap(); + if let Some(full_match) = self.command_matcher.captures(&msg.content) { + if check_prefix(&ctx, &guild, full_match.name("prefix")).await { + let lm = data.get::().unwrap(); - if command.check_permissions(&ctx, &guild, &member).await { - dbg!(command.name); + let language = UserData::language_of(&msg.author, &pool); - { - let guild_id = guild.id.as_u64().to_owned(); + match check_self_permissions(&ctx, &guild, &channel).await { + Ok(perms) => match perms { + PermissionCheck::All => { + let command = self + .commands + .get( + &full_match + .name("cmd") + .unwrap() + .as_str() + .to_lowercase(), + ) + .unwrap(); - GuildData::from_guild(guild, &pool).await.expect( - &format!( - "Failed to create new guild object for {}", - guild_id - ), - ); + let channel_data = ChannelData::from_channel( + msg.channel(&ctx).await.unwrap(), + &pool, + ) + .await + .unwrap(); + + if !command.can_blacklist || !channel_data.blacklisted { + let args = full_match + .name("args") + .map(|m| m.as_str()) + .unwrap_or("") + .to_string(); + + let member = + guild.member(&ctx, &msg.author).await.unwrap(); + + if command + .check_permissions(&ctx, &guild, &member) + .await + { + dbg!(command.name); + + { + let guild_id = guild.id.as_u64().to_owned(); + + GuildData::from_guild(guild, &pool) + .await + .expect(&format!( + "Failed to create new guild object for {}", + guild_id + )); + } + + (command.func)(&ctx, &msg, args).await; + } else if command.required_perms + == PermissionLevel::Restricted + { + let _ = msg + .channel_id + .say( + &ctx, + lm.get( + &language.await, + "no_perms_restricted", + ), + ) + .await; + } else if command.required_perms + == PermissionLevel::Managed + { + let _ = msg + .channel_id + .say( + &ctx, + lm.get(&language.await, "no_perms_managed") + .replace( + "{prefix}", + &ctx.prefix(msg.guild_id).await, + ), + ) + .await; + } } - - (command.func)(&ctx, &msg, args).await; - } else if command.required_perms == PermissionLevel::Restricted - { - let _ = msg - .channel_id - .say( - &ctx, - lm.get(&language.await, "no_perms_restricted"), - ) - .await; - } else if command.required_perms == PermissionLevel::Managed { - let _ = msg - .channel_id - .say( - &ctx, - lm.get(&language.await, "no_perms_managed") - .replace( - "{prefix}", - &ctx.prefix(msg.guild_id).await, - ), - ) - .await; } + + PermissionCheck::Basic( + manage_webhooks, + embed_links, + add_reactions, + manage_messages, + ) => { + let response = lm + .get(&language.await, "no_perms_general") + .replace( + "{manage_webhooks}", + if manage_webhooks { "✅" } else { "❌" }, + ) + .replace( + "{embed_links}", + if embed_links { "✅" } else { "❌" }, + ) + .replace( + "{add_reactions}", + if add_reactions { "✅" } else { "❌" }, + ) + .replace( + "{manage_messages}", + if manage_messages { "✅" } else { "❌" }, + ); + + let _ = msg.channel_id.say(&ctx, response).await; + } + + PermissionCheck::None => { + warn!("Missing enough permissions for guild {}", guild.id); + } + }, + + Err(e) => { + error!( + "Error occurred getting permissions in guild {}: {:?}", + guild.id, e + ); } } - - PermissionCheck::Basic( - manage_webhooks, - embed_links, - add_reactions, - manage_messages, - ) => { - let response = lm - .get(&language.await, "no_perms_general") - .replace( - "{manage_webhooks}", - if manage_webhooks { "✅" } else { "❌" }, - ) - .replace("{embed_links}", if embed_links { "✅" } else { "❌" }) - .replace( - "{add_reactions}", - if add_reactions { "✅" } else { "❌" }, - ) - .replace( - "{manage_messages}", - if manage_messages { "✅" } else { "❌" }, - ); - - let _ = msg.channel_id.say(&ctx, response).await; - } - - PermissionCheck::None => { - warn!("Missing enough permissions for guild {}", guild.id); - } - }, - - Err(e) => { - error!( - "Error occurred getting permissions in guild {}: {:?}", - guild.id, e - ); } } } - } - } - // DM Command - else if self.dm_enabled { - if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { - let command = self - .commands - .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) - .unwrap(); - let args = full_match - .name("args") - .map(|m| m.as_str()) - .unwrap_or("") - .to_string(); + // DM Command + else if self.dm_enabled { + if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { + let command = self + .commands + .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) + .unwrap(); + let args = full_match + .name("args") + .map(|m| m.as_str()) + .unwrap_or("") + .to_string(); - dbg!(command.name); + dbg!(command.name); - (command.func)(&ctx, &msg, args).await; + (command.func)(&ctx, &msg, args).await; + } + } + + { + // wait 500 ms before allowing the user to execute a command again + tokio::time::sleep(Duration::from_millis(500)).await; + + let mut lock = currently_executing.lock().unwrap(); + lock.remove(&msg.author.id); + } } } } diff --git a/src/main.rs b/src/main.rs index ad03cb5..5fa3b83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,6 +47,8 @@ use dashmap::DashMap; use tokio::sync::RwLock; use chrono_tz::Tz; +use std::collections::HashSet; +use std::sync::Mutex; struct GuildDataCache; @@ -78,6 +80,12 @@ impl TypeMapKey for PopularTimezones { type Value = Arc>; } +struct CurrentlyExecuting; + +impl TypeMapKey for CurrentlyExecuting { + type Value = Arc>>; +} + struct Handler; #[async_trait] @@ -309,7 +317,7 @@ async fn main() -> Result<(), Box> { let mut data = client.data.write().await; data.insert::(Arc::new(guild_data_cache)); - + data.insert::(Arc::new(Mutex::new(HashSet::new()))); data.insert::(pool); data.insert::(Arc::new(popular_timezones)); data.insert::(Arc::new(reqwest::Client::new()));