diff --git a/.idea/dictionaries/jude.xml b/.idea/dictionaries/jude.xml
index 864a310..38d72d3 100644
--- a/.idea/dictionaries/jude.xml
+++ b/.idea/dictionaries/jude.xml
@@ -2,6 +2,7 @@
reqwest
+ webhooks
\ No newline at end of file
diff --git a/src/framework.rs b/src/framework.rs
index fba244a..46a030f 100644
--- a/src/framework.rs
+++ b/src/framework.rs
@@ -9,7 +9,10 @@ use serenity::{
},
},
model::{
- guild::Guild,
+ guild::{
+ Guild,
+ Member,
+ },
channel::{
Channel, GuildChannel, Message,
}
@@ -33,6 +36,7 @@ use std::{
};
use crate::SQLPool;
+use serenity::model::id::RoleId;
#[derive(Debug)]
pub enum PermissionLevel {
@@ -49,6 +53,70 @@ pub struct Command {
pub func: CommandFn,
}
+impl Command {
+ async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
+
+ guild.member_permissions(&member.user).manage_guild() || match self.required_perms {
+ PermissionLevel::Unrestricted => true,
+
+ PermissionLevel::Managed => {
+ let pool = ctx.data.read().await
+ .get::().cloned().expect("Could not get SQLPool from data");
+
+ match sqlx::query!("
+SELECT
+ role
+FROM
+ roles
+INNER JOIN
+ command_restrictions ON roles.id = command_restrictions.role_id
+WHERE
+ command_restrictions.command = ? AND
+ command_restrictions.guild_id = (
+ SELECT
+ id
+ FROM
+ guilds
+ WHERE
+ guild = ?
+ )", self.name, guild.id.as_u64())
+ .fetch_all(&pool)
+ .await {
+
+ Ok(rows) => {
+
+ let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::>();
+
+ for row in rows {
+ if role_ids.contains(&row.role) {
+ return true
+ }
+ }
+
+ false
+ }
+
+ Err(sqlx::Error::RowNotFound) => {
+ false
+ }
+
+ Err(e) => {
+ warn!("Unexpected error occurred querying command_restrictions: {:?}", e);
+
+ false
+ }
+
+ }
+ }
+
+ PermissionLevel::Restricted => {
+ false
+ }
+ }
+
+ }
+}
+
impl fmt::Debug for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Command")
@@ -232,6 +300,8 @@ impl Framework for RegexFramework {
// Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) {
+ let member = guild.member(&ctx, &msg.author).await.unwrap();
+
if let Some(full_match) = self.regex_matcher.captures(&msg.content[..]) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
@@ -249,7 +319,9 @@ impl Framework for RegexFramework {
&[]
);
- (command.func)(&ctx, &msg, args).await;
+ if command.check_permissions(&ctx, &guild, &member).await {
+ (command.func)(&ctx, &msg, args).await;
+ }
}
PermissionCheck::Basic => {
diff --git a/src/main.rs b/src/main.rs
index 7d80ddd..a56ca3d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -56,8 +56,7 @@ async fn main() -> Result<(), Box> {
let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?)
.ignore_bots(true)
.default_prefix("$")
- .add_command("help".to_string(), &HELP_COMMAND)
- .add_command("h".to_string(), &HELP_COMMAND)
+ .add_command("look".to_string(), &LOOK_COMMAND)
.build();
let mut client = Client::new(&env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment"))
@@ -80,7 +79,9 @@ async fn main() -> Result<(), Box> {
}
#[command]
-async fn help(_ctx: &Context, _msg: &Message, _args: Args) -> CommandResult {
+#[permission_level(Managed)]
+#[supports_dm(false)]
+async fn look(_ctx: &Context, _msg: &Message, _args: Args) -> CommandResult {
println!("Help command called");
Ok(())