diff --git a/src/framework.rs b/src/framework.rs index 1b43f85..fba244a 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -2,7 +2,12 @@ use async_trait::async_trait; use serenity::{ client::Context, - framework::Framework, + framework::{ + Framework, + standard::{ + Args, CommandFn, + }, + }, model::{ guild::Guild, channel::{ @@ -27,7 +32,6 @@ use std::{ fmt, }; -use serenity::framework::standard::CommandFn; use crate::SQLPool; #[derive(Debug)] @@ -60,6 +64,7 @@ impl fmt::Debug for Command { pub struct RegexFramework { commands: HashMap, regex_matcher: Regex, + dm_regex_matcher: Regex, default_prefix: String, client_id: u64, ignore_bots: bool, @@ -70,6 +75,7 @@ impl RegexFramework { Self { commands: HashMap::new(), regex_matcher: Regex::new(r#"^$"#).unwrap(), + dm_regex_matcher: Regex::new(r#"^$"#).unwrap(), default_prefix: String::from("$"), client_id, ignore_bots: true, @@ -95,26 +101,59 @@ impl RegexFramework { } pub fn build(mut self) -> Self { - let command_names; - { - let mut command_names_vec = self.commands - .keys() - .map(|k| &k[..]) - .collect::>(); + let command_names; - command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len())); + { + let mut command_names_vec = self.commands + .keys() + .map(|k| &k[..]) + .collect::>(); - command_names = command_names_vec.join("|"); + command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len())); + + command_names = command_names_vec.join("|"); + } + + info!("Command names: {}", command_names); + + { + let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P\S{1,5}?))(?PCOMMANDS)(?:$|\s+(?P.*))$"# + .replace("COMMANDS", command_names.as_str()) + .replace("ID", self.client_id.to_string().as_str()); + + self.regex_matcher = Regex::new(match_string.as_str()).unwrap(); + } } - info!("Command names: {}", command_names); + { + let dm_command_names; - let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P\S{1,5}?))(?PCOMMANDS)(?:$|\s+(?P.*))$"# - .replace("COMMANDS", command_names.as_str()) - .replace("ID", self.client_id.to_string().as_str()); + { + let mut command_names_vec = self.commands + .iter() + .filter_map(|(key, command)| { + if command.supports_dm { + Some(&key[..]) + } else { + None + } + }) + .collect::>(); - self.regex_matcher = Regex::new(match_string.as_str()).unwrap(); + command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len())); + + dm_command_names = command_names_vec.join("|"); + } + + { + let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(\$)|())(?PCOMMANDS)(?:$|\s+(?P.*))$"# + .replace("COMMANDS", dm_command_names.as_str()) + .replace("ID", self.client_id.to_string().as_str()); + + self.dm_regex_matcher = Regex::new(match_string.as_str()).unwrap(); + } + } self } @@ -130,7 +169,7 @@ enum PermissionCheck { impl Framework for RegexFramework { async fn dispatch(&self, ctx: Context, msg: Message) { - async fn check_self_permissions(ctx: &Context, guild: &Guild, channel: &GuildChannel) -> Result> { + async fn check_self_permissions(ctx: &Context, guild: &Guild, channel: &GuildChannel) -> Result> { let user_id = ctx.cache.current_user_id().await; let guild_perms = guild.member_permissions(user_id); @@ -149,12 +188,12 @@ impl Framework for RegexFramework { }) } - async fn check_prefix(ctx: &Context, guild_id: u64, prefix_opt: Option>) -> bool { + async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option>) -> bool { if let Some(prefix) = prefix_opt { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - match sqlx::query!("SELECT prefix FROM guilds WHERE id = ?", guild_id) + match sqlx::query!("SELECT prefix FROM guilds WHERE id = ?", guild.id.as_u64()) .fetch_one(&pool) .await { Ok(row) => { @@ -162,6 +201,10 @@ impl Framework for RegexFramework { } Err(sqlx::Error::RowNotFound) => { + sqlx::query!("INSERT INTO guilds (guild, name) VALUES (?, ?)", guild.id.as_u64(), guild.name) + .execute(&pool) + .await; + prefix.as_str() == "$" } @@ -191,15 +234,27 @@ impl Framework for RegexFramework { if let Some(full_match) = self.regex_matcher.captures(&msg.content[..]) { - if check_prefix(&ctx, *guild.id.as_u64(), full_match.name("prefix")).await { + if check_prefix(&ctx, &guild, full_match.name("prefix")).await { debug!("Prefix matched on {}", msg.content); match check_self_permissions(&ctx, &guild, &channel).await { Ok(perms) => match perms { - PermissionCheck::All => {} + PermissionCheck::All => { + let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap(); + let args = Args::new( + full_match.name("args") + .map(|m| m.as_str()) + .unwrap_or(""), + &[] + ); - PermissionCheck::Basic => {} + (command.func)(&ctx, &msg, args).await; + } + + PermissionCheck::Basic => { + msg.channel_id.say(&ctx, "Not enough perms").await; + } PermissionCheck::None => { warn!("Missing enough permissions for guild {}", guild.id); @@ -216,7 +271,17 @@ impl Framework for RegexFramework { // DM Command else { + if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { + let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap(); + let args = Args::new( + full_match.name("args") + .map(|m| m.as_str()) + .unwrap_or(""), + &[] + ); + (command.func)(&ctx, &msg, args).await; + } } } }