diff --git a/Cargo.lock b/Cargo.lock index 5c27e11..59a8936 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1093,6 +1093,7 @@ dependencies = [ "chrono", "chrono-tz", "dotenv", + "lazy_static", "log", "regex", "regex_command_attr", diff --git a/Cargo.toml b/Cargo.toml index 9b117df..777b9a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ async-trait = "0.1.36" log = "0.4.11" chrono = "0.4" chrono-tz = "0.5" +lazy_static = "1.4.0" [dependencies.regex_command_attr] path = "./regex_command_attr" diff --git a/src/commands/info_cmds.rs b/src/commands/info_cmds.rs index 9df491a..7d2f93e 100644 --- a/src/commands/info_cmds.rs +++ b/src/commands/info_cmds.rs @@ -14,6 +14,7 @@ use crate::THEME_COLOR; #[command] +#[can_blacklist(false)] async fn help(ctx: &Context, msg: &Message, _args: String) -> CommandResult { msg.channel_id.send_message(ctx, |m| m .embed(|e| e diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index f110821..f2c4e88 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -10,11 +10,17 @@ use serenity::{ framework::standard::CommandResult, }; +use regex::Regex; + use crate::{ models::ChannelData, SQLPool, }; +lazy_static! { + static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); +} + #[command] #[supports_dm(false)] #[permission_level(Restricted)] @@ -23,7 +29,15 @@ async fn blacklist(ctx: &Context, msg: &Message, args: String) -> CommandResult let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), pool.clone()).await.unwrap(); + let capture_opt = REGEX_CHANNEL.captures(&args).map(|cap| cap.get(1)).flatten(); + + let mut channel = match capture_opt { + Some(capture) => + ChannelData::from_id(capture.as_str().parse::().unwrap(), pool.clone()).await.unwrap(), + + None => + ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), pool.clone()).await.unwrap(), + }; channel.blacklisted = !channel.blacklisted; channel.commit_changes(pool).await; diff --git a/src/framework.rs b/src/framework.rs index 712ca05..6a1a79d 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -320,7 +320,7 @@ impl Framework for RegexFramework { let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap(); let channel_data = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), pool).await; - if !command.can_blacklist || channel_data.map(|c| c.blacklisted).unwrap_or(false) { + if !command.can_blacklist || !channel_data.map(|c| c.blacklisted).unwrap_or(false) { let args = full_match.name("args") .map(|m| m.as_str()) .unwrap_or("") diff --git a/src/main.rs b/src/main.rs index 4c22000..d1c46dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,6 @@ +#[macro_use] +extern crate lazy_static; + mod models; mod framework; mod commands; diff --git a/src/models.rs b/src/models.rs index 1251feb..94873f5 100644 --- a/src/models.rs +++ b/src/models.rs @@ -59,6 +59,15 @@ SELECT id, guild, name, prefix FROM guilds WHERE guild = ? } impl ChannelData { + pub async fn from_id(channel_id: u64, pool: MySqlPool) -> Option { + sqlx::query_as_unchecked!(Self, + " +SELECT * FROM channels WHERE channel = ? + ", channel_id) + .fetch_one(&pool) + .await.ok() + } + pub async fn from_channel(channel: Channel, pool: MySqlPool) -> Result> { let channel_id = channel.id().as_u64().clone();