From a0baaabe8ab7092fc594e7de6c95de2dace08a5b Mon Sep 17 00:00:00 2001 From: jude Date: Wed, 2 Sep 2020 17:13:17 +0100 Subject: [PATCH] restrict command --- src/commands/moderation_cmds.rs | 79 +++++++++++++++++++++++++++++++++ src/framework.rs | 2 +- src/main.rs | 1 + src/models.rs | 2 +- 4 files changed, 82 insertions(+), 2 deletions(-) diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index 607231b..3972bf8 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -3,6 +3,7 @@ use regex_command_attr::command; use serenity::{ client::Context, model::{ + id::RoleId, channel::{ Message, }, @@ -149,6 +150,84 @@ async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult { #[supports_dm(false)] #[permission_level(Restricted)] async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult { + let pool = ctx.data.read().await + .get::().cloned().expect("Could not get SQLPool from data"); + + let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap(); + + let role_tag_match = REGEX_ROLE.find(&args); + + if let Some(role_tag) = role_tag_match { + let commands = REGEX_COMMANDS.find_iter(&args.to_lowercase()).map(|c| c.as_str().to_string()).collect::>(); + let role_id = RoleId(role_tag.as_str()[3..role_tag.as_str().len()-1].parse::().unwrap()); + + let role_opt = role_id.to_role_cached(&ctx).await; + + if let Some(role) = role_opt { + if commands.len() == 0 { + let _ = sqlx::query!( + " +DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?) + ", role.id.as_u64()) + .execute(&pool) + .await; + + let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/disabled").await).await; + } + else { + let _ = sqlx::query!( + " +INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, ?, ?) + ", role.id.as_u64(), role.name, guild_data.id) + .execute(&pool) + .await; + + for command in commands { + let res = sqlx::query!( + " +INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM roles WHERE role = ?), ?) + ", role.id.as_u64(), command) + .execute(&pool) + .await; + + if res.is_err() { + let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/failure").await).await; + } + } + + let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/enabled").await).await; + } + } + } + else if args.len() == 0 { + let guild_id = msg.guild_id.unwrap().as_u64().clone(); + + let rows = sqlx::query!( + " +SELECT + roles.role, command_restrictions.command +FROM + command_restrictions +INNER JOIN + roles +ON + roles.id = command_restrictions.role_id +WHERE + roles.guild_id = (SELECT id FROM guilds WHERE guild = ?) + ", guild_id) + .fetch_all(&pool) + .await + .unwrap(); + + let display_inner = rows.iter().map(|row| format!("<@&{}> can use {}", row.role, row.command)).collect::>().join("\n"); + let display = user_data.response(&pool, "restrict/allowed").await.replacen("{}", &display_inner, 1); + + let _ = msg.channel_id.say(&ctx, display).await; + } + else { + let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/help").await).await; + } Ok(()) } diff --git a/src/framework.rs b/src/framework.rs index 9dd9b21..4e6705b 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -75,7 +75,7 @@ INNER JOIN command_restrictions ON roles.id = command_restrictions.role_id WHERE command_restrictions.command = ? AND - command_restrictions.guild_id = ( + roles.guild_id = ( SELECT id FROM diff --git a/src/main.rs b/src/main.rs index 1fc76cb..2c1de97 100644 --- a/src/main.rs +++ b/src/main.rs @@ -70,6 +70,7 @@ async fn main() -> Result<(), Box> { .add_command("todo", &todo_cmds::TODO_PARSE_COMMAND) .add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND) + .add_command("restrict", &moderation_cmds::RESTRICT_COMMAND) .add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND) .add_command("prefix", &moderation_cmds::PREFIX_COMMAND) .add_command("lang", &moderation_cmds::LANGUAGE_COMMAND) diff --git a/src/models.rs b/src/models.rs index cce8946..1470ffc 100644 --- a/src/models.rs +++ b/src/models.rs @@ -8,8 +8,8 @@ use serenity::{ }; use sqlx::MySqlPool; -use chrono::NaiveDateTime; +use chrono::NaiveDateTime; use chrono_tz::Tz; pub struct GuildData {