check roles

This commit is contained in:
jude 2020-08-09 21:01:50 +01:00
parent 0bda479d01
commit 6de542264a
3 changed files with 79 additions and 5 deletions
.idea/dictionaries
src

View File

@ -2,6 +2,7 @@
<dictionary name="jude">
<words>
<w>reqwest</w>
<w>webhooks</w>
</words>
</dictionary>
</component>

View File

@ -9,7 +9,10 @@ use serenity::{
},
},
model::{
guild::Guild,
guild::{
Guild,
Member,
},
channel::{
Channel, GuildChannel, Message,
}
@ -33,6 +36,7 @@ use std::{
};
use crate::SQLPool;
use serenity::model::id::RoleId;
#[derive(Debug)]
pub enum PermissionLevel {
@ -49,6 +53,70 @@ pub struct Command {
pub func: CommandFn,
}
impl Command {
async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
guild.member_permissions(&member.user).manage_guild() || match self.required_perms {
PermissionLevel::Unrestricted => true,
PermissionLevel::Managed => {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
match sqlx::query!("
SELECT
role
FROM
roles
INNER JOIN
command_restrictions ON roles.id = command_restrictions.role_id
WHERE
command_restrictions.command = ? AND
command_restrictions.guild_id = (
SELECT
id
FROM
guilds
WHERE
guild = ?
)", self.name, guild.id.as_u64())
.fetch_all(&pool)
.await {
Ok(rows) => {
let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>();
for row in rows {
if role_ids.contains(&row.role) {
return true
}
}
false
}
Err(sqlx::Error::RowNotFound) => {
false
}
Err(e) => {
warn!("Unexpected error occurred querying command_restrictions: {:?}", e);
false
}
}
}
PermissionLevel::Restricted => {
false
}
}
}
}
impl fmt::Debug for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Command")
@ -232,6 +300,8 @@ impl Framework for RegexFramework {
// Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) {
let member = guild.member(&ctx, &msg.author).await.unwrap();
if let Some(full_match) = self.regex_matcher.captures(&msg.content[..]) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
@ -249,7 +319,9 @@ impl Framework for RegexFramework {
&[]
);
(command.func)(&ctx, &msg, args).await;
if command.check_permissions(&ctx, &guild, &member).await {
(command.func)(&ctx, &msg, args).await;
}
}
PermissionCheck::Basic => {

View File

@ -56,8 +56,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?)
.ignore_bots(true)
.default_prefix("$")
.add_command("help".to_string(), &HELP_COMMAND)
.add_command("h".to_string(), &HELP_COMMAND)
.add_command("look".to_string(), &LOOK_COMMAND)
.build();
let mut client = Client::new(&env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment"))
@ -80,7 +79,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
}
#[command]
async fn help(_ctx: &Context, _msg: &Message, _args: Args) -> CommandResult {
#[permission_level(Managed)]
#[supports_dm(false)]
async fn look(_ctx: &Context, _msg: &Message, _args: Args) -> CommandResult {
println!("Help command called");
Ok(())