diff --git a/.idea/dictionaries/jude.xml b/.idea/dictionaries/jude.xml index 864a310..38d72d3 100644 --- a/.idea/dictionaries/jude.xml +++ b/.idea/dictionaries/jude.xml @@ -2,6 +2,7 @@ reqwest + webhooks \ No newline at end of file diff --git a/src/framework.rs b/src/framework.rs index fba244a..46a030f 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -9,7 +9,10 @@ use serenity::{ }, }, model::{ - guild::Guild, + guild::{ + Guild, + Member, + }, channel::{ Channel, GuildChannel, Message, } @@ -33,6 +36,7 @@ use std::{ }; use crate::SQLPool; +use serenity::model::id::RoleId; #[derive(Debug)] pub enum PermissionLevel { @@ -49,6 +53,70 @@ pub struct Command { pub func: CommandFn, } +impl Command { + async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool { + + guild.member_permissions(&member.user).manage_guild() || match self.required_perms { + PermissionLevel::Unrestricted => true, + + PermissionLevel::Managed => { + let pool = ctx.data.read().await + .get::().cloned().expect("Could not get SQLPool from data"); + + match sqlx::query!(" +SELECT + role +FROM + roles +INNER JOIN + command_restrictions ON roles.id = command_restrictions.role_id +WHERE + command_restrictions.command = ? AND + command_restrictions.guild_id = ( + SELECT + id + FROM + guilds + WHERE + guild = ? + )", self.name, guild.id.as_u64()) + .fetch_all(&pool) + .await { + + Ok(rows) => { + + let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::>(); + + for row in rows { + if role_ids.contains(&row.role) { + return true + } + } + + false + } + + Err(sqlx::Error::RowNotFound) => { + false + } + + Err(e) => { + warn!("Unexpected error occurred querying command_restrictions: {:?}", e); + + false + } + + } + } + + PermissionLevel::Restricted => { + false + } + } + + } +} + impl fmt::Debug for Command { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Command") @@ -232,6 +300,8 @@ impl Framework for RegexFramework { // Guild Command else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) { + let member = guild.member(&ctx, &msg.author).await.unwrap(); + if let Some(full_match) = self.regex_matcher.captures(&msg.content[..]) { if check_prefix(&ctx, &guild, full_match.name("prefix")).await { @@ -249,7 +319,9 @@ impl Framework for RegexFramework { &[] ); - (command.func)(&ctx, &msg, args).await; + if command.check_permissions(&ctx, &guild, &member).await { + (command.func)(&ctx, &msg, args).await; + } } PermissionCheck::Basic => { diff --git a/src/main.rs b/src/main.rs index 7d80ddd..a56ca3d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -56,8 +56,7 @@ async fn main() -> Result<(), Box> { let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?) .ignore_bots(true) .default_prefix("$") - .add_command("help".to_string(), &HELP_COMMAND) - .add_command("h".to_string(), &HELP_COMMAND) + .add_command("look".to_string(), &LOOK_COMMAND) .build(); let mut client = Client::new(&env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment")) @@ -80,7 +79,9 @@ async fn main() -> Result<(), Box> { } #[command] -async fn help(_ctx: &Context, _msg: &Message, _args: Args) -> CommandResult { +#[permission_level(Managed)] +#[supports_dm(false)] +async fn look(_ctx: &Context, _msg: &Message, _args: Args) -> CommandResult { println!("Help command called"); Ok(())