changed permission chekc to be more manual since built in one isnt working

This commit is contained in:
jude 2020-10-11 18:56:27 +01:00
parent 7cfe62d18f
commit 09a7608429
5 changed files with 75 additions and 52 deletions

View File

@ -49,7 +49,7 @@ CREATE TABLE reminders.users (
dm_channel INT UNSIGNED UNIQUE NOT NULL, dm_channel INT UNSIGNED UNIQUE NOT NULL,
language VARCHAR(2) DEFAULT 'EN' NOT NULL, language VARCHAR(2) DEFAULT 'EN' NOT NULL,
timezone VARCHAR(32), # nullable s.t it can default to server timezone timezone VARCHAR(32) DEFAULT 'UTC' NOT NULL,
allowed_dm BOOLEAN DEFAULT 1 NOT NULL, allowed_dm BOOLEAN DEFAULT 1 NOT NULL,
patreon BOOL NOT NULL DEFAULT 0, patreon BOOL NOT NULL DEFAULT 0,

View File

@ -199,14 +199,14 @@ async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let role_opt = role_id.to_role_cached(&ctx).await; let role_opt = role_id.to_role_cached(&ctx).await;
if let Some(role) = role_opt { if let Some(role) = role_opt {
if commands.is_empty() { let _ = sqlx::query!(
let _ = sqlx::query!( "
"
DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?) DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?)
", role.id.as_u64()) ", role.id.as_u64())
.execute(&pool) .execute(&pool)
.await; .await;
if commands.is_empty() {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/disabled").await).await; let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/disabled").await).await;
} }
else { else {
@ -226,7 +226,12 @@ INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM role
.await; .await;
if res.is_err() { if res.is_err() {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/failure").await).await; println!("{:?}", res);
let content = user_data.response(&pool, "restrict/failure").await
.replacen("{command}", &command, 1);
let _ = msg.channel_id.say(&ctx, content).await;
} }
} }

View File

@ -30,6 +30,9 @@ use crate::{
MAX_TIME, MAX_TIME,
LOCAL_TIMEZONE, LOCAL_TIMEZONE,
CHARACTERS, CHARACTERS,
DAY,
HOUR,
MINUTE,
}, },
models::{ models::{
ChannelData, ChannelData,
@ -41,7 +44,6 @@ use crate::{
time_parser::TimeParser, time_parser::TimeParser,
framework::SendIterator, framework::SendIterator,
check_subscription_on_message, check_subscription_on_message,
shorthand_displacement, longhand_displacement
}; };
use chrono::{NaiveDateTime, offset::TimeZone}; use chrono::{NaiveDateTime, offset::TimeZone};
@ -79,6 +81,29 @@ use regex::Regex;
use serde_json::json; use serde_json::json;
fn shorthand_displacement(seconds: u64) -> String {
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
}
fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let mut sections = vec![];
for (var, name) in [days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter()) {
if *var > 0 {
sections.push(format!("{} {}", var, name));
}
}
sections.join(", ")
}
#[command] #[command]
#[supports_dm(false)] #[supports_dm(false)]
#[permission_level(Restricted)] #[permission_level(Restricted)]
@ -357,7 +382,11 @@ LIMIT
user_data.timezone().timestamp(reminder.time as i64, 0).format("%Y-%m-%D %H:%M:%S").to_string() user_data.timezone().timestamp(reminder.time as i64, 0).format("%Y-%m-%D %H:%M:%S").to_string()
}, },
TimeDisplayType::Relative => { TimeDisplayType::Relative => {
longhand_displacement(reminder.time as u64) let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap().as_secs();
longhand_displacement(reminder.time as u64 - now)
}, },
}; };
@ -793,6 +822,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
} }
#[command] #[command]
#[permission_level(Managed)]
async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .get::<SQLPool>().cloned().expect("Could not get SQLPool from data");

View File

@ -23,7 +23,6 @@ use serenity::{
use log::{ use log::{
warn, warn,
error, error,
debug,
info, info,
}; };
@ -47,7 +46,7 @@ use crate::consts::MAX_MESSAGE_LENGTH;
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, CommandResult>; type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, CommandResult>;
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub enum PermissionLevel { pub enum PermissionLevel {
Unrestricted, Unrestricted,
Managed, Managed,
@ -65,14 +64,29 @@ pub struct Command {
impl Command { impl Command {
async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool { async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
guild.member_permissions(&member.user).manage_guild() || match self.required_perms { if self.required_perms == PermissionLevel::Unrestricted {
PermissionLevel::Unrestricted => true, true
}
else {
for role_id in &member.roles {
let role = role_id.to_role_cached(&ctx).await;
PermissionLevel::Managed => { if let Some(cached_role) = role {
if cached_role.permissions.manage_guild() {
return true
}
else if self.required_perms == PermissionLevel::Managed && cached_role.permissions.manage_messages() {
return true
}
}
}
if self.required_perms == PermissionLevel::Managed {
let pool = ctx.data.read().await let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
match sqlx::query!(" match sqlx::query!(
"
SELECT SELECT
role role
FROM FROM
@ -87,13 +101,12 @@ WHERE
FROM FROM
guilds guilds
WHERE WHERE
guild = ? guild = ?)
)", self.name, guild.id.as_u64()) ", self.name, guild.id.as_u64())
.fetch_all(&pool) .fetch_all(&pool)
.await { .await {
Ok(rows) => { Ok(rows) => {
let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>(); let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>();
for row in rows { for row in rows {
@ -114,15 +127,12 @@ WHERE
false false
} }
} }
} }
else {
PermissionLevel::Restricted => {
false false
} }
} }
} }
} }
@ -340,8 +350,6 @@ impl Framework for RegexFramework {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await { if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
debug!("Prefix matched on {}", msg.content);
match check_self_permissions(&ctx, &guild, &channel).await { match check_self_permissions(&ctx, &guild, &channel).await {
Ok(perms) => match perms { Ok(perms) => match perms {
PermissionCheck::All => { PermissionCheck::All => {
@ -360,6 +368,12 @@ impl Framework for RegexFramework {
if command.check_permissions(&ctx, &guild, &member).await { if command.check_permissions(&ctx, &guild, &member).await {
(command.func)(&ctx, &msg, args).await.unwrap(); (command.func)(&ctx, &msg, args).await.unwrap();
} }
else if command.required_perms == PermissionLevel::Restricted {
let _ = msg.channel_id.say(&ctx, "You must have permission level `Manage Server` or greater to use this command.").await;
}
else if command.required_perms == PermissionLevel::Managed {
let _ = msg.channel_id.say(&ctx, "You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `{prefix}restrict` command to specify allowed roles.").await;
}
} }
} }

View File

@ -45,8 +45,7 @@ use std::{
use crate::{ use crate::{
framework::RegexFramework, framework::RegexFramework,
consts::{ consts::{
PREFIX, DAY, HOUR, MINUTE, PREFIX, SUBSCRIPTION_ROLES, CNC_GUILD,
SUBSCRIPTION_ROLES, CNC_GUILD,
}, },
commands::{ commands::{
info_cmds, info_cmds,
@ -56,7 +55,6 @@ use crate::{
}, },
}; };
use num_integer::Integer;
use serenity::futures::TryFutureExt; use serenity::futures::TryFutureExt;
struct SQLPool; struct SQLPool;
@ -154,7 +152,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool { pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD { if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await; let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await;
@ -177,26 +174,3 @@ pub async fn check_subscription_on_message(cache_http: impl CacheHttp + AsRef<Ca
check_subscription(&cache_http, &msg.author).await || check_subscription(&cache_http, &msg.author).await ||
if let Some(guild) = msg.guild(&cache_http).await { check_subscription(&cache_http, guild.owner_id).await } else { false } if let Some(guild) = msg.guild(&cache_http).await { check_subscription(&cache_http, guild.owner_id).await } else { false }
} }
pub fn shorthand_displacement(seconds: u64) -> String {
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
}
pub fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let mut sections = vec![];
for (var, name) in [days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter()) {
if *var > 0 {
sections.push(format!("{} {}", var, name));
}
}
sections.join(", ")
}