From 09a7608429c757d59a1b8c6000d96d9129d2477a Mon Sep 17 00:00:00 2001 From: jude Date: Sun, 11 Oct 2020 18:56:27 +0100 Subject: [PATCH] changed permission chekc to be more manual since built in one isnt working --- create.sql | 2 +- src/commands/moderation_cmds.rs | 19 ++++++++------ src/commands/reminder_cmds.rs | 34 +++++++++++++++++++++++-- src/framework.rs | 44 ++++++++++++++++++++++----------- src/main.rs | 28 +-------------------- 5 files changed, 75 insertions(+), 52 deletions(-) diff --git a/create.sql b/create.sql index 9d3b4a2..1eb668c 100644 --- a/create.sql +++ b/create.sql @@ -49,7 +49,7 @@ CREATE TABLE reminders.users ( dm_channel INT UNSIGNED UNIQUE NOT NULL, language VARCHAR(2) DEFAULT 'EN' NOT NULL, - timezone VARCHAR(32), # nullable s.t it can default to server timezone + timezone VARCHAR(32) DEFAULT 'UTC' NOT NULL, allowed_dm BOOLEAN DEFAULT 1 NOT NULL, patreon BOOL NOT NULL DEFAULT 0, diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index e81b8cd..4b015ff 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -199,14 +199,14 @@ async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult { let role_opt = role_id.to_role_cached(&ctx).await; if let Some(role) = role_opt { - if commands.is_empty() { - let _ = sqlx::query!( - " + let _ = sqlx::query!( + " DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?) - ", role.id.as_u64()) - .execute(&pool) - .await; + ", role.id.as_u64()) + .execute(&pool) + .await; + if commands.is_empty() { let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/disabled").await).await; } else { @@ -226,7 +226,12 @@ INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM role .await; if res.is_err() { - let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/failure").await).await; + println!("{:?}", res); + + let content = user_data.response(&pool, "restrict/failure").await + .replacen("{command}", &command, 1); + + let _ = msg.channel_id.say(&ctx, content).await; } } diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index d9e6e9e..0dbf1bc 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -30,6 +30,9 @@ use crate::{ MAX_TIME, LOCAL_TIMEZONE, CHARACTERS, + DAY, + HOUR, + MINUTE, }, models::{ ChannelData, @@ -41,7 +44,6 @@ use crate::{ time_parser::TimeParser, framework::SendIterator, check_subscription_on_message, - shorthand_displacement, longhand_displacement }; use chrono::{NaiveDateTime, offset::TimeZone}; @@ -79,6 +81,29 @@ use regex::Regex; use serde_json::json; +fn shorthand_displacement(seconds: u64) -> String { + let (hours, seconds) = seconds.div_rem(&HOUR); + let (minutes, seconds) = seconds.div_rem(&MINUTE); + + format!("{:02}:{:02}:{:02}", hours, minutes, seconds) +} + +fn longhand_displacement(seconds: u64) -> String { + let (days, seconds) = seconds.div_rem(&DAY); + let (hours, seconds) = seconds.div_rem(&HOUR); + let (minutes, seconds) = seconds.div_rem(&MINUTE); + + let mut sections = vec![]; + + for (var, name) in [days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter()) { + if *var > 0 { + sections.push(format!("{} {}", var, name)); + } + } + + sections.join(", ") +} + #[command] #[supports_dm(false)] #[permission_level(Restricted)] @@ -357,7 +382,11 @@ LIMIT user_data.timezone().timestamp(reminder.time as i64, 0).format("%Y-%m-%D %H:%M:%S").to_string() }, TimeDisplayType::Relative => { - longhand_displacement(reminder.time as u64) + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap().as_secs(); + + longhand_displacement(reminder.time as u64 - now) }, }; @@ -793,6 +822,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem } #[command] +#[permission_level(Managed)] async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); diff --git a/src/framework.rs b/src/framework.rs index 0bdbb65..0cc76b0 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -23,7 +23,6 @@ use serenity::{ use log::{ warn, error, - debug, info, }; @@ -47,7 +46,7 @@ use crate::consts::MAX_MESSAGE_LENGTH; type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, CommandResult>; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum PermissionLevel { Unrestricted, Managed, @@ -65,14 +64,29 @@ pub struct Command { impl Command { async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool { - guild.member_permissions(&member.user).manage_guild() || match self.required_perms { - PermissionLevel::Unrestricted => true, + if self.required_perms == PermissionLevel::Unrestricted { + true + } + else { + for role_id in &member.roles { + let role = role_id.to_role_cached(&ctx).await; - PermissionLevel::Managed => { + if let Some(cached_role) = role { + if cached_role.permissions.manage_guild() { + return true + } + else if self.required_perms == PermissionLevel::Managed && cached_role.permissions.manage_messages() { + return true + } + } + } + + if self.required_perms == PermissionLevel::Managed { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - match sqlx::query!(" + match sqlx::query!( + " SELECT role FROM @@ -87,13 +101,12 @@ WHERE FROM guilds WHERE - guild = ? - )", self.name, guild.id.as_u64()) + guild = ?) + ", self.name, guild.id.as_u64()) .fetch_all(&pool) .await { Ok(rows) => { - let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::>(); for row in rows { @@ -114,15 +127,12 @@ WHERE false } - } } - - PermissionLevel::Restricted => { + else { false } } - } } @@ -340,8 +350,6 @@ impl Framework for RegexFramework { if check_prefix(&ctx, &guild, full_match.name("prefix")).await { - debug!("Prefix matched on {}", msg.content); - match check_self_permissions(&ctx, &guild, &channel).await { Ok(perms) => match perms { PermissionCheck::All => { @@ -360,6 +368,12 @@ impl Framework for RegexFramework { if command.check_permissions(&ctx, &guild, &member).await { (command.func)(&ctx, &msg, args).await.unwrap(); } + else if command.required_perms == PermissionLevel::Restricted { + let _ = msg.channel_id.say(&ctx, "You must have permission level `Manage Server` or greater to use this command.").await; + } + else if command.required_perms == PermissionLevel::Managed { + let _ = msg.channel_id.say(&ctx, "You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `{prefix}restrict` command to specify allowed roles.").await; + } } } diff --git a/src/main.rs b/src/main.rs index ccf1e19..57e1981 100644 --- a/src/main.rs +++ b/src/main.rs @@ -45,8 +45,7 @@ use std::{ use crate::{ framework::RegexFramework, consts::{ - PREFIX, DAY, HOUR, MINUTE, - SUBSCRIPTION_ROLES, CNC_GUILD, + PREFIX, SUBSCRIPTION_ROLES, CNC_GUILD, }, commands::{ info_cmds, @@ -56,7 +55,6 @@ use crate::{ }, }; -use num_integer::Integer; use serenity::futures::TryFutureExt; struct SQLPool; @@ -154,7 +152,6 @@ async fn main() -> Result<(), Box> { pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into) -> bool { - if let Some(subscription_guild) = *CNC_GUILD { let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await; @@ -177,26 +174,3 @@ pub async fn check_subscription_on_message(cache_http: impl CacheHttp + AsRef String { - let (hours, seconds) = seconds.div_rem(&HOUR); - let (minutes, seconds) = seconds.div_rem(&MINUTE); - - format!("{:02}:{:02}:{:02}", hours, minutes, seconds) -} - -pub fn longhand_displacement(seconds: u64) -> String { - let (days, seconds) = seconds.div_rem(&DAY); - let (hours, seconds) = seconds.div_rem(&HOUR); - let (minutes, seconds) = seconds.div_rem(&MINUTE); - - let mut sections = vec![]; - - for (var, name) in [days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter()) { - if *var > 0 { - sections.push(format!("{} {}", var, name)); - } - } - - sections.join(", ") -}