diff --git a/Cargo.lock b/Cargo.lock index d737f43..481cca3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1213,6 +1213,7 @@ dependencies = [ "proc-macro2", "quote", "syn", + "uuid", ] [[package]] @@ -2016,6 +2017,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "uuid" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" +dependencies = [ + "getrandom 0.2.3", +] + [[package]] name = "uwl" version = "0.6.0" diff --git a/README.md b/README.md index adcf466..52ee94d 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,6 @@ Reminder Bot can be built by running `cargo build --release` in the top level di These environment variables must be provided when compiling the bot * `DATABASE_URL` - the URL of your MySQL database (`mysql://user[:password]@domain/database`) * `WEBHOOK_AVATAR` - accepts the name of an image file located in `$CARGO_MANIFEST_DIR/assets/` to be used as the avatar when creating webhooks. **IMPORTANT: image file must be 128x128 or smaller in size** -* `STRINGS_FILE` - accepts the name of a compiled strings file located in `$CARGO_MANIFEST_DIR/assets/` to be used for creating messages. Compiled string files can be generated with `compile.py` at https://github.com/reminder-bot/languages ### Setting up Python Reminder Bot by default looks for a venv within it's working directory to run Python out of. To set up a venv, install `python3-venv` and run `python3 -m venv venv`. Then, run `source venv/bin/activate` to activate the venv, and do `pip install dateparser` to install the required library @@ -29,14 +28,12 @@ __Required Variables__ __Other Variables__ * `MIN_INTERVAL` - default `600`, defines the shortest interval the bot should accept -* `MAX_TIME` - default `1576800000`, defines the maximum time ahead that reminders can be set for * `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor * `DEFAULT_PREFIX` - default `$`, used for the default prefix on new guilds * `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users * `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to * `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots * `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else -* `LOCAL_LANGUAGE` - default `EN`. Specifies the string set to fall back to if a string cannot be found (and to be used with new users) * `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds * `CASE_INSENSITIVE` - default `1`, if `1`, commands will be treated with case insensitivity (so both `$help` and `$HELP` will work) * `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran diff --git a/command_attributes/Cargo.toml b/command_attributes/Cargo.toml index 24e5fa5..a4ce1d3 100644 --- a/command_attributes/Cargo.toml +++ b/command_attributes/Cargo.toml @@ -13,3 +13,4 @@ proc-macro = true quote = "^1.0" syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] } proc-macro2 = "1.0" +uuid = { version = "0.8", features = ["v4"] } diff --git a/command_attributes/src/attributes.rs b/command_attributes/src/attributes.rs index 3b93fd8..1293186 100644 --- a/command_attributes/src/attributes.rs +++ b/command_attributes/src/attributes.rs @@ -8,7 +8,7 @@ use syn::{ }; use crate::{ - structures::{ApplicationCommandOptionType, Arg, PermissionLevel}, + structures::{ApplicationCommandOptionType, Arg}, util::{AsOption, LitExt}, }; @@ -46,24 +46,15 @@ impl fmt::Display for ValueKind { fn to_ident(p: Path) -> Result { if p.segments.is_empty() { - return Err(Error::new( - p.span(), - "cannot convert an empty path to an identifier", - )); + return Err(Error::new(p.span(), "cannot convert an empty path to an identifier")); } if p.segments.len() > 1 { - return Err(Error::new( - p.span(), - "the path must not have more than one segment", - )); + return Err(Error::new(p.span(), "the path must not have more than one segment")); } if !p.segments[0].arguments.is_empty() { - return Err(Error::new( - p.span(), - "the singular path segment must not have any arguments", - )); + return Err(Error::new(p.span(), "the singular path segment must not have any arguments")); } Ok(p.segments[0].ident.clone()) @@ -85,12 +76,7 @@ impl Values { literals: Vec<(Option, Lit)>, span: Span, ) -> Self { - Values { - name, - literals, - kind, - span, - } + Values { name, literals, kind, span } } } @@ -145,11 +131,7 @@ pub fn parse_values(attr: &Attribute) -> Result { } } - let kind = if lits.len() == 1 { - ValueKind::SingleList - } else { - ValueKind::List - }; + let kind = if lits.len() == 1 { ValueKind::SingleList } else { ValueKind::List }; Ok(Values::new(name, kind, lits, attr.span())) } else { @@ -183,12 +165,7 @@ pub fn parse_values(attr: &Attribute) -> Result { let name = to_ident(meta.path)?; let lit = meta.lit; - Ok(Values::new( - name, - ValueKind::Equals, - vec![(None, lit)], - attr.span(), - )) + Ok(Values::new(name, ValueKind::Equals, vec![(None, lit)], attr.span())) } } } @@ -231,10 +208,7 @@ fn validate(values: &Values, forms: &[ValueKind]) -> Result<()> { return Err(Error::new( values.span, // Using the `_args` version here to avoid an allocation. - format_args!( - "the attribute must be in of these forms:\n{}", - DisplaySlice(forms) - ), + format_args!("the attribute must be in of these forms:\n{}", DisplaySlice(forms)), )); } @@ -254,11 +228,7 @@ impl AttributeOption for Vec { fn parse(values: Values) -> Result { validate(&values, &[ValueKind::List])?; - Ok(values - .literals - .into_iter() - .map(|(_, l)| l.to_str()) - .collect()) + Ok(values.literals.into_iter().map(|(_, l)| l.to_str()).collect()) } } @@ -294,37 +264,18 @@ impl AttributeOption for Vec { fn parse(values: Values) -> Result { validate(&values, &[ValueKind::List])?; - Ok(values - .literals - .into_iter() - .map(|(_, l)| l.to_ident()) - .collect()) + Ok(values.literals.into_iter().map(|(_, l)| l.to_ident()).collect()) } } impl AttributeOption for Option { fn parse(values: Values) -> Result { - validate( - &values, - &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList], - )?; + validate(&values, &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList])?; Ok(values.literals.get(0).map(|(_, l)| l.to_str())) } } -impl AttributeOption for PermissionLevel { - fn parse(values: Values) -> Result { - validate(&values, &[ValueKind::SingleList])?; - - Ok(values - .literals - .get(0) - .map(|(_, l)| PermissionLevel::from_str(&*l.to_str()).unwrap()) - .unwrap()) - } -} - impl AttributeOption for Arg { fn parse(values: Values) -> Result { validate(&values, &[ValueKind::EqualsList])?; diff --git a/command_attributes/src/consts.rs b/command_attributes/src/consts.rs index f3e2533..1ed969f 100644 --- a/command_attributes/src/consts.rs +++ b/command_attributes/src/consts.rs @@ -2,6 +2,8 @@ pub mod suffixes { pub const COMMAND: &str = "COMMAND"; pub const ARG: &str = "ARG"; pub const SUBCOMMAND: &str = "SUBCOMMAND"; + pub const CHECK: &str = "CHECK"; + pub const HOOK: &str = "HOOK"; } pub use self::suffixes::*; diff --git a/command_attributes/src/lib.rs b/command_attributes/src/lib.rs index 26133c0..2ccec4d 100644 --- a/command_attributes/src/lib.rs +++ b/command_attributes/src/lib.rs @@ -5,6 +5,7 @@ use proc_macro::TokenStream; use proc_macro2::Ident; use quote::quote; use syn::{parse::Error, parse_macro_input, parse_quote, spanned::Spanned, Lit, Type}; +use uuid::Uuid; pub(crate) mod attributes; pub(crate) mod consts; @@ -43,6 +44,7 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { fun.name.to_string() }; + let mut hooks: Vec = Vec::new(); let mut options = Options::new(); for attribute in &fun.attributes { @@ -76,11 +78,13 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { util::append_line(&mut options.description, line); } } + "hook" => { + hooks.push(propagate_err!(attributes::parse(values))); + } _ => { match_options!(name, values, options, span => [ aliases; group; - required_permissions; can_blacklist; supports_dm ]); @@ -93,7 +97,6 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { description, group, examples, - required_permissions, can_blacklist, supports_dm, mut cmd_args, @@ -235,10 +238,10 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { desc: #description, group: #group, examples: &[#(#examples),*], - required_permissions: #required_permissions, can_blacklist: #can_blacklist, supports_dm: #supports_dm, args: &[#(&#arg_idents),*], + hooks: &[#(&#hooks),*], }; }); @@ -256,3 +259,44 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { tokens.into() } + +#[proc_macro_attribute] +pub fn check(_attr: TokenStream, input: TokenStream) -> TokenStream { + let mut fun = parse_macro_input!(input as CommandFun); + + let n = fun.name.clone(); + let name = n.with_suffix(HOOK); + let fn_name = n.with_suffix(CHECK); + let visibility = fun.visibility; + + let cooked = fun.cooked; + let body = fun.body; + let ret = fun.ret; + populate_fut_lifetimes_on_refs(&mut fun.args); + let args = fun.args; + + let hook_path = quote!(crate::framework::Hook); + let uuid = Uuid::new_v4().as_u128(); + + (quote! { + #(#cooked)* + #[allow(missing_docs)] + #visibility fn #fn_name<'fut>(#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> { + use ::serenity::futures::future::FutureExt; + + async move { + let _output: #ret = { #(#body)* }; + #[allow(unreachable_code)] + _output + }.boxed() + } + + #(#cooked)* + #[allow(missing_docs)] + pub static #name: #hook_path = #hook_path { + fun: #fn_name, + uuid: #uuid, + }; + }) + .into() +} diff --git a/command_attributes/src/structures.rs b/command_attributes/src/structures.rs index dc781ee..1985a3a 100644 --- a/command_attributes/src/structures.rs +++ b/command_attributes/src/structures.rs @@ -4,7 +4,7 @@ use syn::{ braced, parse::{Error, Parse, ParseStream, Result}, spanned::Spanned, - Attribute, Block, FnArg, Ident, Pat, Stmt, Token, Visibility, + Attribute, Block, FnArg, Ident, Pat, ReturnType, Stmt, Token, Type, Visibility, }; use crate::util::{Argument, Parenthesised}; @@ -78,6 +78,7 @@ pub struct CommandFun { pub visibility: Visibility, pub name: Ident, pub args: Vec, + pub ret: Type, pub body: Vec, } @@ -97,6 +98,11 @@ impl Parse for CommandFun { // (...) let Parenthesised(args) = input.parse::>()?; + let ret = match input.parse::()? { + ReturnType::Type(_, t) => (*t).clone(), + ReturnType::Default => Type::Verbatim(quote!(())), + }; + // { ... } let bcont; braced!(bcont in input); @@ -104,72 +110,23 @@ impl Parse for CommandFun { let args = args.into_iter().map(parse_argument).collect::>>()?; - Ok(Self { attributes, cooked, visibility, name, args, body }) + Ok(Self { attributes, cooked, visibility, name, args, ret, body }) } } impl ToTokens for CommandFun { fn to_tokens(&self, stream: &mut TokenStream2) { - let Self { attributes: _, cooked, visibility, name, args, body } = self; + let Self { attributes: _, cooked, visibility, name, args, ret, body } = self; stream.extend(quote! { #(#cooked)* - #visibility async fn #name (#(#args),*) { + #visibility async fn #name (#(#args),*) -> #ret { #(#body)* } }); } } -#[derive(Debug)] -pub enum PermissionLevel { - Unrestricted, - Managed, - Restricted, -} - -impl Default for PermissionLevel { - fn default() -> Self { - Self::Unrestricted - } -} - -impl PermissionLevel { - pub fn from_str(s: &str) -> Option { - Some(match s.to_uppercase().as_str() { - "UNRESTRICTED" => Self::Unrestricted, - "MANAGED" => Self::Managed, - "RESTRICTED" => Self::Restricted, - _ => return None, - }) - } -} - -impl ToTokens for PermissionLevel { - fn to_tokens(&self, stream: &mut TokenStream2) { - let path = quote!(crate::framework::PermissionLevel); - let variant; - - match self { - Self::Unrestricted => { - variant = quote!(Unrestricted); - } - - Self::Managed => { - variant = quote!(Managed); - } - - Self::Restricted => { - variant = quote!(Restricted); - } - } - - stream.extend(quote! { - #path::#variant - }); - } -} - #[derive(Debug)] pub(crate) enum ApplicationCommandOptionType { SubCommand, @@ -272,7 +229,6 @@ pub(crate) struct Options { pub description: String, pub group: String, pub examples: Vec, - pub required_permissions: PermissionLevel, pub can_blacklist: bool, pub supports_dm: bool, pub cmd_args: Vec, diff --git a/migration/02-macro.sql b/migration/02-macro.sql new file mode 100644 index 0000000..1018f97 --- /dev/null +++ b/migration/02-macro.sql @@ -0,0 +1,11 @@ +CREATE TABLE macro ( + id INT UNSIGNED AUTO_INCREMENT, + guild_id INT UNSIGNED NOT NULL, + + name VARCHAR(100) NOT NULL, + description VARCHAR(100), + commands TEXT, + + FOREIGN KEY (guild_id) REFERENCES guilds(id), + PRIMARY KEY (id) +); diff --git a/src/commands/info_cmds.rs b/src/commands/info_cmds.rs index 7eb4295..9388d65 100644 --- a/src/commands/info_cmds.rs +++ b/src/commands/info_cmds.rs @@ -27,7 +27,7 @@ fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEm #[aliases("invite")] #[description("Get information about the bot")] #[group("Info")] -async fn info(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { +async fn info(ctx: &Context, invoke: CommandInvoke) { let prefix = ctx.prefix(invoke.guild_id()).await; let current_user = ctx.cache.current_user(); let footer = footer(ctx); @@ -61,7 +61,7 @@ Use our dashboard: https://reminder-bot.com/", #[command] #[description("Details on supporting the bot and Patreon benefits")] #[group("Info")] -async fn donate(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { +async fn donate(ctx: &Context, invoke: CommandInvoke) { let footer = footer(ctx); let _ = invoke @@ -94,7 +94,7 @@ Just $2 USD/month! #[command] #[description("Get the link to the online dashboard")] #[group("Info")] -async fn dashboard(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { +async fn dashboard(ctx: &Context, invoke: CommandInvoke) { let footer = footer(ctx); let _ = invoke @@ -113,7 +113,7 @@ async fn dashboard(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { #[command] #[description("View the current time in your selected timezone")] #[group("Info")] -async fn clock(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { +async fn clock(ctx: &Context, invoke: CommandInvoke) { let ud = ctx.user_data(&invoke.author_id()).await.unwrap(); let now = Utc::now().with_timezone(&ud.timezone()); diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index 2d0ae7f..43962a4 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -2,16 +2,21 @@ use chrono::offset::Utc; use chrono_tz::{Tz, TZ_VARIANTS}; use levenshtein::levenshtein; use regex_command_attr::command; -use serenity::{client::Context, model::misc::Mentionable}; +use serenity::{ + client::Context, + model::{ + interactions::InteractionResponseType, misc::Mentionable, + prelude::InteractionApplicationCommandCallbackDataFlags, + }, +}; use crate::{ component_models::{ComponentDataModel, Restrict}, consts::THEME_COLOR, - framework::{ - CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue, PermissionLevel, - }, - models::{channel_data::ChannelData, CtxData}, - PopularTimezones, RegexFramework, SQLPool, + framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue}, + hooks::{CHECK_GUILD_PERMISSIONS_HOOK, CHECK_MANAGED_PERMISSIONS_HOOK}, + models::{channel_data::ChannelData, command_macro::CommandMacro, CtxData}, + PopularTimezones, RecordingMacros, RegexFramework, SQLPool, }; #[command("blacklist")] @@ -23,13 +28,9 @@ use crate::{ required = false )] #[supports_dm(false)] -#[required_permissions(Restricted)] +#[hook(CHECK_GUILD_PERMISSIONS_HOOK)] #[can_blacklist(false)] -async fn blacklist( - ctx: &Context, - invoke: &(dyn CommandInvoke + Send + Sync), - args: CommandOptions, -) { +async fn blacklist(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let channel = match args.get("channel") { @@ -72,7 +73,7 @@ async fn blacklist( kind = "String", required = false )] -async fn timezone(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { +async fn timezone(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap(); @@ -178,8 +179,8 @@ You may want to use one of the popular timezones below, otherwise click [here](h #[command("prefix")] #[description("Configure a prefix for text-based commands (deprecated)")] #[supports_dm(false)] -#[required_permissions(Restricted)] -async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: String) { +#[hook(CHECK_GUILD_PERMISSIONS_HOOK)] +async fn prefix(ctx: &Context, invoke: CommandInvoke, args: String) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap(); @@ -222,8 +223,8 @@ async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: required = true )] #[supports_dm(false)] -#[required_permissions(Restricted)] -async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { +#[hook(CHECK_GUILD_PERMISSIONS_HOOK)] +async fn restrict(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let framework = ctx.data.read().await.get::().cloned().unwrap(); @@ -240,7 +241,7 @@ async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), arg let restrictable_commands = framework .commands .iter() - .filter(|c| c.required_permissions == PermissionLevel::Managed) + .filter(|c| c.hooks.contains(&&CHECK_MANAGED_PERMISSIONS_HOOK)) .map(|c| c.names[0].to_string()) .collect::>(); @@ -289,6 +290,132 @@ async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), arg } } +#[command("macro")] +#[description("Record and replay command sequences")] +#[subcommand("record")] +#[description("Start recording up to 5 commands to replay")] +#[arg(name = "name", description = "Name for the new macro", kind = "String", required = true)] +#[arg( + name = "description", + description = "Description for the new macro", + kind = "String", + required = false +)] +#[subcommand("finish")] +#[description("Finish current recording")] +#[subcommand("list")] +#[description("List recorded macros")] +#[subcommand("run")] +#[description("Run a recorded macro")] +#[arg(name = "name", description = "Name of the macro to run", kind = "String", required = true)] +#[supports_dm(false)] +#[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] +async fn macro_cmd(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { + let interaction = invoke.interaction().unwrap(); + + match args.subcommand.clone().unwrap().as_str() { + "record" => { + let macro_buffer = ctx.data.read().await.get::().cloned().unwrap(); + + { + let mut lock = macro_buffer.write().await; + + let guild_id = interaction.guild_id.unwrap(); + + lock.insert( + (guild_id, interaction.user.id), + CommandMacro { + guild_id, + name: args.get("name").unwrap().to_string(), + description: args.get("description").map(|d| d.to_string()), + commands: vec![], + }, + ); + } + + let _ = interaction + .create_interaction_response(&ctx, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| { + d.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL) + .create_embed(|e| { + e + .title("Macro Recording Started") + .description( +"Run up to 5 commands, or type `/macro finish` to stop at any point. +Any commands ran as part of recording will be inconsequential") + .color(*THEME_COLOR) + }) + }) + }) + .await; + } + "finish" => { + let key = (interaction.guild_id.unwrap(), interaction.user.id); + let macro_buffer = ctx.data.read().await.get::().cloned().unwrap(); + + { + let lock = macro_buffer.read().await; + let contained = lock.get(&key); + + if contained.map_or(true, |cmacro| cmacro.commands.is_empty()) { + let _ = interaction + .create_interaction_response(&ctx, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| { + d.create_embed(|e| { + e.title("No Macro Recorded") + .description( + "Use `/macro record` to start recording a macro", + ) + .color(*THEME_COLOR) + }) + }) + }) + .await; + } else { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + + let command_macro = contained.unwrap(); + let json = serde_json::to_string(&command_macro.commands).unwrap(); + + sqlx::query!( + "INSERT INTO macro (guild_id, name, description, commands) VALUES ((SELECT id FROM guilds WHERE guild = ?), ?, ?, ?)", + command_macro.guild_id.0, + command_macro.name, + command_macro.description, + json + ) + .execute(&pool) + .await + .unwrap(); + + let _ = interaction + .create_interaction_response(&ctx, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| { + d.create_embed(|e| { + e.title("Macro Recorded") + .description("Use `/macro run` to execute the macro") + .color(*THEME_COLOR) + }) + }) + }) + .await; + } + } + + { + let mut lock = macro_buffer.write().await; + lock.remove(&key); + } + } + "list" => {} + "run" => {} + _ => {} + } +} + /* #[command("alias")] #[supports_dm(false)] diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 2fe8371..60222de 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -5,16 +5,13 @@ use std::{ }; use chrono::NaiveDateTime; +use chrono_tz::Tz; use num_integer::Integer; use regex_command_attr::command; use serenity::{ - builder::CreateEmbed, + builder::{CreateEmbed, CreateInteractionResponse}, client::Context, - model::{ - channel::Channel, - id::{GuildId, UserId}, - interactions::InteractionResponseType, - }, + model::{channel::Channel, interactions::InteractionResponseType}, }; use crate::{ @@ -28,6 +25,7 @@ use crate::{ REGEX_NATURAL_COMMAND_2, THEME_COLOR, }, framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue}, + hooks::{CHECK_GUILD_PERMISSIONS_HOOK, CHECK_MANAGED_PERMISSIONS_HOOK}, models::{ channel_data::ChannelData, guild_data::GuildData, @@ -55,8 +53,8 @@ use crate::{ required = false )] #[supports_dm(false)] -#[required_permissions(Restricted)] -async fn pause(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { +#[hook(CHECK_GUILD_PERMISSIONS_HOOK)] +async fn pause(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; @@ -141,8 +139,8 @@ async fn pause(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: kind = "Integer", required = false )] -#[required_permissions(Restricted)] -async fn offset(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { +#[hook(CHECK_GUILD_PERMISSIONS_HOOK)] +async fn offset(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let combined_time = args.get("hours").map_or(0, |h| h.as_i64().unwrap() * 3600) @@ -218,8 +216,8 @@ WHERE FIND_IN_SET(channels.`channel`, ?)", kind = "Integer", required = false )] -#[required_permissions(Restricted)] -async fn nudge(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { +#[hook(CHECK_GUILD_PERMISSIONS_HOOK)] +async fn nudge(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let combined_time = args.get("minutes").map_or(0, |m| m.as_i64().unwrap() * 60) @@ -270,8 +268,8 @@ async fn nudge(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: kind = "Boolean", required = false )] -#[required_permissions(Managed)] -async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { +#[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] +async fn look(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; @@ -363,103 +361,132 @@ async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: C #[command("del")] #[description("Delete reminders")] -#[required_permissions(Managed)] -async fn delete(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { - let pool = ctx.data.read().await.get::().cloned().unwrap(); +#[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] +async fn delete(ctx: &Context, invoke: CommandInvoke, _args: CommandOptions) { + let interaction = invoke.interaction().unwrap(); - let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; + let timezone = ctx.timezone(interaction.user.id).await; - let reminders = Reminder::from_guild(ctx, invoke.guild_id(), invoke.author_id()).await; + let reminders = Reminder::from_guild(ctx, interaction.guild_id, interaction.user.id).await; - if reminders.is_empty() { - let _ = invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new().content("No reminders to delete!"), - ) - .await; - } else { - let mut char_count = 0; + let resp = show_delete_page(&reminders, 0, timezone).await; - let (shown_reminders, display_vec): (Vec<&Reminder>, Vec) = reminders - .iter() - .enumerate() - .map(|(count, reminder)| (reminder, reminder.display_del(count, &timezone))) - .take_while(|(_, p)| { - char_count += p.len(); - - char_count < EMBED_DESCRIPTION_MAX_LENGTH - }) - .unzip(); - - let display = display_vec.join("\n"); - - let pages = reminders - .iter() - .enumerate() - .map(|(count, reminder)| reminder.display_del(count, &timezone)) - .fold(0, |t, r| t + r.len()) - .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); - - let pager = DelPager::new(timezone); - - let del_selector = ComponentDataModel::DelSelector(DelSelector { page: 0, timezone }); - - invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new() - .embed(|e| { - e.title("Delete Reminders") - .description(display) - .footer(|f| f.text(format!("Page {} of {}", 1, pages))) - .color(*THEME_COLOR) - }) - .components(|comp| { - pager.create_button_row(pages, comp); - - comp.create_action_row(|row| { - row.create_select_menu(|menu| { - menu.custom_id(del_selector.to_custom_id()).options(|opt| { - for (count, reminder) in shown_reminders.iter().enumerate() { - opt.create_option(|o| { - o.label(count + 1).value(reminder.id).description({ - let c = reminder.display_content(); - - if c.len() > 100 { - format!( - "{}...", - reminder - .display_content() - .chars() - .take(97) - .collect::() - ) - } else { - c.to_string() - } - }) - }); - } - - opt - }) - }) - }) - }), - ) - .await - .unwrap(); - } + let _ = interaction + .create_interaction_response(&ctx, |r| { + *r = resp; + r + }) + .await; } -async fn show_delete_page( - ctx: &Context, - guild_id: Option, - user_id: UserId, +pub fn max_delete_page(reminders: &Vec, timezone: &Tz) -> usize { + reminders + .iter() + .enumerate() + .map(|(count, reminder)| reminder.display_del(count, timezone)) + .fold(0, |t, r| t + r.len()) + .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH) +} + +pub async fn show_delete_page( + reminders: &Vec, page: usize, timezone: Tz, -) { +) -> CreateInteractionResponse { + let pager = DelPager::new(timezone); + + if reminders.is_empty() { + let mut embed = CreateEmbed::default(); + embed.title("Delete Reminders").description("No Reminders").color(*THEME_COLOR); + + let mut response = CreateInteractionResponse::default(); + response.kind(InteractionResponseType::UpdateMessage).interaction_response_data( + |response| { + response.embeds(vec![embed]).components(|comp| { + pager.create_button_row(0, comp); + comp + }) + }, + ); + + return response; + } + + let pages = max_delete_page(&reminders, &timezone); + + let mut page = page; + if page >= pages { + page = pages - 1; + } + + let mut char_count = 0; + let mut skip_char_count = 0; + let mut first_num = 0; + + let (shown_reminders, display_vec): (Vec<&Reminder>, Vec) = reminders + .iter() + .enumerate() + .map(|(count, reminder)| (reminder, reminder.display_del(count, &timezone))) + .skip_while(|(_, p)| { + first_num += 1; + skip_char_count += p.len(); + + skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * page + }) + .take_while(|(_, p)| { + char_count += p.len(); + + char_count < EMBED_DESCRIPTION_MAX_LENGTH + }) + .unzip(); + + let display = display_vec.join("\n"); + + let del_selector = ComponentDataModel::DelSelector(DelSelector { page, timezone }); + + let mut embed = CreateEmbed::default(); + embed + .title("Delete Reminders") + .description(display) + .footer(|f| f.text(format!("Page {} of {}", page + 1, pages))) + .color(*THEME_COLOR); + + let mut response = CreateInteractionResponse::default(); + response.kind(InteractionResponseType::UpdateMessage).interaction_response_data(|d| { + d.embeds(vec![embed]).components(|comp| { + pager.create_button_row(pages, comp); + + comp.create_action_row(|row| { + row.create_select_menu(|menu| { + menu.custom_id(del_selector.to_custom_id()).options(|opt| { + for (count, reminder) in shown_reminders.iter().enumerate() { + opt.create_option(|o| { + o.label(count + first_num).value(reminder.id).description({ + let c = reminder.display_content(); + + if c.len() > 100 { + format!( + "{}...", + reminder + .display_content() + .chars() + .take(97) + .collect::() + ) + } else { + c.to_string() + } + }) + }); + } + + opt + }) + }) + }) + }) + }); + response } #[command("timer")] @@ -472,8 +499,8 @@ async fn show_delete_page( #[subcommand("delete")] #[description("Delete a timer")] #[arg(name = "name", description = "Name of the timer to delete", kind = "String", required = true)] -#[required_permissions(Managed)] -async fn timer(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { +#[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] +async fn timer(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { fn time_difference(start_time: NaiveDateTime) -> String { let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let now = NaiveDateTime::from_timestamp(unix_time, 0); @@ -638,8 +665,8 @@ DELETE FROM timers WHERE owner = ? AND name = ? kind = "Boolean", required = false )] -#[required_permissions(Managed)] -async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { +#[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] +async fn remind(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { let interaction = invoke.interaction().unwrap(); // defer response since processing times can take some time @@ -650,7 +677,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: .await .unwrap(); - let user_data = ctx.user_data(invoke.author_id()).await.unwrap(); + let user_data = ctx.user_data(interaction.user.id).await.unwrap(); let timezone = user_data.timezone(); let time = { @@ -675,7 +702,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: .unwrap_or(vec![]); if list.is_empty() { - vec![ReminderScope::Channel(invoke.channel_id().0)] + vec![ReminderScope::Channel(interaction.channel_id.0)] } else { list } @@ -698,7 +725,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: } }; - let mut builder = MultiReminderBuilder::new(ctx, invoke.guild_id()) + let mut builder = MultiReminderBuilder::new(ctx, interaction.guild_id) .author(user_data) .content(content) .time(time) diff --git a/src/component_models/mod.rs b/src/component_models/mod.rs index 8adba69..22780f1 100644 --- a/src/component_models/mod.rs +++ b/src/component_models/mod.rs @@ -17,6 +17,7 @@ use serenity::{ }; use crate::{ + commands::reminder_cmds::{max_delete_page, show_delete_page}, component_models::pager::{DelPager, LookPager, Pager}, consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR}, models::reminder::Reminder, @@ -165,98 +166,15 @@ INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id let reminders = Reminder::from_guild(ctx, component.guild_id, component.user.id).await; - let pages = reminders - .iter() - .enumerate() - .map(|(count, reminder)| reminder.display_del(count, &pager.timezone)) - .fold(0, |t, r| t + r.len()) - .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); + let max_pages = max_delete_page(&reminders, &pager.timezone); - let next_page = pager.next_page(pages); + let resp = + show_delete_page(&reminders, pager.next_page(max_pages), pager.timezone).await; - let mut char_count = 0; - let mut skip_char_count = 0; - let mut first_num = 0; - - let (shown_reminders, display_vec): (Vec<&Reminder>, Vec) = reminders - .iter() - .enumerate() - .map(|(count, reminder)| { - (reminder, reminder.display_del(count, &pager.timezone)) - }) - .skip_while(|(_, p)| { - first_num += 1; - skip_char_count += p.len(); - - skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * next_page - }) - .take_while(|(_, p)| { - char_count += p.len(); - - char_count < EMBED_DESCRIPTION_MAX_LENGTH - }) - .unzip(); - - let display = display_vec.join("\n"); - - let del_selector = ComponentDataModel::DelSelector(DelSelector { - page: next_page, - timezone: pager.timezone, - }); - - let mut embed = CreateEmbed::default(); - embed - .title("Delete Reminders") - .description(display) - .footer(|f| f.text(format!("Page {} of {}", next_page + 1, pages))) - .color(*THEME_COLOR); - - component - .create_interaction_response(&ctx, |r| { - r.kind(InteractionResponseType::UpdateMessage).interaction_response_data( - |response| { - response.embeds(vec![embed]).components(|comp| { - pager.create_button_row(pages, comp); - - comp.create_action_row(|row| { - row.create_select_menu(|menu| { - menu.custom_id(del_selector.to_custom_id()).options( - |opt| { - for (count, reminder) in - shown_reminders.iter().enumerate() - { - opt.create_option(|o| { - o.label(count + first_num) - .value(reminder.id) - .description({ - let c = - reminder.display_content(); - - if c.len() > 100 { - format!( - "{}...", - reminder - .display_content() - .chars() - .take(97) - .collect::( - ) - ) - } else { - c.to_string() - } - }) - }); - } - - opt - }, - ) - }) - }) - }) - }, - ) + let _ = component + .create_interaction_response(&ctx, move |r| { + *r = resp; + r }) .await; } @@ -272,119 +190,12 @@ INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id let reminders = Reminder::from_guild(ctx, component.guild_id, component.user.id).await; - if reminders.is_empty() { - let mut embed = CreateEmbed::default(); - embed.title("Delete Reminders").description("No Reminders").color(*THEME_COLOR); + let resp = show_delete_page(&reminders, selector.page, selector.timezone).await; - component - .create_interaction_response(&ctx, |r| { - r.kind(InteractionResponseType::UpdateMessage) - .interaction_response_data(|response| { - response.embeds(vec![embed]).components(|comp| comp) - }) - }) - .await; - - return; - } - - let pages = reminders - .iter() - .enumerate() - .map(|(count, reminder)| reminder.display_del(count, &selector.timezone)) - .fold(0, |t, r| t + r.len()) - .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); - - let mut page = selector.page; - if page >= pages { - page = pages - 1; - } - - let mut char_count = 0; - let mut skip_char_count = 0; - let mut first_num = 0; - - let (shown_reminders, display_vec): (Vec<&Reminder>, Vec) = reminders - .iter() - .enumerate() - .map(|(count, reminder)| { - (reminder, reminder.display_del(count, &selector.timezone)) - }) - .skip_while(|(_, p)| { - first_num += 1; - skip_char_count += p.len(); - - skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * page - }) - .take_while(|(_, p)| { - char_count += p.len(); - - char_count < EMBED_DESCRIPTION_MAX_LENGTH - }) - .unzip(); - - let display = display_vec.join("\n"); - - let pager = DelPager::new(selector.timezone); - - let del_selector = ComponentDataModel::DelSelector(DelSelector { - page, - timezone: selector.timezone, - }); - - let mut embed = CreateEmbed::default(); - embed - .title("Delete Reminders") - .description(display) - .footer(|f| f.text(format!("Page {} of {}", page + 1, pages))) - .color(*THEME_COLOR); - - component - .create_interaction_response(&ctx, |r| { - r.kind(InteractionResponseType::UpdateMessage).interaction_response_data( - |response| { - response.embeds(vec![embed]).components(|comp| { - pager.create_button_row(pages, comp); - - comp.create_action_row(|row| { - row.create_select_menu(|menu| { - menu.custom_id(del_selector.to_custom_id()).options( - |opt| { - for (count, reminder) in - shown_reminders.iter().enumerate() - { - opt.create_option(|o| { - o.label(count + first_num) - .value(reminder.id) - .description({ - let c = - reminder.display_content(); - - if c.len() > 100 { - format!( - "{}...", - reminder - .display_content() - .chars() - .take(97) - .collect::( - ) - ) - } else { - c.to_string() - } - }) - }); - } - - opt - }, - ) - }) - }) - }) - }, - ) + let _ = component + .create_interaction_response(&ctx, move |r| { + *r = resp; + r }) .await; } diff --git a/src/framework.rs b/src/framework.rs index 60a3e22..e79c50a 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -6,8 +6,9 @@ use std::{ sync::Arc, }; -use log::{error, info, warn}; +use log::info; use regex::{Match, Regex, RegexBuilder}; +use serde::{Deserialize, Serialize}; use serenity::{ async_trait, builder::{CreateApplicationCommands, CreateComponents, CreateEmbed}, @@ -15,9 +16,9 @@ use serenity::{ client::Context, framework::Framework, futures::prelude::future::BoxFuture, - http::Http, + http::{CacheHttp, Http}, model::{ - channel::{Channel, GuildChannel, Message}, + channel::Message, guild::{Guild, Member}, id::{ChannelId, GuildId, MessageId, RoleId, UserId}, interactions::{ @@ -29,20 +30,10 @@ use serenity::{ prelude::application_command::ApplicationCommandInteractionDataOption, }, prelude::TypeMapKey, - FutureExt, Result as SerenityResult, + Result as SerenityResult, }; -use crate::{ - models::{channel_data::ChannelData, CtxData}, - LimitExecutors, SQLPool, -}; - -#[derive(Debug, PartialEq)] -pub enum PermissionLevel { - Unrestricted, - Managed, - Restricted, -} +use crate::{models::CtxData, LimitExecutors}; pub struct CreateGenericResponse { content: String, @@ -81,196 +72,135 @@ impl CreateGenericResponse { } } -#[async_trait] -pub trait CommandInvoke { - fn channel_id(&self) -> ChannelId; - fn guild_id(&self) -> Option; - fn guild(&self, cache: Arc) -> Option; - fn author_id(&self) -> UserId; - async fn member(&self, context: &Context) -> SerenityResult; - fn msg(&self) -> Option; - fn interaction(&self) -> Option; - async fn respond( - &self, - http: Arc, - generic_response: CreateGenericResponse, - ) -> SerenityResult<()>; - async fn followup( - &self, - http: Arc, - generic_response: CreateGenericResponse, - ) -> SerenityResult<()>; +enum InvokeModel { + Slash(ApplicationCommandInteraction), + Text(Message), } -#[async_trait] -impl CommandInvoke for Message { - fn channel_id(&self) -> ChannelId { - self.channel_id - } - - fn guild_id(&self) -> Option { - self.guild_id - } - - fn guild(&self, cache: Arc) -> Option { - self.guild(cache) - } - - fn author_id(&self) -> UserId { - self.author.id - } - - async fn member(&self, context: &Context) -> SerenityResult { - self.member(context).await - } - - fn msg(&self) -> Option { - Some(self.clone()) - } - - fn interaction(&self) -> Option { - None - } - - async fn respond( - &self, - http: Arc, - generic_response: CreateGenericResponse, - ) -> SerenityResult<()> { - self.channel_id - .send_message(http, |m| { - m.content(generic_response.content); - - if let Some(embed) = generic_response.embed { - m.set_embed(embed.clone()); - } - - if let Some(components) = generic_response.components { - m.components(|c| { - *c = components; - c - }); - } - - m - }) - .await - .map(|_| ()) - } - - async fn followup( - &self, - http: Arc, - generic_response: CreateGenericResponse, - ) -> SerenityResult<()> { - self.channel_id - .send_message(http, |m| { - m.content(generic_response.content); - - if let Some(embed) = generic_response.embed { - m.set_embed(embed.clone()); - } - - if let Some(components) = generic_response.components { - m.components(|c| { - *c = components; - c - }); - } - - m - }) - .await - .map(|_| ()) - } +pub struct CommandInvoke { + model: InvokeModel, + already_responded: bool, } -#[async_trait] -impl CommandInvoke for ApplicationCommandInteraction { - fn channel_id(&self) -> ChannelId { - self.channel_id +impl CommandInvoke { + fn slash(interaction: ApplicationCommandInteraction) -> Self { + Self { model: InvokeModel::Slash(interaction), already_responded: false } } - fn guild_id(&self) -> Option { - self.guild_id + fn msg(msg: Message) -> Self { + Self { model: InvokeModel::Text(msg), already_responded: false } } - fn guild(&self, cache: Arc) -> Option { - if let Some(guild_id) = self.guild_id { - guild_id.to_guild_cached(cache) - } else { - None + pub fn interaction(self) -> Option { + match self.model { + InvokeModel::Slash(i) => Some(i), + InvokeModel::Text(_) => None, } } - fn author_id(&self) -> UserId { - self.member.as_ref().unwrap().user.id + pub fn channel_id(&self) -> ChannelId { + match &self.model { + InvokeModel::Slash(i) => i.channel_id, + InvokeModel::Text(m) => m.channel_id, + } } - async fn member(&self, _: &Context) -> SerenityResult { - Ok(self.member.clone().unwrap()) + pub fn guild_id(&self) -> Option { + match &self.model { + InvokeModel::Slash(i) => i.guild_id, + InvokeModel::Text(m) => m.guild_id, + } } - fn msg(&self) -> Option { - None + pub fn guild(&self, cache: impl AsRef) -> Option { + self.guild_id().map(|id| id.to_guild_cached(cache)).flatten() } - fn interaction(&self) -> Option { - Some(self.clone()) + pub fn author_id(&self) -> UserId { + match &self.model { + InvokeModel::Slash(i) => i.user.id, + InvokeModel::Text(m) => m.author.id, + } } - async fn respond( + pub async fn member(&self, cache_http: impl CacheHttp) -> Option { + match &self.model { + InvokeModel::Slash(i) => i.member.clone(), + InvokeModel::Text(m) => m.member(cache_http).await.ok(), + } + } + + pub async fn respond( &self, - http: Arc, + http: impl AsRef, generic_response: CreateGenericResponse, ) -> SerenityResult<()> { - self.create_interaction_response(http, |r| { - r.kind(InteractionResponseType::ChannelMessageWithSource).interaction_response_data( - |d| { - d.content(generic_response.content); + match &self.model { + InvokeModel::Slash(i) => { + if !self.already_responded { + i.create_interaction_response(http, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| { + d.content(generic_response.content); + + if let Some(embed) = generic_response.embed { + d.add_embed(embed.clone()); + } + + if let Some(components) = generic_response.components { + d.components(|c| { + *c = components; + c + }); + } + + d + }) + }) + .await + .map(|_| ()) + } else { + i.create_followup_message(http, |d| { + d.content(generic_response.content); + + if let Some(embed) = generic_response.embed { + d.add_embed(embed.clone()); + } + + if let Some(components) = generic_response.components { + d.components(|c| { + *c = components; + c + }); + } + + d + }) + .await + .map(|_| ()) + } + } + InvokeModel::Text(m) => m + .channel_id + .send_message(http, |m| { + m.content(generic_response.content); if let Some(embed) = generic_response.embed { - d.add_embed(embed.clone()); + m.set_embed(embed.clone()); } if let Some(components) = generic_response.components { - d.components(|c| { + m.components(|c| { *c = components; c }); } - d - }, - ) - }) - .await - .map(|_| ()) - } - - async fn followup( - &self, - http: Arc, - generic_response: CreateGenericResponse, - ) -> SerenityResult<()> { - self.create_followup_message(http, |d| { - d.content(generic_response.content); - - if let Some(embed) = generic_response.embed { - d.add_embed(embed.clone()); - } - - if let Some(components) = generic_response.components { - d.components(|c| { - *c = components; - c - }); - } - - d - }) - .await - .map(|_| ()) + m + }) + .await + .map(|_| ()), + } } } @@ -283,6 +213,7 @@ pub struct Arg { pub options: &'static [&'static Self], } +#[derive(Serialize, Deserialize, Clone)] pub enum OptionValue { String(String), Integer(i64), @@ -330,8 +261,9 @@ impl OptionValue { } } +#[derive(Serialize, Deserialize, Clone)] pub struct CommandOptions { - pub command: String, + pub command: &'static str, pub subcommand: Option, pub subcommand_group: Option, pub options: HashMap, @@ -343,8 +275,17 @@ impl CommandOptions { } } -impl From for CommandOptions { - fn from(interaction: ApplicationCommandInteraction) -> Self { +impl CommandOptions { + fn new(command: &'static Command) -> Self { + Self { + command: command.names[0], + subcommand: None, + subcommand_group: None, + options: Default::default(), + } + } + + fn populate(mut self, interaction: &ApplicationCommandInteraction) -> Self { fn match_option( option: ApplicationCommandInteractionDataOption, cmd_opts: &mut CommandOptions, @@ -429,35 +370,31 @@ impl From for CommandOptions { } } - let mut cmd_opts = Self { - command: interaction.data.name, - subcommand: None, - subcommand_group: None, - options: Default::default(), - }; - - for option in interaction.data.options { - match_option(option, &mut cmd_opts) + for option in &interaction.data.options { + match_option(option.clone(), &mut self) } - cmd_opts + self } } -type SlashCommandFn = for<'fut> fn( - &'fut Context, - &'fut (dyn CommandInvoke + Sync + Send), - CommandOptions, -) -> BoxFuture<'fut, ()>; +pub enum HookResult { + Continue, + Halt, +} -type TextCommandFn = for<'fut> fn( - &'fut Context, - &'fut (dyn CommandInvoke + Sync + Send), - String, -) -> BoxFuture<'fut, ()>; +type SlashCommandFn = + for<'fut> fn(&'fut Context, CommandInvoke, CommandOptions) -> BoxFuture<'fut, ()>; -type MultiCommandFn = - for<'fut> fn(&'fut Context, &'fut (dyn CommandInvoke + Sync + Send)) -> BoxFuture<'fut, ()>; +type TextCommandFn = for<'fut> fn(&'fut Context, CommandInvoke, String) -> BoxFuture<'fut, ()>; + +type MultiCommandFn = for<'fut> fn(&'fut Context, CommandInvoke) -> BoxFuture<'fut, ()>; + +pub type HookFn = for<'fut> fn( + &'fut Context, + &'fut CommandInvoke, + &'fut CommandOptions, +) -> BoxFuture<'fut, HookResult>; pub enum CommandFnType { Slash(SlashCommandFn), @@ -474,6 +411,17 @@ impl CommandFnType { } } +pub struct Hook { + pub fun: HookFn, + pub uuid: u128, +} + +impl PartialEq for Hook { + fn eq(&self, other: &Self) -> bool { + self.uuid == other.uuid + } +} + pub struct Command { pub fun: CommandFnType, @@ -483,11 +431,12 @@ pub struct Command { pub examples: &'static [&'static str], pub group: &'static str, - pub required_permissions: PermissionLevel, pub args: &'static [&'static Arg], pub can_blacklist: bool, pub supports_dm: bool, + + pub hooks: &'static [&'static Hook], } impl Hash for Command { @@ -504,81 +453,6 @@ impl PartialEq for Command { impl Eq for Command {} -impl Command { - async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool { - if self.required_permissions == PermissionLevel::Unrestricted { - true - } else { - let permissions = guild.member_permissions(&ctx, &member.user).await.unwrap(); - - if permissions.manage_guild() - || (permissions.manage_messages() - && self.required_permissions == PermissionLevel::Managed) - { - return true; - } - - if self.required_permissions == PermissionLevel::Managed { - let pool = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("Could not get SQLPool from data"); - - match sqlx::query!( - " -SELECT - role -FROM - roles -INNER JOIN - command_restrictions ON roles.id = command_restrictions.role_id -WHERE - command_restrictions.command = ? AND - roles.guild_id = ( - SELECT - id - FROM - guilds - WHERE - guild = ?) - ", - self.names[0], - guild.id.as_u64() - ) - .fetch_all(&pool) - .await - { - Ok(rows) => { - let role_ids = - member.roles.iter().map(|r| *r.as_u64()).collect::>(); - - for row in rows { - if role_ids.contains(&row.role) { - return true; - } - } - - false - } - - Err(sqlx::Error::RowNotFound) => false, - - Err(e) => { - warn!("Unexpected error occurred querying command_restrictions: {:?}", e); - - false - } - } - } else { - false - } - } - } -} - pub struct RegexFramework { pub commands_map: HashMap, pub commands: HashSet<&'static Command>, @@ -589,23 +463,14 @@ pub struct RegexFramework { ignore_bots: bool, case_insensitive: bool, dm_enabled: bool, - default_text_fun: TextCommandFn, debug_guild: Option, + hooks: Vec<&'static Hook>, } impl TypeMapKey for RegexFramework { type Value = Arc; } -fn drop_text<'fut>( - _: &'fut Context, - _: &'fut (dyn CommandInvoke + Sync + Send), - _: String, -) -> std::pin::Pin + std::marker::Send + 'fut)>> -{ - async move {}.boxed() -} - impl RegexFramework { pub fn new>(client_id: T) -> Self { Self { @@ -618,8 +483,8 @@ impl RegexFramework { ignore_bots: true, case_insensitive: true, dm_enabled: true, - default_text_fun: drop_text, debug_guild: None, + hooks: vec![], } } @@ -647,6 +512,12 @@ impl RegexFramework { self } + pub fn add_hook(mut self, fun: &'static Hook) -> Self { + self.hooks.push(fun); + + self + } + pub fn add_command(mut self, command: &'static Command) -> Self { self.commands.insert(command); @@ -791,77 +662,46 @@ impl RegexFramework { .expect(&format!("Received invalid command: {}", interaction.data.name)) }; - let guild = interaction.guild(ctx.cache.clone()).unwrap(); - let member = interaction.clone().member.unwrap(); + let args = CommandOptions::new(command).populate(&interaction); + let command_invoke = CommandInvoke::slash(interaction); - if command.check_permissions(&ctx, &guild, &member).await { - let args = CommandOptions::from(interaction.clone()); - - if !ctx.check_executing(interaction.author_id()).await { - ctx.set_executing(interaction.author_id()).await; - - match command.fun { - CommandFnType::Slash(t) => t(&ctx, &interaction, args).await, - CommandFnType::Multi(m) => m(&ctx, &interaction).await, - _ => (), + for hook in command.hooks { + match (hook.fun)(&ctx, &command_invoke, &args).await { + HookResult::Continue => {} + HookResult::Halt => { + return; } - - ctx.drop_executing(interaction.author_id()).await; } - } else if command.required_permissions == PermissionLevel::Restricted { - let _ = interaction - .respond( - ctx.http.clone(), - CreateGenericResponse::new().content( - "You must have the `Manage Server` permission to use this command.", - ), - ) - .await; - } else if command.required_permissions == PermissionLevel::Managed { - let _ = interaction - .respond( - ctx.http.clone(), - CreateGenericResponse::new().content( - "You must have `Manage Messages` or have a role capable of sending reminders to that channel. \ - Please talk to your server admin, and ask them to use the `/restrict` command to specify \ - allowed roles.", - ), - ) - .await; + } + + for hook in &self.hooks { + match (hook.fun)(&ctx, &command_invoke, &args).await { + HookResult::Continue => {} + HookResult::Halt => { + return; + } + } + } + + let user_id = command_invoke.author_id(); + + if !ctx.check_executing(user_id).await { + ctx.set_executing(user_id).await; + + match command.fun { + CommandFnType::Slash(t) => t(&ctx, command_invoke, args).await, + CommandFnType::Multi(m) => m(&ctx, command_invoke).await, + _ => (), + } + + ctx.drop_executing(user_id).await; } } } -enum PermissionCheck { - None, // No permissions - Basic(bool, bool), // Send + Embed permissions (sufficient to reply) - All, // Above + Manage Webhooks (sufficient to operate) -} - #[async_trait] impl Framework for RegexFramework { async fn dispatch(&self, ctx: Context, msg: Message) { - async fn check_self_permissions( - ctx: &Context, - guild: &Guild, - channel: &GuildChannel, - ) -> SerenityResult { - let user_id = ctx.cache.current_user_id(); - - let guild_perms = guild.member_permissions(&ctx, user_id).await?; - let channel_perms = channel.permissions_for_user(ctx, user_id)?; - - let basic_perms = channel_perms.send_messages(); - - Ok(if basic_perms && guild_perms.manage_webhooks() && channel_perms.embed_links() { - PermissionCheck::All - } else if basic_perms { - PermissionCheck::Basic(guild_perms.manage_webhooks(), channel_perms.embed_links()) - } else { - PermissionCheck::None - }) - } - async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option>) -> bool { if let Some(prefix) = prefix_opt { let guild_prefix = ctx.prefix(Some(guild.id)).await; @@ -874,144 +714,65 @@ impl Framework for RegexFramework { // gate to prevent analysing messages unnecessarily if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() { - } else { - // Guild Command - if let (Some(guild), Ok(Channel::Guild(channel))) = - (msg.guild(&ctx), msg.channel(&ctx).await) - { - let data = ctx.data.read().await; + return; + } - let pool = data.get::().cloned().expect("Could not get SQLPool from data"); + let user_id = msg.author.id; + let invoke = CommandInvoke::msg(msg.clone()); - if let Some(full_match) = self.command_matcher.captures(&msg.content) { - if check_prefix(&ctx, &guild, full_match.name("prefix")).await { - match check_self_permissions(&ctx, &guild, &channel).await { - Ok(perms) => match perms { - PermissionCheck::All => { - let command = self - .commands_map - .get( - &full_match - .name("cmd") - .unwrap() - .as_str() - .to_lowercase(), - ) - .unwrap(); - - let channel = msg.channel(&ctx).await.unwrap(); - - let channel_data = - ChannelData::from_channel(&channel, &pool).await.unwrap(); - - if !command.can_blacklist || !channel_data.blacklisted { - let args = full_match - .name("args") - .map(|m| m.as_str()) - .unwrap_or("") - .to_string(); - - let member = guild.member(&ctx, &msg.author).await.unwrap(); - - if command.check_permissions(&ctx, &guild, &member).await { - if msg.id == MessageId(0) - || !ctx.check_executing(msg.author.id).await - { - ctx.set_executing(msg.author.id).await; - - match command.fun { - CommandFnType::Text(t) => t(&ctx, &msg, args), - CommandFnType::Multi(m) => m(&ctx, &msg), - _ => (self.default_text_fun)(&ctx, &msg, args), - } - .await; - - ctx.drop_executing(msg.author.id).await; - } - } else if command.required_permissions - == PermissionLevel::Restricted - { - let _ = msg - .channel_id - .say( - &ctx, - "You must have the `Manage Server` permission to use this command.", - ) - .await; - } else if command.required_permissions - == PermissionLevel::Managed - { - let _ = msg - .channel_id - .say( - &ctx, - "You must have `Manage Messages` or have a role capable of \ - sending reminders to that channel. Please talk to your server \ - admin, and ask them to use the `/restrict` command to specify \ - allowed roles.", - ) - .await; - } - } - } - - PermissionCheck::Basic(manage_webhooks, embed_links) => { - let _ = msg - .channel_id - .say( - &ctx, - format!( - "Please ensure the bot has the correct permissions: - -✅ **Send Message** -{} **Embed Links** -{} **Manage Webhooks**", - if manage_webhooks { "✅" } else { "❌" }, - if embed_links { "✅" } else { "❌" }, - ), - ) - .await; - } - - PermissionCheck::None => { - warn!("Missing enough permissions for guild {}", guild.id); - } - }, - - Err(e) => { - error!( - "Error occurred getting permissions in guild {}: {:?}", - guild.id, e - ); - } - } - } - } - } - // DM Command - else if self.dm_enabled { - if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { + // Guild Command + if let Some(guild) = msg.guild(&ctx) { + if let Some(full_match) = self.command_matcher.captures(&msg.content) { + if check_prefix(&ctx, &guild, full_match.name("prefix")).await { let command = self .commands_map .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) .unwrap(); - let args = - full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); - if msg.id == MessageId(0) || !ctx.check_executing(msg.author.id).await { - ctx.set_executing(msg.author.id).await; + let channel_data = ctx.channel_data(invoke.channel_id()).await.unwrap(); - match command.fun { - CommandFnType::Text(t) => t(&ctx, &msg, args), - CommandFnType::Multi(m) => m(&ctx, &msg), - _ => (self.default_text_fun)(&ctx, &msg, args), + if !command.can_blacklist || !channel_data.blacklisted { + let args = + full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); + + if msg.id == MessageId(0) || !ctx.check_executing(user_id).await { + ctx.set_executing(user_id).await; + + match command.fun { + CommandFnType::Text(t) => t(&ctx, invoke, args).await, + CommandFnType::Multi(m) => m(&ctx, invoke).await, + _ => {} + }; + + ctx.drop_executing(user_id).await; } - .await; - - ctx.drop_executing(msg.author.id).await; } } } } + // DM Command + else if self.dm_enabled { + if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { + let command = self + .commands_map + .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) + .unwrap(); + let args = full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); + + let user_id = invoke.author_id(); + + if msg.id == MessageId(0) || !ctx.check_executing(user_id).await { + ctx.set_executing(user_id).await; + + match command.fun { + CommandFnType::Text(t) => t(&ctx, invoke, args).await, + CommandFnType::Multi(m) => m(&ctx, invoke).await, + _ => {} + }; + + ctx.drop_executing(user_id).await; + } + } + } } } diff --git a/src/hooks.rs b/src/hooks.rs new file mode 100644 index 0000000..5c51fa0 --- /dev/null +++ b/src/hooks.rs @@ -0,0 +1,217 @@ +use log::warn; +use regex_command_attr::check; +use serenity::{client::Context, model::channel::Channel}; + +use crate::{ + framework::{CommandInvoke, CommandOptions, CreateGenericResponse, HookResult}, + moderation_cmds, RecordingMacros, SQLPool, +}; + +#[check] +pub async fn macro_check( + ctx: &Context, + invoke: &CommandInvoke, + args: &CommandOptions, +) -> HookResult { + if let Some(guild_id) = invoke.guild_id() { + if args.command != moderation_cmds::MACRO_CMD_COMMAND.names[0] { + let active_recordings = + ctx.data.read().await.get::().cloned().unwrap(); + let mut lock = active_recordings.write().await; + + if let Some(command_macro) = lock.get_mut(&(guild_id, invoke.author_id())) { + command_macro.commands.push(args.clone()); + + let _ = invoke + .respond( + &ctx, + CreateGenericResponse::new().content("Command recorded to macro"), + ) + .await; + + HookResult::Halt + } else { + HookResult::Continue + } + } else { + HookResult::Continue + } + } else { + HookResult::Continue + } +} + +#[check] +pub async fn check_self_permissions( + ctx: &Context, + invoke: &CommandInvoke, + _args: &CommandOptions, +) -> HookResult { + if let Some(guild) = invoke.guild(&ctx) { + let user_id = ctx.cache.current_user_id(); + + let manage_webhooks = + guild.member_permissions(&ctx, user_id).await.map_or(false, |p| p.manage_webhooks()); + let (send_messages, embed_links) = invoke + .channel_id() + .to_channel_cached(&ctx) + .map(|c| { + if let Channel::Guild(channel) = c { + channel.permissions_for_user(ctx, user_id).ok() + } else { + None + } + }) + .flatten() + .map_or((false, false), |p| (p.send_messages(), p.embed_links())); + + if manage_webhooks && send_messages && embed_links { + HookResult::Continue + } else { + if send_messages { + let _ = invoke + .respond( + &ctx, + CreateGenericResponse::new().content(format!( + "Please ensure the bot has the correct permissions: + +✅ **Send Message** +{} **Embed Links** +{} **Manage Webhooks**", + if manage_webhooks { "✅" } else { "❌" }, + if embed_links { "✅" } else { "❌" }, + )), + ) + .await; + } else { + warn!("Missing permissions in guild {}", guild.id); + } + + HookResult::Halt + } + } else { + HookResult::Continue + } +} + +#[check] +pub async fn check_managed_permissions( + ctx: &Context, + invoke: &CommandInvoke, + args: &CommandOptions, +) -> HookResult { + if let Some(guild) = invoke.guild(&ctx) { + let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap(); + + if permissions.manage_messages() { + return HookResult::Continue; + } + + let member = invoke.member(&ctx).await.unwrap(); + + let pool = ctx + .data + .read() + .await + .get::() + .cloned() + .expect("Could not get SQLPool from data"); + + match sqlx::query!( + " +SELECT + role +FROM + roles +INNER JOIN + command_restrictions ON roles.id = command_restrictions.role_id +WHERE + command_restrictions.command = ? AND + roles.guild_id = ( + SELECT + id + FROM + guilds + WHERE + guild = ?) + ", + args.command, + guild.id.as_u64() + ) + .fetch_all(&pool) + .await + { + Ok(rows) => { + let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::>(); + + for row in rows { + if role_ids.contains(&row.role) { + return HookResult::Continue; + } + } + + let _ = invoke + .respond( + &ctx, + CreateGenericResponse::new().content( + "You must have \"Manage Messages\" or have a role capable of sending reminders to that channel. \ +Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.", + ), + ) + .await; + + HookResult::Halt + } + + Err(sqlx::Error::RowNotFound) => { + let _ = invoke + .respond( + &ctx, + CreateGenericResponse::new().content( + "You must have \"Manage Messages\" or have a role capable of sending reminders to that channel. \ +Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.", + ), + ) + .await; + + HookResult::Halt + } + + Err(e) => { + warn!("Unexpected error occurred querying command_restrictions: {:?}", e); + + HookResult::Halt + } + } + } else { + HookResult::Continue + } +} + +#[check] +pub async fn check_guild_permissions( + ctx: &Context, + invoke: &CommandInvoke, + _args: &CommandOptions, +) -> HookResult { + if let Some(guild) = invoke.guild(&ctx) { + let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap(); + + if !permissions.manage_guild() { + let _ = invoke + .respond( + &ctx, + CreateGenericResponse::new().content( + "You must have the \"Manage Server\" permission to use this command", + ), + ) + .await; + + HookResult::Halt + } else { + HookResult::Continue + } + } else { + HookResult::Continue + } +} diff --git a/src/main.rs b/src/main.rs index e4999d0..c2bd692 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ mod commands; mod component_models; mod consts; mod framework; +mod hooks; mod models; mod time_parser; @@ -39,7 +40,7 @@ use crate::{ component_models::ComponentDataModel, consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR}, framework::RegexFramework, - models::guild_data::GuildData, + models::{command_macro::CommandMacro, guild_data::GuildData}, }; struct GuildDataCache; @@ -72,6 +73,12 @@ impl TypeMapKey for CurrentlyExecuting { type Value = Arc>>; } +struct RecordingMacros; + +impl TypeMapKey for RecordingMacros { + type Value = Arc>>; +} + #[async_trait] trait LimitExecutors { async fn check_executing(&self, user: UserId) -> bool; @@ -326,10 +333,13 @@ async fn main() -> Result<(), Box> { .add_command(&moderation_cmds::RESTRICT_COMMAND) .add_command(&moderation_cmds::TIMEZONE_COMMAND) .add_command(&moderation_cmds::PREFIX_COMMAND) + .add_command(&moderation_cmds::MACRO_CMD_COMMAND) /* .add_command("alias", &moderation_cmds::ALIAS_COMMAND) .add_command("a", &moderation_cmds::ALIAS_COMMAND) */ + .add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK) + .add_hook(&hooks::MACRO_CHECK_HOOK) .build(); let framework_arc = Arc::new(framework); @@ -375,6 +385,7 @@ async fn main() -> Result<(), Box> { data.insert::(Arc::new(popular_timezones)); data.insert::(Arc::new(reqwest::Client::new())); data.insert::(framework_arc.clone()); + data.insert::(Arc::new(RwLock::new(HashMap::new()))); } if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| { diff --git a/src/models/command_macro.rs b/src/models/command_macro.rs new file mode 100644 index 0000000..4121faa --- /dev/null +++ b/src/models/command_macro.rs @@ -0,0 +1,10 @@ +use serenity::model::id::GuildId; + +use crate::framework::CommandOptions; + +pub struct CommandMacro { + pub guild_id: GuildId, + pub name: String, + pub description: Option, + pub commands: Vec, +} diff --git a/src/models/mod.rs b/src/models/mod.rs index 7801c90..697bcf9 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,4 +1,5 @@ pub mod channel_data; +pub mod command_macro; pub mod guild_data; pub mod reminder; pub mod timer; diff --git a/src/models/reminder/errors.rs b/src/models/reminder/errors.rs index e8fc0f5..141feb8 100644 --- a/src/models/reminder/errors.rs +++ b/src/models/reminder/errors.rs @@ -1,32 +1,5 @@ use crate::consts::{MAX_TIME, MIN_INTERVAL}; -#[derive(Debug)] -pub enum InteractionError { - InvalidFormat, - InvalidBase64, - InvalidSize, - NoReminder, - SignatureMismatch, - InvalidAction, -} - -impl ToString for InteractionError { - fn to_string(&self) -> String { - match self { - InteractionError::InvalidFormat => { - String::from("The interaction data was improperly formatted") - } - InteractionError::InvalidBase64 => String::from("The interaction data was invalid"), - InteractionError::InvalidSize => String::from("The interaction data was invalid"), - InteractionError::NoReminder => String::from("Reminder could not be found"), - InteractionError::SignatureMismatch => { - String::from("Only the user who did the command can use interactions") - } - InteractionError::InvalidAction => String::from("The action was invalid"), - } - } -} - #[derive(PartialEq, Eq, Hash, Debug)] pub enum ReminderError { LongTime,