From 9b5333dc875da04fcc245081f9ec640044f1d783 Mon Sep 17 00:00:00 2001 From: jellywx Date: Sat, 11 Sep 2021 00:14:23 +0100 Subject: [PATCH] more commands. fixed an issue with text only commands --- regex_command_attr/src/lib.rs | 20 +- regex_command_attr/src/structures.rs | 60 +-- rustfmt.toml | 1 + src/commands/info_cmds.rs | 12 +- src/commands/mod.rs | 4 +- src/commands/moderation_cmds.rs | 606 +++++++++------------------ src/commands/reminder_cmds.rs | 200 ++++----- src/framework.rs | 187 ++++----- src/main.rs | 107 ++--- src/models/channel_data.rs | 66 +-- src/models/mod.rs | 54 ++- src/models/reminder/builder.rs | 16 +- src/models/reminder/content.rs | 7 +- src/models/reminder/errors.rs | 60 ++- src/models/reminder/helper.rs | 14 +- src/models/reminder/mod.rs | 12 +- src/models/user_data.rs | 16 +- src/time_parser.rs | 17 +- 18 files changed, 562 insertions(+), 897 deletions(-) diff --git a/regex_command_attr/src/lib.rs b/regex_command_attr/src/lib.rs index 4302bf9..be7bdab 100644 --- a/regex_command_attr/src/lib.rs +++ b/regex_command_attr/src/lib.rs @@ -53,13 +53,9 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { let name = &name[..]; match name { - "arg" => options - .cmd_args - .push(propagate_err!(attributes::parse(values))), + "arg" => options.cmd_args.push(propagate_err!(attributes::parse(values))), "example" => { - options - .examples - .push(propagate_err!(attributes::parse(values))); + options.examples.push(propagate_err!(attributes::parse(values))); } "description" => { let line: String = propagate_err!(attributes::parse(values)); @@ -105,20 +101,14 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { let arg_idents = cmd_args .iter() .map(|arg| { - n.with_suffix(arg.name.replace(" ", "_").replace("-", "_").as_str()) - .with_suffix(ARG) + n.with_suffix(arg.name.replace(" ", "_").replace("-", "_").as_str()).with_suffix(ARG) }) .collect::>(); let mut tokens = cmd_args .iter_mut() .map(|arg| { - let Arg { - name, - description, - kind, - required, - } = arg; + let Arg { name, description, kind, required } = arg; let an = n.with_suffix(name.as_str()).with_suffix(ARG); @@ -141,7 +131,7 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { let variant = if args.len() == 2 { quote!(crate::framework::CommandFnType::Multi) } else { - let string: Type = parse_quote!(std::string::String); + let string: Type = parse_quote!(String); let final_arg = args.get(2).unwrap(); diff --git a/regex_command_attr/src/structures.rs b/regex_command_attr/src/structures.rs index 7fcf15b..a983eb4 100644 --- a/regex_command_attr/src/structures.rs +++ b/regex_command_attr/src/structures.rs @@ -20,41 +20,28 @@ fn parse_argument(arg: FnArg) -> Result { let name = id.ident; let mutable = id.mutability; - Ok(Argument { - mutable, - name, - kind: *kind, - }) + Ok(Argument { mutable, name, kind: *kind }) } Pat::Wild(wild) => { let token = wild.underscore_token; let name = Ident::new("_", token.spans[0]); - Ok(Argument { - mutable: None, - name, - kind: *kind, - }) + Ok(Argument { mutable: None, name, kind: *kind }) } - _ => Err(Error::new( - pat.span(), - format_args!("unsupported pattern: {:?}", pat), - )), + _ => Err(Error::new(pat.span(), format_args!("unsupported pattern: {:?}", pat))), } } - FnArg::Receiver(_) => Err(Error::new( - arg.span(), - format_args!("`self` arguments are prohibited: {:?}", arg), - )), + FnArg::Receiver(_) => { + Err(Error::new(arg.span(), format_args!("`self` arguments are prohibited: {:?}", arg))) + } } } /// Test if the attribute is cooked. fn is_cooked(attr: &Attribute) -> bool { - const COOKED_ATTRIBUTE_NAMES: &[&str] = &[ - "cfg", "cfg_attr", "derive", "inline", "allow", "warn", "deny", "forbid", - ]; + const COOKED_ATTRIBUTE_NAMES: &[&str] = + &["cfg", "cfg_attr", "derive", "inline", "allow", "warn", "deny", "forbid"]; COOKED_ATTRIBUTE_NAMES.iter().any(|n| attr.path.is_ident(n)) } @@ -115,32 +102,15 @@ impl Parse for CommandFun { braced!(bcont in input); let body = bcont.call(Block::parse_within)?; - let args = args - .into_iter() - .map(parse_argument) - .collect::>>()?; + let args = args.into_iter().map(parse_argument).collect::>>()?; - Ok(Self { - attributes, - cooked, - visibility, - name, - args, - body, - }) + Ok(Self { attributes, cooked, visibility, name, args, body }) } } impl ToTokens for CommandFun { fn to_tokens(&self, stream: &mut TokenStream2) { - let Self { - attributes: _, - cooked, - visibility, - name, - args, - body, - } = self; + let Self { attributes: _, cooked, visibility, name, args, body } = self; stream.extend(quote! { #(#cooked)* @@ -211,6 +181,7 @@ pub(crate) enum ApplicationCommandOptionType { Channel, Role, Mentionable, + Number, Unknown, } @@ -226,6 +197,7 @@ impl ApplicationCommandOptionType { "Channel" => Self::Channel, "Role" => Self::Role, "Mentionable" => Self::Mentionable, + "Number" => Self::Number, _ => Self::Unknown, } } @@ -246,6 +218,7 @@ impl ToTokens for ApplicationCommandOptionType { ApplicationCommandOptionType::Channel => quote!(Channel), ApplicationCommandOptionType::Role => quote!(Role), ApplicationCommandOptionType::Mentionable => quote!(Mentionable), + ApplicationCommandOptionType::Number => quote!(Number), ApplicationCommandOptionType::Unknown => quote!(Unknown), }; @@ -289,9 +262,6 @@ pub(crate) struct Options { impl Options { #[inline] pub fn new() -> Self { - Self { - group: "Other".to_string(), - ..Default::default() - } + Self { group: "Other".to_string(), ..Default::default() } } } diff --git a/rustfmt.toml b/rustfmt.toml index 455c820..6cd7eb7 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1,3 @@ imports_granularity = "Crate" group_imports = "StdExternalCrate" +use_small_heuristics = "Max" diff --git a/src/commands/info_cmds.rs b/src/commands/info_cmds.rs index c282139..cb0a6c9 100644 --- a/src/commands/info_cmds.rs +++ b/src/commands/info_cmds.rs @@ -40,15 +40,15 @@ async fn info(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { .description(format!( "Default prefix: `{default_prefix}` Reset prefix: `@{user} prefix {default_prefix}` -Help: `{prefix}help` - -**Welcome to Reminder Bot!** +Help: `{prefix}help`**Welcome \ + to Reminder Bot!** Developer: <@203532103185465344> Icon: <@253202252821430272> -Find me on https://discord.jellywx.com and on https://github.com/JellyWX :) +Find me on https://discord.jellywx.com \ + and on https://github.com/JellyWX :) -Invite the bot: https://invite.reminder-bot.com/ -Use our dashboard: https://reminder-bot.com/", +Invite the bot: https://invite.reminder-bot.com/Use our dashboard: \ + https://reminder-bot.com/", default_prefix = *DEFAULT_PREFIX, user = current_user.name, prefix = prefix diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 31582d1..6bb21cf 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,4 +1,4 @@ pub mod info_cmds; -//pub mod moderation_cmds; -//pub mod reminder_cmds; +pub mod moderation_cmds; +pub mod reminder_cmds; //pub mod todo_cmds; diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index ebf6b8c..0b852a1 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -2,134 +2,126 @@ use std::{collections::HashMap, iter}; use chrono::offset::Utc; use chrono_tz::{Tz, TZ_VARIANTS}; -use inflector::Inflector; use levenshtein::levenshtein; +use regex::Regex; use regex_command_attr::command; use serenity::{ - builder::CreateActionRow, client::Context, - framework::Framework, model::{ channel::Message, + guild::ActionRole::Create, id::{ChannelId, MessageId, RoleId}, interactions::message_component::ButtonStyle, + misc::Mentionable, }, }; use crate::{ - command_help, - consts::{REGEX_ALIAS, REGEX_CHANNEL, REGEX_COMMANDS, REGEX_ROLE, THEME_COLOR}, - framework::SendIterator, - get_ctx_data, + consts::{REGEX_ALIAS, REGEX_COMMANDS, THEME_COLOR}, + framework::{CommandInvoke, CreateGenericResponse, PermissionLevel}, models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxData}, - FrameworkCtx, PopularTimezones, + PopularTimezones, RegexFramework, SQLPool, }; -#[command] +#[command("blacklist")] +#[description("Block channels from using bot commands")] +#[arg( + name = "channel", + description = "The channel to blacklist", + kind = "Channel", + required = false +)] #[supports_dm(false)] -#[permission_level(Restricted)] +#[required_permissions(Restricted)] #[can_blacklist(false)] -async fn blacklist(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; +async fn blacklist( + ctx: &Context, + invoke: &(dyn CommandInvoke + Send + Sync), + args: HashMap, +) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); - let language = UserData::language_of(&msg.author, &pool).await; + let channel = match args.get("channel") { + Some(channel_id) => ChannelId(channel_id.parse::().unwrap()), - let capture_opt = REGEX_CHANNEL - .captures(&args) - .map(|cap| cap.get(1)) - .flatten(); + None => invoke.channel_id(), + } + .to_channel_cached(&ctx) + .unwrap(); - let (channel, local) = match capture_opt { - Some(capture) => ( - ChannelId(capture.as_str().parse::().unwrap()).to_channel_cached(&ctx), - false, - ), - - None => (msg.channel(&ctx).await.ok(), true), - }; - - let mut channel_data = ChannelData::from_channel(channel.unwrap(), &pool) - .await - .unwrap(); + let mut channel_data = ChannelData::from_channel(&channel, &pool).await.unwrap(); channel_data.blacklisted = !channel_data.blacklisted; channel_data.commit_changes(&pool).await; if channel_data.blacklisted { - if local { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "blacklist/added")) - .await; - } else { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "blacklist/added_from")) - .await; - } - } else if local { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "blacklist/removed")) + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(format!("{} has been blacklisted", channel.mention())), + ) .await; } else { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "blacklist/removed_from")) + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(format!("{} has been removed from the blacklist", channel.mention())), + ) .await; } } -#[command] -async fn timezone(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; +#[command("timezone")] +#[description("Select your timezone")] +#[arg( + name = "timezone", + description = "Timezone to use from this list: https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee", + kind = "String", + required = false +)] +async fn timezone( + ctx: &Context, + invoke: &(dyn CommandInvoke + Send + Sync), + args: HashMap, +) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap(); - let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); + let footer_text = format!("Current timezone: {}", user_data.timezone); - let footer_text = lm.get(&user_data.language, "timezone/footer").replacen( - "{timezone}", - &user_data.timezone, - 1, - ); - - if !args.is_empty() { - match args.parse::() { - Ok(_) => { - user_data.timezone = args; + if let Some(timezone) = args.get("timezone") { + match timezone.parse::() { + Ok(tz) => { + user_data.timezone = timezone.clone(); user_data.commit_changes(&pool).await; - let now = Utc::now().with_timezone(&user_data.timezone()); + let now = Utc::now().with_timezone(&tz); - let content = lm - .get(&user_data.language, "timezone/set_p") - .replacen("{timezone}", &user_data.timezone, 1) - .replacen("{time}", &now.format("%H:%M").to_string(), 1); - - let _ = - msg.channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title(lm.get(&user_data.language, "timezone/set_p_title")) - .description(content) - .color(*THEME_COLOR) - .footer(|f| { - f.text( - lm.get(&user_data.language, "timezone/footer") - .replacen("{timezone}", &user_data.timezone, 1), - ) - }) - }) - }) - .await; + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().embed(|e| { + e.title("Timezone Set") + .description(format!( + "Timezone has been set to **{}**. Your current time should be `{}`", + timezone, + now.format("%H:%M").to_string() + )) + .color(*THEME_COLOR) + }), + ) + .await; } Err(_) => { let filtered_tz = TZ_VARIANTS .iter() .filter(|tz| { - args.contains(&tz.to_string()) - || tz.to_string().contains(&args) - || levenshtein(&tz.to_string(), &args) < 4 + timezone.contains(&tz.to_string()) + || tz.to_string().contains(timezone) + || levenshtein(&tz.to_string(), timezone) < 4 }) .take(25) .map(|t| t.to_owned()) @@ -146,371 +138,164 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) { ) }); - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title(lm.get(&user_data.language, "timezone/no_timezone_title")) - .description(lm.get(&user_data.language, "timezone/no_timezone")) + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().embed(|e| { + e.title("Timezone Not Recognized") + .description("Possibly you meant one of the following timezones, otherwise click [here](https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee):") .color(*THEME_COLOR) .fields(fields) .footer(|f| f.text(footer_text)) .url("https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee") - }).components(|c| { - for row in filtered_tz.as_slice().chunks(5) { - let mut action_row = CreateActionRow::default(); - for timezone in row { - action_row.create_button(|b| { - b.style(ButtonStyle::Secondary) - .label(timezone.to_string()) - .custom_id(format!("timezone:{}", timezone.to_string())) - }); - } - - c.add_action_row(action_row); - } - - c - }) - }) + }), + ) .await; } } } else { - let content = lm - .get(&user_data.language, "timezone/no_argument") - .replace("{prefix}", &ctx.prefix(msg.guild_id).await); - - let popular_timezones = ctx - .data - .read() - .await - .get::() - .cloned() - .unwrap(); + let popular_timezones = ctx.data.read().await.get::().cloned().unwrap(); let popular_timezones_iter = popular_timezones.iter().map(|t| { ( t.to_string(), - format!( - "🕗 `{}`", - Utc::now().with_timezone(t).format("%H:%M").to_string() - ), + format!("🕗 `{}`", Utc::now().with_timezone(t).format("%H:%M").to_string()), true, ) }); - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title(lm.get(&user_data.language, "timezone/no_argument_title")) - .description(content) + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().embed(|e| { + e.title("Timezone Usage") + .description( + "**Usage:** +`/timezone Name` + +**Example:** +`/timezone Europe/London` + +You may want to use one of the popular timezones below, otherwise click [here](https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee):", + ) .color(*THEME_COLOR) .fields(popular_timezones_iter) .footer(|f| f.text(footer_text)) .url("https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee") - }) - .components(|c| { - for row in popular_timezones.as_slice().chunks(5) { - let mut action_row = CreateActionRow::default(); - for timezone in row { - action_row.create_button(|b| { - b.style(ButtonStyle::Secondary) - .label(timezone.to_string()) - .custom_id(format!("timezone:{}", timezone.to_string())) - }); - } - - c.add_action_row(action_row); - } - - c - }) - }) - .await; - } -} - -#[command("lang")] -async fn language(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; - - let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); - - if !args.is_empty() { - match lm.get_language(&args) { - Some(lang) => { - user_data.language = lang.to_string(); - - user_data.commit_changes(&pool).await; - - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title(lm.get(&user_data.language, "lang/set_p_title")) - .color(*THEME_COLOR) - .description(lm.get(&user_data.language, "lang/set_p")) - }) - }) - .await; - } - - None => { - let language_codes = lm.all_languages().map(|(k, v)| { - ( - format!("{} {}", lm.get(k, "flag"), v.to_title_case()), - format!("`$lang {}`", k.to_uppercase()), - true, - ) - }); - - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title(lm.get(&user_data.language, "lang/invalid_title")) - .color(*THEME_COLOR) - .description(lm.get(&user_data.language, "lang/invalid")) - .fields(language_codes) - }) - .components(|c| { - for row in lm - .all_languages() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect::>() - .as_slice() - .chunks(5) - { - let mut action_row = CreateActionRow::default(); - for (code, name) in row { - action_row.create_button(|b| { - b.style(ButtonStyle::Primary) - .label(name.to_title_case()) - .custom_id(format!("lang:{}", code.to_uppercase())) - }); - } - - c.add_action_row(action_row); - } - - c - }) - }) - .await; - } - } - } else { - let language_codes = lm.all_languages().map(|(k, v)| { - ( - format!("{} {}", lm.get(k, "flag"), v.to_title_case()), - format!("`$lang {}`", k.to_uppercase()), - true, + }), ) - }); - - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title(lm.get(&user_data.language, "lang/select_title")) - .color(*THEME_COLOR) - .description(lm.get(&user_data.language, "lang/select")) - .fields(language_codes) - }) - .components(|c| { - for row in lm - .all_languages() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect::>() - .as_slice() - .chunks(5) - { - let mut action_row = CreateActionRow::default(); - for (code, name) in row { - action_row.create_button(|b| { - b.style(ButtonStyle::Primary) - .label(name.to_title_case()) - .custom_id(format!("lang:{}", code.to_uppercase())) - }); - } - - c.add_action_row(action_row); - } - - c - }) - }) .await; } } -#[command] +#[command("prefix")] +#[description("Configure a prefix for text-based commands (deprecated)")] #[supports_dm(false)] -#[permission_level(Restricted)] -async fn prefix(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; +#[required_permissions(Restricted)] +async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: String) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); - let guild_data = ctx.guild_data(msg.guild_id.unwrap()).await.unwrap(); - let language = UserData::language_of(&msg.author, &pool).await; + let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap(); if args.len() > 5 { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "prefix/too_long")) + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("Please select a prefix under 5 characters"), + ) .await; } else if args.is_empty() { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "prefix/no_argument")) + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Please use this command as `@reminder-bot prefix `"), + ) .await; } else { guild_data.write().await.prefix = args; - guild_data.read().await.commit_changes(&pool).await; - let content = lm.get(&language, "prefix/success").replacen( - "{prefix}", - &guild_data.read().await.prefix, - 1, - ); - - let _ = msg.channel_id.say(&ctx, content).await; - } -} - -#[command] -#[supports_dm(false)] -#[permission_level(Restricted)] -async fn restrict(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; - - let language = UserData::language_of(&msg.author, &pool).await; - let guild_data = GuildData::from_guild(msg.guild(&ctx).unwrap(), &pool) - .await - .unwrap(); - - let role_tag_match = REGEX_ROLE.find(&args); - - if let Some(role_tag) = role_tag_match { - let commands = REGEX_COMMANDS - .find_iter(&args.to_lowercase()) - .map(|c| c.as_str().to_string()) - .collect::>(); - let role_id = RoleId( - role_tag.as_str()[3..role_tag.as_str().len() - 1] - .parse::() - .unwrap(), - ); - - let role_opt = role_id.to_role_cached(&ctx); - - if let Some(role) = role_opt { - let _ = sqlx::query!( - " -DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?) - ", - role.id.as_u64() + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(format!("Prefix changed to {}", guild_data.read().await.prefix)), ) - .execute(&pool) .await; - - if commands.is_empty() { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "restrict/disabled")) - .await; - } else { - let _ = sqlx::query!( - " -INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, ?, ?) - ", - role.id.as_u64(), - role.name, - guild_data.id - ) - .execute(&pool) - .await; - - for command in commands { - let res = sqlx::query!( - " -INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM roles WHERE role = ?), ?) - ", role.id.as_u64(), command) - .execute(&pool) - .await; - - if res.is_err() { - println!("{:?}", res); - - let content = lm.get(&language, "restrict/failure").replacen( - "{command}", - &command, - 1, - ); - - let _ = msg.channel_id.say(&ctx, content).await; - } - } - - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "restrict/enabled")) - .await; - } - } - } else if args.is_empty() { - let guild_id = msg.guild_id.unwrap().as_u64().to_owned(); - - let rows = sqlx::query!( - " -SELECT - roles.role, command_restrictions.command -FROM - command_restrictions -INNER JOIN - roles -ON - roles.id = command_restrictions.role_id -WHERE - roles.guild_id = (SELECT id FROM guilds WHERE guild = ?) - ", - guild_id - ) - .fetch_all(&pool) - .await - .unwrap(); - - let mut commands_roles: HashMap<&str, Vec> = HashMap::new(); - - rows.iter().for_each(|row| { - if let Some(vec) = commands_roles.get_mut(&row.command.as_str()) { - vec.push(format!("<@&{}>", row.role)); - } else { - commands_roles.insert(&row.command, vec![format!("<@&{}>", row.role)]); - } - }); - - let fields = commands_roles - .iter() - .map(|(key, value)| (key.to_title_case(), value.join("\n"), true)); - - let title = lm.get(&language, "restrict/title"); - - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| e.title(title).fields(fields).color(*THEME_COLOR)) - }) - .await; - } else { - let prefix = ctx.prefix(msg.guild_id).await; - - command_help(ctx, msg, lm, &prefix, &language, "restrict").await; } } +#[command("restrict")] +#[description("Configure which roles can use commands on the bot")] +#[arg( + name = "role", + description = "The role to configure command permissions for", + kind = "Role", + required = true +)] +#[supports_dm(false)] +#[required_permissions(Restricted)] +async fn restrict( + ctx: &Context, + invoke: &(dyn CommandInvoke + Send + Sync), + args: HashMap, +) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + let framework = ctx.data.read().await.get::().cloned().unwrap(); + + let role = RoleId(args.get("role").unwrap().parse::().unwrap()); + + let restricted_commands = + sqlx::query!("SELECT command FROM command_restrictions WHERE role_id = ?", role.0) + .fetch_all(&pool) + .await + .unwrap() + .iter() + .map(|row| row.command.clone()) + .collect::>(); + + let restrictable_commands = framework + .commands + .iter() + .filter(|c| c.required_permissions == PermissionLevel::Managed) + .map(|c| c.names[0].to_string()) + .collect::>(); + + let len = restrictable_commands.len(); + + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(format!("Select the commands to allow to {} from below:", role.mention())) + .components(|c| { + c.create_action_row(|row| { + row.create_select_menu(|select| { + select + .custom_id("test_id") + .options(|options| { + for command in restrictable_commands { + options.create_option(|opt| { + opt.label(&command).value(&command).default_selection( + restricted_commands.contains(&command), + ) + }); + } + + options + }) + .min_values(0) + .max_values(len as u64) + }) + }) + }), + ) + .await + .unwrap(); +} + +/* #[command("alias")] #[supports_dm(false)] #[permission_level(Managed)] @@ -638,3 +423,4 @@ SELECT command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHER command_help(ctx, msg, lm, &prefix, &language, "alias").await; } } +*/ diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index e87cb81..4a0aa98 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -1,4 +1,5 @@ use std::{ + collections::HashMap, default::Default, string::ToString, time::{SystemTime, UNIX_EPOCH}, @@ -13,13 +14,12 @@ use serenity::{ }; use crate::{ - check_subscription_on_message, command_help, + check_subscription_on_message, consts::{ REGEX_CHANNEL_USER, REGEX_NATURAL_COMMAND_1, REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND, THEME_COLOR, }, - framework::SendIterator, - get_ctx_data, + framework::{CommandInvoke, CreateGenericResponse}, models::{ channel_data::ChannelData, guild_data::GuildData, @@ -34,44 +34,35 @@ use crate::{ CtxData, }, time_parser::{natural_parser, TimeParser}, + SQLPool, }; -#[command] +#[command("pause")] +#[description("Pause all reminders on the current channel until a certain time or indefinitely")] +#[arg( + name = "until", + description = "When to pause until (hint: try 'next Wednesday', or '10 minutes')", + kind = "String", + required = false +)] #[supports_dm(false)] -#[permission_level(Restricted)] -async fn pause(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; +#[required_permissions(Restricted)] +async fn pause( + ctx: &Context, + invoke: &(dyn CommandInvoke + Send + Sync), + args: HashMap, +) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); - let language = UserData::language_of(&msg.author, &pool).await; - let timezone = UserData::timezone_of(&msg.author, &pool).await; + let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; - let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool) - .await - .unwrap(); + let mut channel = ctx.channel_data(invoke.channel_id()).await.unwrap(); - if args.is_empty() { - channel.paused = !channel.paused; - channel.paused_until = None; + match args.get("until") { + Some(until) => { + let parsed = natural_parser(until, &timezone.to_string()).await; - channel.commit_changes(&pool).await; - - if channel.paused { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "pause/paused_indefinite")) - .await; - } else { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "pause/unpaused")) - .await; - } - } else { - let parser = TimeParser::new(&args, timezone); - let pause_until = parser.timestamp(); - - match pause_until { - Ok(timestamp) => { + if let Some(timestamp) = parsed { let dt = NaiveDateTime::from_timestamp(timestamp, 0); channel.paused = true; @@ -79,23 +70,53 @@ async fn pause(ctx: &Context, msg: &Message, args: String) { channel.commit_changes(&pool).await; - let content = lm - .get(&language, "pause/paused_until") - .replace("{}", &format!("", timestamp)); - - let _ = msg.channel_id.say(&ctx, content).await; + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!( + "Reminders in this channel have been silenced until ****", + timestamp + )), + ) + .await; + } else { + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Time could not be processed. Please write the time as clearly as possible"), + ) + .await; } + } + None => { + channel.paused = !channel.paused; + channel.paused_until = None; - Err(_) => { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "pause/invalid_time")) + channel.commit_changes(&pool).await; + + if channel.paused { + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Reminders in this channel have been silenced indefinitely"), + ) + .await; + } else { + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Reminders in this channel have been unsilenced"), + ) .await; } } } } +/* #[command] #[permission_level(Restricted)] async fn offset(ctx: &Context, msg: &Message, args: String) { @@ -150,10 +171,8 @@ UPDATE reminders SET `utc_time` = `utc_time` + ? WHERE reminders.channel_id = ? let _ = msg.channel_id.say(&ctx, response).await; } else { - let _ = msg - .channel_id - .say(&ctx, lm.get(&user_data.language, "offset/invalid_time")) - .await; + let _ = + msg.channel_id.say(&ctx, lm.get(&user_data.language, "offset/invalid_time")).await; } } } @@ -166,9 +185,8 @@ async fn nudge(ctx: &Context, msg: &Message, args: String) { let language = UserData::language_of(&msg.author, &pool).await; let timezone = UserData::timezone_of(&msg.author, &pool).await; - let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool) - .await - .unwrap(); + let mut channel = + ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap(); if args.is_empty() { let content = lm @@ -183,10 +201,7 @@ async fn nudge(ctx: &Context, msg: &Message, args: String) { match nudge_time { Ok(displacement) => { if displacement < i16::MIN as i64 || displacement > i16::MAX as i64 { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "nudge/invalid_time")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&language, "nudge/invalid_time")).await; } else { channel.nudge = displacement as i16; @@ -203,10 +218,7 @@ async fn nudge(ctx: &Context, msg: &Message, args: String) { } Err(_) => { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "nudge/invalid_time")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&language, "nudge/invalid_time")).await; } } } @@ -236,14 +248,9 @@ async fn look(ctx: &Context, msg: &Message, args: String) { let reminders = Reminder::from_channel(ctx, channel_id, &flags).await; if reminders.is_empty() { - let _ = msg - .channel_id - .say(&ctx, "No reminders on specified channel") - .await; + let _ = msg.channel_id.say(&ctx, "No reminders on specified channel").await; } else { - let display = reminders - .iter() - .map(|reminder| reminder.display(&flags, &timezone)); + let display = reminders.iter().map(|reminder| reminder.display(&flags, &timezone)); let _ = msg.channel_id.say_lines(&ctx, display).await; } @@ -256,10 +263,7 @@ async fn delete(ctx: &Context, msg: &Message, _args: String) { let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); - let _ = msg - .channel_id - .say(&ctx, lm.get(&user_data.language, "del/listing")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&user_data.language, "del/listing")).await; let mut reminder_ids: Vec = vec![]; @@ -278,23 +282,13 @@ async fn delete(ctx: &Context, msg: &Message, _args: String) { }); let _ = msg.channel_id.say_lines(&ctx, enumerated_reminders).await; - let _ = msg - .channel_id - .say(&ctx, lm.get(&user_data.language, "del/listed")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&user_data.language, "del/listed")).await; - let reply = msg - .channel_id - .await_reply(&ctx) - .author_id(msg.author.id) - .channel_id(msg.channel_id) - .await; + let reply = + msg.channel_id.await_reply(&ctx).author_id(msg.author.id).channel_id(msg.channel_id).await; if let Some(content) = reply.map(|m| m.content.replace(",", " ")) { - let parts = content - .split(' ') - .filter(|i| !i.is_empty()) - .collect::>(); + let parts = content.split(' ').filter(|i| !i.is_empty()).collect::>(); let valid_parts = parts .iter() @@ -352,9 +346,7 @@ INSERT INTO events (event_name, bulk_count, guild_id, user_id) VALUES ('delete', let _ = msg.channel_id.say(&ctx, content).await; } else { - let content = lm - .get(&user_data.language, "del/count") - .replacen("{}", "0", 1); + let content = lm.get(&user_data.language, "del/count").replacen("{}", "0", 1); let _ = msg.channel_id.say(&ctx, content).await; } @@ -365,10 +357,7 @@ INSERT INTO events (event_name, bulk_count, guild_id, user_id) VALUES ('delete', #[permission_level(Managed)] async fn timer(ctx: &Context, msg: &Message, args: String) { fn time_difference(start_time: NaiveDateTime) -> String { - let unix_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64; + let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let now = NaiveDateTime::from_timestamp(unix_time, 0); let delta = (now - start_time).num_seconds(); @@ -415,10 +404,7 @@ async fn timer(ctx: &Context, msg: &Message, args: String) { let count = Timer::count_from_owner(owner, &pool).await; if count >= 25 { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "timer/limit")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&language, "timer/limit")).await; } else { let name = args_iter .next() @@ -428,10 +414,7 @@ async fn timer(ctx: &Context, msg: &Message, args: String) { if name.len() <= 32 { Timer::create(&name, owner, &pool).await; - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "timer/success")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&language, "timer/success")).await; } else { let _ = msg .channel_id @@ -469,21 +452,12 @@ DELETE FROM timers WHERE owner = ? AND name = ? .await .unwrap(); - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "timer/deleted")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&language, "timer/deleted")).await; } else { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "timer/not_found")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&language, "timer/not_found")).await; } } else { - let _ = msg - .channel_id - .say(&ctx, lm.get(&language, "timer/help")) - .await; + let _ = msg.channel_id.say(&ctx, lm.get(&language, "timer/help")).await; } } @@ -547,9 +521,8 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem let time_parser = TimeParser::new(captures.name("time").unwrap().as_str(), timezone); - let expires_parser = captures - .name("expires") - .map(|mat| TimeParser::new(mat.as_str(), timezone)); + let expires_parser = + captures.name("expires").map(|mat| TimeParser::new(mat.as_str(), timezone)); let interval_parser = captures .name("interval") @@ -854,3 +827,4 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { } } } +*/ diff --git a/src/framework.rs b/src/framework.rs index 15c4540..cd38330 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -31,7 +31,7 @@ use serenity::{ }; use crate::{ - models::{channel_data::ChannelData, guild_data::GuildData, CtxData}, + models::{channel_data::ChannelData, CtxData}, LimitExecutors, SQLPool, }; @@ -50,11 +50,7 @@ pub struct CreateGenericResponse { impl CreateGenericResponse { pub fn new() -> Self { - Self { - content: "".to_string(), - embed: None, - components: None, - } + Self { content: "".to_string(), embed: None, components: None } } pub fn content(mut self, content: D) -> Self { @@ -227,8 +223,8 @@ impl CommandInvoke for ApplicationCommandInteraction { generic_response: CreateGenericResponse, ) -> SerenityResult<()> { self.create_interaction_response(http, |r| { - r.kind(InteractionResponseType::ChannelMessageWithSource) - .interaction_response_data(|d| { + r.kind(InteractionResponseType::ChannelMessageWithSource).interaction_response_data( + |d| { d.content(generic_response.content); if let Some(embed) = generic_response.embed { @@ -243,7 +239,8 @@ impl CommandInvoke for ApplicationCommandInteraction { } d - }) + }, + ) }) .await .map(|_| ()) @@ -305,10 +302,10 @@ pub enum CommandFnType { } impl CommandFnType { - pub fn text(&self) -> Option<&TextCommandFn> { + pub fn is_slash(&self) -> bool { match self { - CommandFnType::Text(t) => Some(t), - _ => None, + CommandFnType::Text(_) => false, + _ => true, } } } @@ -391,11 +388,8 @@ WHERE .await { Ok(rows) => { - let role_ids = member - .roles - .iter() - .map(|r| *r.as_u64()) - .collect::>(); + let role_ids = + member.roles.iter().map(|r| *r.as_u64()).collect::>(); for row in rows { if role_ids.contains(&row.role) { @@ -409,10 +403,7 @@ WHERE Err(sqlx::Error::RowNotFound) => false, Err(e) => { - warn!( - "Unexpected error occurred querying command_restrictions: {:?}", - e - ); + warn!("Unexpected error occurred querying command_restrictions: {:?}", e); false } @@ -513,11 +504,8 @@ impl RegexFramework { let command_names; { - let mut command_names_vec = self - .commands_map - .keys() - .map(|k| &k[..]) - .collect::>(); + let mut command_names_vec = + self.commands_map.keys().map(|k| &k[..]).collect::>(); command_names_vec.sort_unstable_by_key(|a| a.len()); @@ -527,9 +515,10 @@ impl RegexFramework { info!("Command names: {}", command_names); { - let match_string = r#"^(?:(?:<@ID>\s*)|(?:<@!ID>\s*)|(?P\S{1,5}?))(?PCOMMANDS)(?:$|\s+(?P.*))$"# - .replace("COMMANDS", command_names.as_str()) - .replace("ID", self.client_id.to_string().as_str()); + let match_string = + r#"^(?:(?:<@ID>\s*)|(?:<@!ID>\s*)|(?P\S{1,5}?))(?PCOMMANDS)(?:$|\s+(?P.*))$"# + .replace("COMMANDS", command_names.as_str()) + .replace("ID", self.client_id.to_string().as_str()); self.command_matcher = RegexBuilder::new(match_string.as_str()) .case_insensitive(self.case_insensitive) @@ -546,13 +535,9 @@ impl RegexFramework { let mut command_names_vec = self .commands_map .iter() - .filter_map(|(key, command)| { - if command.supports_dm { - Some(&key[..]) - } else { - None - } - }) + .filter_map( + |(key, command)| if command.supports_dm { Some(&key[..]) } else { None }, + ) .collect::>(); command_names_vec.sort_unstable_by_key(|a| a.len()); @@ -583,30 +568,7 @@ impl RegexFramework { None => { ApplicationCommand::set_global_application_commands(&http, |commands| { for command in &self.commands { - commands.create_application_command(|c| { - c.name(command.names[0]).description(command.desc); - - for arg in command.args { - c.create_option(|o| { - o.name(arg.name) - .description(arg.description) - .kind(arg.kind) - .required(arg.required) - }); - } - - c - }); - } - - commands - }) - .await; - } - Some(debug_guild) => { - debug_guild - .set_application_commands(&http, |commands| { - for command in &self.commands { + if command.fun.is_slash() { commands.create_application_command(|c| { c.name(command.names[0]).description(command.desc); @@ -622,10 +584,38 @@ impl RegexFramework { c }); } + } + + commands + }) + .await; + } + Some(debug_guild) => { + debug_guild + .set_application_commands(&http, |commands| { + for command in &self.commands { + if command.fun.is_slash() { + commands.create_application_command(|c| { + c.name(command.names[0]).description(command.desc); + + for arg in command.args { + c.create_option(|o| { + o.name(arg.name) + .description(arg.description) + .kind(arg.kind) + .required(arg.required) + }); + } + + c + }); + } + } commands }) - .await; + .await + .unwrap(); } } @@ -636,10 +626,7 @@ impl RegexFramework { let command = { self.commands_map .get(&interaction.data.name) - .expect(&format!( - "Received invalid command: {}", - interaction.data.name - )) + .expect(&format!("Received invalid command: {}", interaction.data.name)) }; let guild = interaction.guild(ctx.cache.clone()).unwrap(); @@ -648,12 +635,7 @@ impl RegexFramework { if command.check_permissions(&ctx, &guild, &member).await { let mut args = HashMap::new(); - for arg in interaction - .data - .options - .iter() - .filter(|o| o.value.is_some()) - { + for arg in interaction.data.options.iter().filter(|o| o.value.is_some()) { args.insert( arg.name.clone(), match arg.value.clone().unwrap() { @@ -696,7 +678,9 @@ impl RegexFramework { .respond( ctx.http.clone(), CreateGenericResponse::new().content( - "You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.", + "You must have `Manage Messages` or have a role capable of sending reminders to that channel. \ + Please talk to your server admin, and ask them to use the `/restrict` command to specify \ + allowed roles.", ), ) .await; @@ -725,18 +709,13 @@ impl Framework for RegexFramework { let basic_perms = channel_perms.send_messages(); - Ok( - if basic_perms && guild_perms.manage_webhooks() && channel_perms.embed_links() { - PermissionCheck::All - } else if basic_perms { - PermissionCheck::Basic( - guild_perms.manage_webhooks(), - channel_perms.embed_links(), - ) - } else { - PermissionCheck::None - }, - ) + Ok(if basic_perms && guild_perms.manage_webhooks() && channel_perms.embed_links() { + PermissionCheck::All + } else if basic_perms { + PermissionCheck::Basic(guild_perms.manage_webhooks(), channel_perms.embed_links()) + } else { + PermissionCheck::None + }) } async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option>) -> bool { @@ -758,10 +737,7 @@ impl Framework for RegexFramework { { let data = ctx.data.read().await; - let pool = data - .get::() - .cloned() - .expect("Could not get SQLPool from data"); + let pool = data.get::().cloned().expect("Could not get SQLPool from data"); if let Some(full_match) = self.command_matcher.captures(&msg.content) { if check_prefix(&ctx, &guild, full_match.name("prefix")).await { @@ -779,12 +755,10 @@ impl Framework for RegexFramework { ) .unwrap(); - let channel_data = ChannelData::from_channel( - msg.channel(&ctx).await.unwrap(), - &pool, - ) - .await - .unwrap(); + let channel = msg.channel(&ctx).await.unwrap(); + + let channel_data = + ChannelData::from_channel(&channel, &pool).await.unwrap(); if !command.can_blacklist || !channel_data.blacklisted { let args = full_match @@ -796,19 +770,6 @@ impl Framework for RegexFramework { let member = guild.member(&ctx, &msg.author).await.unwrap(); if command.check_permissions(&ctx, &guild, &member).await { - { - let guild_id = guild.id.as_u64().to_owned(); - - GuildData::from_guild(guild, &pool) - .await - .unwrap_or_else(|_| { - panic!( - "Failed to create new guild object for {}", - guild_id - ) - }); - } - if msg.id == MessageId(0) || !ctx.check_executing(msg.author.id).await { @@ -840,7 +801,10 @@ impl Framework for RegexFramework { .channel_id .say( &ctx, - "You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.", + "You must have `Manage Messages` or have a role capable of \ + sending reminders to that channel. Please talk to your server \ + admin, and ask them to use the `/restrict` command to specify \ + allowed roles.", ) .await; } @@ -887,11 +851,8 @@ impl Framework for RegexFramework { .commands_map .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) .unwrap(); - let args = full_match - .name("args") - .map(|m| m.as_str()) - .unwrap_or("") - .to_string(); + let args = + full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); if msg.id == MessageId(0) || !ctx.check_executing(msg.author.id).await { ctx.set_executing(msg.author.id).await; diff --git a/src/main.rs b/src/main.rs index f3b2a78..925f8aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,7 +33,7 @@ use sqlx::mysql::MySqlPool; use tokio::sync::RwLock; use crate::{ - commands::info_cmds, + commands::{info_cmds, moderation_cmds, reminder_cmds}, consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR}, framework::RegexFramework, models::guild_data::GuildData, @@ -79,28 +79,17 @@ trait LimitExecutors { #[async_trait] impl LimitExecutors for Context { async fn check_executing(&self, user: UserId) -> bool { - let currently_executing = self - .data - .read() - .await - .get::() - .cloned() - .unwrap(); + let currently_executing = + self.data.read().await.get::().cloned().unwrap(); let lock = currently_executing.read().await; - lock.get(&user) - .map_or(false, |now| now.elapsed().as_secs() < 4) + lock.get(&user).map_or(false, |now| now.elapsed().as_secs() < 4) } async fn set_executing(&self, user: UserId) { - let currently_executing = self - .data - .read() - .await - .get::() - .cloned() - .unwrap(); + let currently_executing = + self.data.read().await.get::().cloned().unwrap(); let mut lock = currently_executing.write().await; @@ -108,13 +97,8 @@ impl LimitExecutors for Context { } async fn drop_executing(&self, user: UserId) { - let currently_executing = self - .data - .read() - .await - .get::() - .cloned() - .unwrap(); + let currently_executing = + self.data.read().await.get::().cloned().unwrap(); let mut lock = currently_executing.write().await; @@ -171,11 +155,9 @@ DELETE FROM channels WHERE channel = ? .cloned() .expect("Could not get SQLPool from data"); - GuildData::from_guild(guild, &pool) - .await - .unwrap_or_else(|_| { - panic!("Failed to create new guild object for {}", guild_id) - }); + GuildData::from_guild(guild, &pool).await.unwrap_or_else(|_| { + panic!("Failed to create new guild object for {}", guild_id) + }); } if let Ok(token) = env::var("DISCORDBOTS_TOKEN") { @@ -236,13 +218,7 @@ DELETE FROM channels WHERE channel = ? .cloned() .expect("Could not get SQLPool from data"); - let guild_data_cache = ctx - .data - .read() - .await - .get::() - .cloned() - .unwrap(); + let guild_data_cache = ctx.data.read().await.get::().cloned().unwrap(); guild_data_cache.remove(&deleted_guild.id); sqlx::query!( @@ -292,10 +268,7 @@ async fn main() -> Result<(), Box> { let http = Http::new_with_token(&token); - let logged_in_id = http - .get_current_user() - .map_ok(|user| user.id.as_u64().to_owned()) - .await?; + let logged_in_id = http.get_current_user().map_ok(|user| user.id.as_u64().to_owned()).await?; let application_id = http.get_current_application_info().await?.id; let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1"); @@ -305,9 +278,7 @@ async fn main() -> Result<(), Box> { .case_insensitive(env::var("CASE_INSENSITIVE").map_or(true, |var| var == "1")) .ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1")) .debug_guild(env::var("DEBUG_GUILD").map_or(None, |g| { - Some(GuildId( - g.parse::().expect("DEBUG_GUILD must be a guild ID"), - )) + Some(GuildId(g.parse::().expect("DEBUG_GUILD must be a guild ID"))) })) .dm_enabled(dm_enabled) // info commands @@ -329,6 +300,11 @@ async fn main() -> Result<(), Box> { // management commands .add_command("look", &reminder_cmds::LOOK_COMMAND) .add_command("del", &reminder_cmds::DELETE_COMMAND) + */ + .add_command(&reminder_cmds::PAUSE_COMMAND) + /* + .add_command("offset", &reminder_cmds::OFFSET_COMMAND) + .add_command("nudge", &reminder_cmds::NUDGE_COMMAND) // to-do commands .add_command("todo", &todo_cmds::TODO_USER_COMMAND) .add_command("todo user", &todo_cmds::TODO_USER_COMMAND) @@ -337,15 +313,13 @@ async fn main() -> Result<(), Box> { .add_command("todos", &todo_cmds::TODO_GUILD_COMMAND) .add_command("todo server", &todo_cmds::TODO_GUILD_COMMAND) .add_command("todo guild", &todo_cmds::TODO_GUILD_COMMAND) + */ // moderation commands - .add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND) - .add_command("restrict", &moderation_cmds::RESTRICT_COMMAND) - .add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND) - .add_command("prefix", &moderation_cmds::PREFIX_COMMAND) - .add_command("lang", &moderation_cmds::LANGUAGE_COMMAND) - .add_command("pause", &reminder_cmds::PAUSE_COMMAND) - .add_command("offset", &reminder_cmds::OFFSET_COMMAND) - .add_command("nudge", &reminder_cmds::NUDGE_COMMAND) + .add_command(&moderation_cmds::BLACKLIST_COMMAND) + .add_command(&moderation_cmds::RESTRICT_COMMAND) + .add_command(&moderation_cmds::TIMEZONE_COMMAND) + .add_command(&moderation_cmds::PREFIX_COMMAND) + /* .add_command("alias", &moderation_cmds::ALIAS_COMMAND) .add_command("a", &moderation_cmds::ALIAS_COMMAND) */ @@ -397,9 +371,8 @@ async fn main() -> Result<(), Box> { } if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| { - let mut split = sr - .split(',') - .map(|val| val.parse::().expect("SHARD_RANGE not an integer")); + let mut split = + sr.split(',').map(|val| val.parse::().expect("SHARD_RANGE not an integer")); (split.next(), split.next()) }) { @@ -409,24 +382,14 @@ async fn main() -> Result<(), Box> { .flatten() .expect("No SHARD_COUNT provided, but SHARD_RANGE was provided"); - assert!( - lower < upper, - "SHARD_RANGE lower limit is not less than the upper limit" - ); + assert!(lower < upper, "SHARD_RANGE lower limit is not less than the upper limit"); - info!( - "Starting client fragment with shards {}-{}/{}", - lower, upper, total_shards - ); + info!("Starting client fragment with shards {}-{}/{}", lower, upper, total_shards); - client - .start_shard_range([lower, upper], total_shards) - .await?; - } else if let Ok(total_shards) = env::var("SHARD_COUNT").map(|shard_count| { - shard_count - .parse::() - .expect("SHARD_COUNT not an integer") - }) { + client.start_shard_range([lower, upper], total_shards).await?; + } else if let Ok(total_shards) = env::var("SHARD_COUNT") + .map(|shard_count| shard_count.parse::().expect("SHARD_COUNT not an integer")) + { info!("Starting client with {} shards", total_shards); client.start_shards(total_shards).await?; @@ -441,9 +404,7 @@ async fn main() -> Result<(), Box> { pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into) -> bool { if let Some(subscription_guild) = *CNC_GUILD { - let guild_member = GuildId(subscription_guild) - .member(cache_http, user_id) - .await; + let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await; if let Ok(member) = guild_member { for role in member.roles { diff --git a/src/models/channel_data.rs b/src/models/channel_data.rs index c431ec7..114ca9d 100644 --- a/src/models/channel_data.rs +++ b/src/models/channel_data.rs @@ -15,51 +15,67 @@ pub struct ChannelData { impl ChannelData { pub async fn from_channel( - channel: Channel, + channel: &Channel, pool: &MySqlPool, ) -> Result> { let channel_id = channel.id().as_u64().to_owned(); - if let Ok(c) = sqlx::query_as_unchecked!(Self, + if let Ok(c) = sqlx::query_as_unchecked!( + Self, " SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? - ", channel_id) - .fetch_one(pool) - .await { - + ", + channel_id + ) + .fetch_one(pool) + .await + { Ok(c) - } - else { - let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name)); + } else { + let props = channel.to_owned().guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name)); - let (guild_id, channel_name) = if let Some((a, b)) = props { - (Some(a), Some(b)) - } else { - (None, None) - }; + let (guild_id, channel_name) = if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) }; sqlx::query!( " INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?)) - ", channel_id, channel_name, guild_id) - .execute(&pool.clone()) - .await?; + ", + channel_id, + channel_name, + guild_id + ) + .execute(&pool.clone()) + .await?; - Ok(sqlx::query_as_unchecked!(Self, + Ok(sqlx::query_as_unchecked!( + Self, " SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? - ", channel_id) - .fetch_one(pool) - .await?) + ", + channel_id + ) + .fetch_one(pool) + .await?) } } pub async fn commit_changes(&self, pool: &MySqlPool) { sqlx::query!( " -UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ? - ", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id) - .execute(pool) - .await.unwrap(); +UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until \ + = ? WHERE id = ? + ", + self.name, + self.nudge, + self.blacklisted, + self.webhook_id, + self.webhook_token, + self.paused, + self.paused_until, + self.id + ) + .execute(pool) + .await + .unwrap(); } } diff --git a/src/models/mod.rs b/src/models/mod.rs index a80fd4c..e8afa15 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -6,15 +6,18 @@ pub mod user_data; use std::sync::Arc; -use guild_data::GuildData; use serenity::{ async_trait, - model::id::{GuildId, UserId}, + model::id::{ChannelId, GuildId, UserId}, prelude::Context, }; use tokio::sync::RwLock; -use crate::{consts::DEFAULT_PREFIX, models::user_data::UserData, GuildDataCache, SQLPool}; +use crate::{ + consts::DEFAULT_PREFIX, + models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData}, + GuildDataCache, SQLPool, +}; #[async_trait] pub trait CtxData { @@ -23,12 +26,17 @@ pub trait CtxData { guild_id: G, ) -> Result>, sqlx::Error>; + async fn prefix + Send + Sync>(&self, guild_id: Option) -> String; + async fn user_data + Send + Sync>( &self, user_id: U, ) -> Result>; - async fn prefix + Send + Sync>(&self, guild_id: Option) -> String; + async fn channel_data + Send + Sync>( + &self, + channel_id: C, + ) -> Result>; } #[async_trait] @@ -41,13 +49,7 @@ impl CtxData for Context { let guild = guild_id.to_guild_cached(&self.cache).unwrap(); - let guild_cache = self - .data - .read() - .await - .get::() - .cloned() - .unwrap(); + let guild_cache = self.data.read().await.get::().cloned().unwrap(); let x = if let Some(guild_data) = guild_cache.get(&guild_id) { Ok(guild_data.clone()) @@ -70,6 +72,14 @@ impl CtxData for Context { x } + async fn prefix + Send + Sync>(&self, guild_id: Option) -> String { + if let Some(guild_id) = guild_id { + self.guild_data(guild_id).await.unwrap().read().await.prefix.clone() + } else { + DEFAULT_PREFIX.clone() + } + } + async fn user_data + Send + Sync>( &self, user_id: U, @@ -82,17 +92,15 @@ impl CtxData for Context { UserData::from_user(&user, &self, &pool).await } - async fn prefix + Send + Sync>(&self, guild_id: Option) -> String { - if let Some(guild_id) = guild_id { - self.guild_data(guild_id) - .await - .unwrap() - .read() - .await - .prefix - .clone() - } else { - DEFAULT_PREFIX.clone() - } + async fn channel_data + Send + Sync>( + &self, + channel_id: C, + ) -> Result> { + let channel_id = channel_id.into(); + let pool = self.data.read().await.get::().cloned().unwrap(); + + let channel = channel_id.to_channel_cached(&self).unwrap(); + + ChannelData::from_channel(&channel, &pool).await } } diff --git a/src/models/reminder/builder.rs b/src/models/reminder/builder.rs index beea19e..8bf9ec5 100644 --- a/src/models/reminder/builder.rs +++ b/src/models/reminder/builder.rs @@ -38,10 +38,7 @@ async fn create_webhook( include_bytes!(concat!( env!("CARGO_MANIFEST_DIR"), "/assets/", - env!( - "WEBHOOK_AVATAR", - "WEBHOOK_AVATAR not provided for compilation" - ) + env!("WEBHOOK_AVATAR", "WEBHOOK_AVATAR not provided for compilation") )) as &[u8], env!("WEBHOOK_AVATAR"), ), @@ -230,14 +227,7 @@ impl<'a> MultiReminderBuilder<'a> { } pub async fn build(mut self) -> (HashSet, HashSet) { - let pool = self - .ctx - .data - .read() - .await - .get::() - .cloned() - .unwrap(); + let pool = self.ctx.data.read().await.get::().cloned().unwrap(); let mut errors = HashSet::new(); @@ -296,7 +286,7 @@ impl<'a> MultiReminderBuilder<'a> { Err(ReminderError::InvalidTag) } else { let mut channel_data = - ChannelData::from_channel(channel, &pool).await.unwrap(); + ChannelData::from_channel(&channel, &pool).await.unwrap(); if channel_data.webhook_id.is_none() || channel_data.webhook_token.is_none() diff --git a/src/models/reminder/content.rs b/src/models/reminder/content.rs index 7af093c..0c9e6f2 100644 --- a/src/models/reminder/content.rs +++ b/src/models/reminder/content.rs @@ -12,12 +12,7 @@ pub struct Content { impl Content { pub fn new() -> Self { - Self { - content: "".to_string(), - tts: false, - attachment: None, - attachment_name: None, - } + Self { content: "".to_string(), tts: false, attachment: None, attachment_name: None } } pub async fn build(content: S, message: &Message) -> Result { diff --git a/src/models/reminder/errors.rs b/src/models/reminder/errors.rs index 3eabc8d..d3e8ad5 100644 --- a/src/models/reminder/errors.rs +++ b/src/models/reminder/errors.rs @@ -42,22 +42,50 @@ pub enum ReminderError { impl ReminderError { pub fn display(&self, is_natural: bool) -> String { match self { - ReminderError::LongTime => "That time is too far in the future. Please specify a shorter time.".to_string(), - ReminderError::LongInterval => format!("Please ensure the interval specified is less than {max_time} days", max_time = *MAX_TIME / 86_400), - ReminderError::PastTime => "Please ensure the time provided is in the future. If the time should be in the future, please be more specific with the definition.".to_string(), - ReminderError::ShortInterval => format!("Please ensure the interval provided is longer than {min_interval} seconds", min_interval = *MIN_INTERVAL), - ReminderError::InvalidTag => "Couldn't find a location by your tag. Your tag must be either a channel or a user (not a role)".to_string(), - ReminderError::InvalidTime => if is_natural { - "Your time failed to process. Please make it as clear as possible, for example `\"16th of july\"` or `\"in 20 minutes\"`".to_string() - } else { - "Make sure the time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or `day/month/year-hour:minute:second`".to_string() - }, - ReminderError::InvalidExpiration => if is_natural { - "Your expiration time failed to process. Please make it as clear as possible, for example `\"16th of july\"` or `\"in 20 minutes\"`".to_string() - } else { - "Make sure the expiration time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or `day/month/year-hour:minute:second`".to_string() - }, - ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s) + ReminderError::LongTime => { + "That time is too far in the future. Please specify a shorter time.".to_string() + } + ReminderError::LongInterval => format!( + "Please ensure the interval specified is less than {max_time} days", + max_time = *MAX_TIME / 86_400 + ), + ReminderError::PastTime => { + "Please ensure the time provided is in the future. If the time should be in \ + the future, please be more specific with the definition." + .to_string() + } + ReminderError::ShortInterval => format!( + "Please ensure the interval provided is longer than {min_interval} seconds", + min_interval = *MIN_INTERVAL + ), + ReminderError::InvalidTag => { + "Couldn't find a location by your tag. Your tag must be either a channel or \ + a user (not a role)" + .to_string() + } + ReminderError::InvalidTime => { + if is_natural { + "Your time failed to process. Please make it as clear as possible, for example `\"16th of july\"` \ + or `\"in 20 minutes\"`" + .to_string() + } else { + "Make sure the time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or \ + `day/month/year-hour:minute:second`" + .to_string() + } + } + ReminderError::InvalidExpiration => { + if is_natural { + "Your expiration time failed to process. Please make it as clear as possible, for example `\"16th \ + of july\"` or `\"in 20 minutes\"`" + .to_string() + } else { + "Make sure the expiration time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] \ + etc. or `day/month/year-hour:minute:second`" + .to_string() + } + } + ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s), } } } diff --git a/src/models/reminder/helper.rs b/src/models/reminder/helper.rs index 05edcde..3156f52 100644 --- a/src/models/reminder/helper.rs +++ b/src/models/reminder/helper.rs @@ -10,9 +10,8 @@ pub fn longhand_displacement(seconds: u64) -> String { let mut sections = vec![]; - for (var, name) in [days, hours, minutes, seconds] - .iter() - .zip(["days", "hours", "minutes", "seconds"].iter()) + for (var, name) in + [days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter()) { if *var > 0 { sections.push(format!("{} {}", var, name)); @@ -26,14 +25,7 @@ pub fn generate_uid() -> String { let mut generator: OsRng = Default::default(); (0..64) - .map(|_| { - CHARACTERS - .chars() - .choose(&mut generator) - .unwrap() - .to_owned() - .to_string() - }) + .map(|_| CHARACTERS.chars().choose(&mut generator).unwrap().to_owned().to_string()) .collect::>() .join("") } diff --git a/src/models/reminder/mod.rs b/src/models/reminder/mod.rs index 4197d6b..ac30cfe 100644 --- a/src/models/reminder/mod.rs +++ b/src/models/reminder/mod.rs @@ -329,18 +329,14 @@ WHERE self.display_content(), time_display, longhand_displacement(interval as u64), - self.set_by - .map(|i| format!("<@{}>", i)) - .unwrap_or_else(|| "unknown".to_string()) + self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string()) ) } else { format!( "'{}' *occurs next at* **{}** (set by {})", self.display_content(), time_display, - self.set_by - .map(|i| format!("<@{}>", i)) - .unwrap_or_else(|| "unknown".to_string()) + self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string()) ) } } @@ -380,9 +376,7 @@ WHERE pub fn signed_action>(&self, member_id: U, action: ReminderAction) -> String { let s_key = hmac::Key::new( hmac::HMAC_SHA256, - env::var("SECRET_KEY") - .expect("No SECRET_KEY provided") - .as_bytes(), + env::var("SECRET_KEY").expect("No SECRET_KEY provided").as_bytes(), ); let mut context = hmac::Context::with_key(&s_key); diff --git a/src/models/user_data.rs b/src/models/user_data.rs index f06eef2..76d730c 100644 --- a/src/models/user_data.rs +++ b/src/models/user_data.rs @@ -52,7 +52,8 @@ SELECT timezone FROM users WHERE user = ? " SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ? ", - *LOCAL_TIMEZONE, user_id + *LOCAL_TIMEZONE, + user_id ) .fetch_one(pool) .await @@ -77,9 +78,14 @@ INSERT IGNORE INTO channels (channel) VALUES (?) sqlx::query!( " INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?) - ", user_id, user.name, dm_id, *LOCAL_TIMEZONE) - .execute(&pool_c) - .await?; + ", + user_id, + user.name, + dm_id, + *LOCAL_TIMEZONE + ) + .execute(&pool_c) + .await?; Ok(sqlx::query_as_unchecked!( Self, @@ -96,7 +102,7 @@ SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ? error!("Error querying for user: {:?}", e); Err(Box::new(e)) - }, + } } } diff --git a/src/time_parser.rs b/src/time_parser.rs index 0e21a06..1dd912b 100644 --- a/src/time_parser.rs +++ b/src/time_parser.rs @@ -98,10 +98,7 @@ impl TimeParser { } fn process_explicit(&self) -> Result { - let mut time = Utc::now() - .with_timezone(&self.timezone) - .with_second(0) - .unwrap(); + let mut time = Utc::now().with_timezone(&self.timezone).with_second(0).unwrap(); let mut segments = self.time_string.rsplit('-'); // this segment will always exist even if split fails @@ -109,11 +106,9 @@ impl TimeParser { let h_m_s = hms.split(':'); - for (t, setter) in h_m_s.take(3).zip(&[ - DateTime::with_hour, - DateTime::with_minute, - DateTime::with_second, - ]) { + for (t, setter) in + h_m_s.take(3).zip(&[DateTime::with_hour, DateTime::with_minute, DateTime::with_second]) + { time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorHMS)?) .map_or_else(|| Err(InvalidTime::ParseErrorHMS), Ok)?; } @@ -125,9 +120,7 @@ impl TimeParser { let month = d_m_y.next(); let year = d_m_y.next(); - for (t, setter) in [day, month] - .iter() - .zip(&[DateTime::with_day, DateTime::with_month]) + for (t, setter) in [day, month].iter().zip(&[DateTime::with_day, DateTime::with_month]) { if let Some(t) = t { time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorDMY)?)