diff --git a/src/framework.rs b/src/framework.rs index 328d528..1b43f85 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -14,9 +14,13 @@ use serenity::{ use log::{ warn, error, + debug, + info, }; -use regex::Regex; +use regex::{ + Regex, Match +}; use std::{ collections::HashMap, @@ -24,6 +28,7 @@ use std::{ }; use serenity::framework::standard::CommandFn; +use crate::SQLPool; #[derive(Debug)] pub enum PermissionLevel { @@ -90,11 +95,20 @@ impl RegexFramework { } pub fn build(mut self) -> Self { - let command_names = self.commands - .keys() - .map(|k| &k[..]) - .collect::>() - .join("|"); + let command_names; + + { + let mut command_names_vec = self.commands + .keys() + .map(|k| &k[..]) + .collect::>(); + + command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len())); + + command_names = command_names_vec.join("|"); + } + + info!("Command names: {}", command_names); let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P\S{1,5}?))(?PCOMMANDS)(?:$|\s+(?P.*))$"# .replace("COMMANDS", command_names.as_str()) @@ -135,6 +149,34 @@ impl Framework for RegexFramework { }) } + async fn check_prefix(ctx: &Context, guild_id: u64, prefix_opt: Option>) -> bool { + if let Some(prefix) = prefix_opt { + let pool = ctx.data.read().await + .get::().cloned().expect("Could not get SQLPool from data"); + + match sqlx::query!("SELECT prefix FROM guilds WHERE id = ?", guild_id) + .fetch_one(&pool) + .await { + Ok(row) => { + prefix.as_str() == row.prefix + } + + Err(sqlx::Error::RowNotFound) => { + prefix.as_str() == "$" + } + + Err(e) => { + warn!("Unexpected error in prefix query: {:?}", e); + + false + } + } + } + else { + true + } + } + // gate to prevent analysing messages unnecessarily if (msg.author.bot && self.ignore_bots) || msg.tts || @@ -147,25 +189,26 @@ impl Framework for RegexFramework { // Guild Command else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) { - if let Some(full_match) = self.regex_matcher.captures(msg.content.as_str()) { + if let Some(full_match) = self.regex_matcher.captures(&msg.content[..]) { - match check_self_permissions(&ctx, &guild, &channel).await { - Ok(perms) => match perms { - PermissionCheck::All => { + if check_prefix(&ctx, *guild.id.as_u64(), full_match.name("prefix")).await { + debug!("Prefix matched on {}", msg.content); + + match check_self_permissions(&ctx, &guild, &channel).await { + Ok(perms) => match perms { + PermissionCheck::All => {} + + PermissionCheck::Basic => {} + + PermissionCheck::None => { + warn!("Missing enough permissions for guild {}", guild.id); + } } - PermissionCheck::Basic => { - + Err(e) => { + error!("Error occurred getting permissions in guild {}: {:?}", guild.id, e); } - - PermissionCheck::None => { - warn!("Missing enough permissions for guild {}", guild.id); - } - } - - Err(e) => { - error!("Error occurred getting permissions in guild {}: {:?}", guild.id, e); } } } diff --git a/src/main.rs b/src/main.rs index dc4b7d8..7d80ddd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,7 @@ use regex_command_attr::command; use sqlx::{ Pool, mysql::{ + MySqlPool, MySqlConnection, } }; @@ -46,14 +47,12 @@ impl TypeMapKey for ReqwestClient { type Value = Arc; } -static THEME_COLOR: u32 = 0x00e0f3; +static THEME_COLOR: u32 = 0x8fb677; #[tokio::main] async fn main() -> Result<(), Box> { dotenv()?; - println!("{:?}", HELP_COMMAND); - let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?) .ignore_bots(true) .default_prefix("$") @@ -66,6 +65,15 @@ async fn main() -> Result<(), Box> { .framework(framework) .await.expect("Error occurred creating client"); + { + let pool = MySqlPool::new(&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment")).await.unwrap(); + + let mut data = client.data.write().await; + + data.insert::(pool); + data.insert::(Arc::new(reqwest::Client::new())); + } + client.start_autosharded().await?; Ok(())