added functionality for reusable hook functions that will execute on commands
This commit is contained in:
		
							
								
								
									
										10
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										10
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							| @@ -1213,6 +1213,7 @@ dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn", | ||||
|  "uuid", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| @@ -2016,6 +2017,15 @@ version = "0.7.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" | ||||
|  | ||||
| [[package]] | ||||
| name = "uuid" | ||||
| version = "0.8.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" | ||||
| dependencies = [ | ||||
|  "getrandom 0.2.3", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "uwl" | ||||
| version = "0.6.0" | ||||
|   | ||||
| @@ -15,7 +15,6 @@ Reminder Bot can be built by running `cargo build --release` in the top level di | ||||
| These environment variables must be provided when compiling the bot | ||||
| * `DATABASE_URL` - the URL of your MySQL database (`mysql://user[:password]@domain/database`) | ||||
| * `WEBHOOK_AVATAR` - accepts the name of an image file located in `$CARGO_MANIFEST_DIR/assets/` to be used as the avatar when creating webhooks. **IMPORTANT: image file must be 128x128 or smaller in size** | ||||
| * `STRINGS_FILE` - accepts the name of a compiled strings file located in `$CARGO_MANIFEST_DIR/assets/` to be used for creating messages. Compiled string files can be generated with `compile.py` at https://github.com/reminder-bot/languages | ||||
|  | ||||
| ### Setting up Python | ||||
| Reminder Bot by default looks for a venv within it's working directory to run Python out of. To set up a venv, install `python3-venv` and run `python3 -m venv venv`. Then, run `source venv/bin/activate` to activate the venv, and do `pip install dateparser` to install the required library | ||||
| @@ -29,14 +28,12 @@ __Required Variables__ | ||||
|  | ||||
| __Other Variables__ | ||||
| * `MIN_INTERVAL` - default `600`, defines the shortest interval the bot should accept | ||||
| * `MAX_TIME` - default `1576800000`, defines the maximum time ahead that reminders can be set for | ||||
| * `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor | ||||
| * `DEFAULT_PREFIX` - default `$`, used for the default prefix on new guilds | ||||
| * `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users | ||||
| * `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to | ||||
| * `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots | ||||
| * `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else | ||||
| * `LOCAL_LANGUAGE` - default `EN`. Specifies the string set to fall back to if a string cannot be found (and to be used with new users) | ||||
| * `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds  | ||||
| * `CASE_INSENSITIVE` - default `1`, if `1`, commands will be treated with case insensitivity (so both `$help` and `$HELP` will work) | ||||
| * `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran | ||||
|   | ||||
| @@ -13,3 +13,4 @@ proc-macro = true | ||||
| quote = "^1.0" | ||||
| syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] } | ||||
| proc-macro2 = "1.0" | ||||
| uuid = { version = "0.8", features = ["v4"] } | ||||
|   | ||||
| @@ -8,7 +8,7 @@ use syn::{ | ||||
| }; | ||||
|  | ||||
| use crate::{ | ||||
|     structures::{ApplicationCommandOptionType, Arg, PermissionLevel}, | ||||
|     structures::{ApplicationCommandOptionType, Arg}, | ||||
|     util::{AsOption, LitExt}, | ||||
| }; | ||||
|  | ||||
| @@ -46,24 +46,15 @@ impl fmt::Display for ValueKind { | ||||
|  | ||||
| fn to_ident(p: Path) -> Result<Ident> { | ||||
|     if p.segments.is_empty() { | ||||
|         return Err(Error::new( | ||||
|             p.span(), | ||||
|             "cannot convert an empty path to an identifier", | ||||
|         )); | ||||
|         return Err(Error::new(p.span(), "cannot convert an empty path to an identifier")); | ||||
|     } | ||||
|  | ||||
|     if p.segments.len() > 1 { | ||||
|         return Err(Error::new( | ||||
|             p.span(), | ||||
|             "the path must not have more than one segment", | ||||
|         )); | ||||
|         return Err(Error::new(p.span(), "the path must not have more than one segment")); | ||||
|     } | ||||
|  | ||||
|     if !p.segments[0].arguments.is_empty() { | ||||
|         return Err(Error::new( | ||||
|             p.span(), | ||||
|             "the singular path segment must not have any arguments", | ||||
|         )); | ||||
|         return Err(Error::new(p.span(), "the singular path segment must not have any arguments")); | ||||
|     } | ||||
|  | ||||
|     Ok(p.segments[0].ident.clone()) | ||||
| @@ -85,12 +76,7 @@ impl Values { | ||||
|         literals: Vec<(Option<String>, Lit)>, | ||||
|         span: Span, | ||||
|     ) -> Self { | ||||
|         Values { | ||||
|             name, | ||||
|             literals, | ||||
|             kind, | ||||
|             span, | ||||
|         } | ||||
|         Values { name, literals, kind, span } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -145,11 +131,7 @@ pub fn parse_values(attr: &Attribute) -> Result<Values> { | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 let kind = if lits.len() == 1 { | ||||
|                     ValueKind::SingleList | ||||
|                 } else { | ||||
|                     ValueKind::List | ||||
|                 }; | ||||
|                 let kind = if lits.len() == 1 { ValueKind::SingleList } else { ValueKind::List }; | ||||
|  | ||||
|                 Ok(Values::new(name, kind, lits, attr.span())) | ||||
|             } else { | ||||
| @@ -183,12 +165,7 @@ pub fn parse_values(attr: &Attribute) -> Result<Values> { | ||||
|             let name = to_ident(meta.path)?; | ||||
|             let lit = meta.lit; | ||||
|  | ||||
|             Ok(Values::new( | ||||
|                 name, | ||||
|                 ValueKind::Equals, | ||||
|                 vec![(None, lit)], | ||||
|                 attr.span(), | ||||
|             )) | ||||
|             Ok(Values::new(name, ValueKind::Equals, vec![(None, lit)], attr.span())) | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -231,10 +208,7 @@ fn validate(values: &Values, forms: &[ValueKind]) -> Result<()> { | ||||
|         return Err(Error::new( | ||||
|             values.span, | ||||
|             // Using the `_args` version here to avoid an allocation. | ||||
|             format_args!( | ||||
|                 "the attribute must be in of these forms:\n{}", | ||||
|                 DisplaySlice(forms) | ||||
|             ), | ||||
|             format_args!("the attribute must be in of these forms:\n{}", DisplaySlice(forms)), | ||||
|         )); | ||||
|     } | ||||
|  | ||||
| @@ -254,11 +228,7 @@ impl AttributeOption for Vec<String> { | ||||
|     fn parse(values: Values) -> Result<Self> { | ||||
|         validate(&values, &[ValueKind::List])?; | ||||
|  | ||||
|         Ok(values | ||||
|             .literals | ||||
|             .into_iter() | ||||
|             .map(|(_, l)| l.to_str()) | ||||
|             .collect()) | ||||
|         Ok(values.literals.into_iter().map(|(_, l)| l.to_str()).collect()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -294,37 +264,18 @@ impl AttributeOption for Vec<Ident> { | ||||
|     fn parse(values: Values) -> Result<Self> { | ||||
|         validate(&values, &[ValueKind::List])?; | ||||
|  | ||||
|         Ok(values | ||||
|             .literals | ||||
|             .into_iter() | ||||
|             .map(|(_, l)| l.to_ident()) | ||||
|             .collect()) | ||||
|         Ok(values.literals.into_iter().map(|(_, l)| l.to_ident()).collect()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl AttributeOption for Option<String> { | ||||
|     fn parse(values: Values) -> Result<Self> { | ||||
|         validate( | ||||
|             &values, | ||||
|             &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList], | ||||
|         )?; | ||||
|         validate(&values, &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList])?; | ||||
|  | ||||
|         Ok(values.literals.get(0).map(|(_, l)| l.to_str())) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl AttributeOption for PermissionLevel { | ||||
|     fn parse(values: Values) -> Result<Self> { | ||||
|         validate(&values, &[ValueKind::SingleList])?; | ||||
|  | ||||
|         Ok(values | ||||
|             .literals | ||||
|             .get(0) | ||||
|             .map(|(_, l)| PermissionLevel::from_str(&*l.to_str()).unwrap()) | ||||
|             .unwrap()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl AttributeOption for Arg { | ||||
|     fn parse(values: Values) -> Result<Self> { | ||||
|         validate(&values, &[ValueKind::EqualsList])?; | ||||
|   | ||||
| @@ -2,6 +2,8 @@ pub mod suffixes { | ||||
|     pub const COMMAND: &str = "COMMAND"; | ||||
|     pub const ARG: &str = "ARG"; | ||||
|     pub const SUBCOMMAND: &str = "SUBCOMMAND"; | ||||
|     pub const CHECK: &str = "CHECK"; | ||||
|     pub const HOOK: &str = "HOOK"; | ||||
| } | ||||
|  | ||||
| pub use self::suffixes::*; | ||||
|   | ||||
| @@ -5,6 +5,7 @@ use proc_macro::TokenStream; | ||||
| use proc_macro2::Ident; | ||||
| use quote::quote; | ||||
| use syn::{parse::Error, parse_macro_input, parse_quote, spanned::Spanned, Lit, Type}; | ||||
| use uuid::Uuid; | ||||
|  | ||||
| pub(crate) mod attributes; | ||||
| pub(crate) mod consts; | ||||
| @@ -43,6 +44,7 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { | ||||
|         fun.name.to_string() | ||||
|     }; | ||||
|  | ||||
|     let mut hooks: Vec<Ident> = Vec::new(); | ||||
|     let mut options = Options::new(); | ||||
|  | ||||
|     for attribute in &fun.attributes { | ||||
| @@ -76,11 +78,13 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { | ||||
|                     util::append_line(&mut options.description, line); | ||||
|                 } | ||||
|             } | ||||
|             "hook" => { | ||||
|                 hooks.push(propagate_err!(attributes::parse(values))); | ||||
|             } | ||||
|             _ => { | ||||
|                 match_options!(name, values, options, span => [ | ||||
|                     aliases; | ||||
|                     group; | ||||
|                     required_permissions; | ||||
|                     can_blacklist; | ||||
|                     supports_dm | ||||
|                 ]); | ||||
| @@ -93,7 +97,6 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { | ||||
|         description, | ||||
|         group, | ||||
|         examples, | ||||
|         required_permissions, | ||||
|         can_blacklist, | ||||
|         supports_dm, | ||||
|         mut cmd_args, | ||||
| @@ -235,10 +238,10 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { | ||||
|             desc: #description, | ||||
|             group: #group, | ||||
|             examples: &[#(#examples),*], | ||||
|             required_permissions: #required_permissions, | ||||
|             can_blacklist: #can_blacklist, | ||||
|             supports_dm: #supports_dm, | ||||
|             args: &[#(&#arg_idents),*], | ||||
|             hooks: &[#(&#hooks),*], | ||||
|         }; | ||||
|     }); | ||||
|  | ||||
| @@ -256,3 +259,44 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { | ||||
|  | ||||
|     tokens.into() | ||||
| } | ||||
|  | ||||
| #[proc_macro_attribute] | ||||
| pub fn check(_attr: TokenStream, input: TokenStream) -> TokenStream { | ||||
|     let mut fun = parse_macro_input!(input as CommandFun); | ||||
|  | ||||
|     let n = fun.name.clone(); | ||||
|     let name = n.with_suffix(HOOK); | ||||
|     let fn_name = n.with_suffix(CHECK); | ||||
|     let visibility = fun.visibility; | ||||
|  | ||||
|     let cooked = fun.cooked; | ||||
|     let body = fun.body; | ||||
|     let ret = fun.ret; | ||||
|     populate_fut_lifetimes_on_refs(&mut fun.args); | ||||
|     let args = fun.args; | ||||
|  | ||||
|     let hook_path = quote!(crate::framework::Hook); | ||||
|     let uuid = Uuid::new_v4().as_u128(); | ||||
|  | ||||
|     (quote! { | ||||
|         #(#cooked)* | ||||
|         #[allow(missing_docs)] | ||||
|         #visibility fn #fn_name<'fut>(#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> { | ||||
|             use ::serenity::futures::future::FutureExt; | ||||
|  | ||||
|             async move { | ||||
|                 let _output: #ret = { #(#body)* }; | ||||
|                 #[allow(unreachable_code)] | ||||
|                 _output | ||||
|             }.boxed() | ||||
|         } | ||||
|  | ||||
|         #(#cooked)* | ||||
|         #[allow(missing_docs)] | ||||
|         pub static #name: #hook_path = #hook_path { | ||||
|             fun: #fn_name, | ||||
|             uuid: #uuid, | ||||
|         }; | ||||
|     }) | ||||
|     .into() | ||||
| } | ||||
|   | ||||
| @@ -4,7 +4,7 @@ use syn::{ | ||||
|     braced, | ||||
|     parse::{Error, Parse, ParseStream, Result}, | ||||
|     spanned::Spanned, | ||||
|     Attribute, Block, FnArg, Ident, Pat, Stmt, Token, Visibility, | ||||
|     Attribute, Block, FnArg, Ident, Pat, ReturnType, Stmt, Token, Type, Visibility, | ||||
| }; | ||||
|  | ||||
| use crate::util::{Argument, Parenthesised}; | ||||
| @@ -78,6 +78,7 @@ pub struct CommandFun { | ||||
|     pub visibility: Visibility, | ||||
|     pub name: Ident, | ||||
|     pub args: Vec<Argument>, | ||||
|     pub ret: Type, | ||||
|     pub body: Vec<Stmt>, | ||||
| } | ||||
|  | ||||
| @@ -97,6 +98,11 @@ impl Parse for CommandFun { | ||||
|         // (...) | ||||
|         let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?; | ||||
|  | ||||
|         let ret = match input.parse::<ReturnType>()? { | ||||
|             ReturnType::Type(_, t) => (*t).clone(), | ||||
|             ReturnType::Default => Type::Verbatim(quote!(())), | ||||
|         }; | ||||
|  | ||||
|         // { ... } | ||||
|         let bcont; | ||||
|         braced!(bcont in input); | ||||
| @@ -104,72 +110,23 @@ impl Parse for CommandFun { | ||||
|  | ||||
|         let args = args.into_iter().map(parse_argument).collect::<Result<Vec<_>>>()?; | ||||
|  | ||||
|         Ok(Self { attributes, cooked, visibility, name, args, body }) | ||||
|         Ok(Self { attributes, cooked, visibility, name, args, ret, body }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ToTokens for CommandFun { | ||||
|     fn to_tokens(&self, stream: &mut TokenStream2) { | ||||
|         let Self { attributes: _, cooked, visibility, name, args, body } = self; | ||||
|         let Self { attributes: _, cooked, visibility, name, args, ret, body } = self; | ||||
|  | ||||
|         stream.extend(quote! { | ||||
|             #(#cooked)* | ||||
|             #visibility async fn #name (#(#args),*) { | ||||
|             #visibility async fn #name (#(#args),*) -> #ret { | ||||
|                 #(#body)* | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub enum PermissionLevel { | ||||
|     Unrestricted, | ||||
|     Managed, | ||||
|     Restricted, | ||||
| } | ||||
|  | ||||
| impl Default for PermissionLevel { | ||||
|     fn default() -> Self { | ||||
|         Self::Unrestricted | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl PermissionLevel { | ||||
|     pub fn from_str(s: &str) -> Option<Self> { | ||||
|         Some(match s.to_uppercase().as_str() { | ||||
|             "UNRESTRICTED" => Self::Unrestricted, | ||||
|             "MANAGED" => Self::Managed, | ||||
|             "RESTRICTED" => Self::Restricted, | ||||
|             _ => return None, | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ToTokens for PermissionLevel { | ||||
|     fn to_tokens(&self, stream: &mut TokenStream2) { | ||||
|         let path = quote!(crate::framework::PermissionLevel); | ||||
|         let variant; | ||||
|  | ||||
|         match self { | ||||
|             Self::Unrestricted => { | ||||
|                 variant = quote!(Unrestricted); | ||||
|             } | ||||
|  | ||||
|             Self::Managed => { | ||||
|                 variant = quote!(Managed); | ||||
|             } | ||||
|  | ||||
|             Self::Restricted => { | ||||
|                 variant = quote!(Restricted); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         stream.extend(quote! { | ||||
|             #path::#variant | ||||
|         }); | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub(crate) enum ApplicationCommandOptionType { | ||||
|     SubCommand, | ||||
| @@ -272,7 +229,6 @@ pub(crate) struct Options { | ||||
|     pub description: String, | ||||
|     pub group: String, | ||||
|     pub examples: Vec<String>, | ||||
|     pub required_permissions: PermissionLevel, | ||||
|     pub can_blacklist: bool, | ||||
|     pub supports_dm: bool, | ||||
|     pub cmd_args: Vec<Arg>, | ||||
|   | ||||
							
								
								
									
										11
									
								
								migration/02-macro.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								migration/02-macro.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,11 @@ | ||||
| CREATE TABLE macro ( | ||||
|     id INT UNSIGNED AUTO_INCREMENT, | ||||
|     guild_id INT UNSIGNED NOT NULL, | ||||
|  | ||||
|     name VARCHAR(100) NOT NULL, | ||||
|     description VARCHAR(100), | ||||
|     commands TEXT, | ||||
|  | ||||
|     FOREIGN KEY (guild_id) REFERENCES guilds(id), | ||||
|     PRIMARY KEY (id) | ||||
| ); | ||||
| @@ -27,7 +27,7 @@ fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEm | ||||
| #[aliases("invite")] | ||||
| #[description("Get information about the bot")] | ||||
| #[group("Info")] | ||||
| async fn info(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { | ||||
| async fn info(ctx: &Context, invoke: CommandInvoke) { | ||||
|     let prefix = ctx.prefix(invoke.guild_id()).await; | ||||
|     let current_user = ctx.cache.current_user(); | ||||
|     let footer = footer(ctx); | ||||
| @@ -61,7 +61,7 @@ Use our dashboard: https://reminder-bot.com/", | ||||
| #[command] | ||||
| #[description("Details on supporting the bot and Patreon benefits")] | ||||
| #[group("Info")] | ||||
| async fn donate(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { | ||||
| async fn donate(ctx: &Context, invoke: CommandInvoke) { | ||||
|     let footer = footer(ctx); | ||||
|  | ||||
|     let _ = invoke | ||||
| @@ -94,7 +94,7 @@ Just $2 USD/month! | ||||
| #[command] | ||||
| #[description("Get the link to the online dashboard")] | ||||
| #[group("Info")] | ||||
| async fn dashboard(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { | ||||
| async fn dashboard(ctx: &Context, invoke: CommandInvoke) { | ||||
|     let footer = footer(ctx); | ||||
|  | ||||
|     let _ = invoke | ||||
| @@ -113,7 +113,7 @@ async fn dashboard(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { | ||||
| #[command] | ||||
| #[description("View the current time in your selected timezone")] | ||||
| #[group("Info")] | ||||
| async fn clock(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { | ||||
| async fn clock(ctx: &Context, invoke: CommandInvoke) { | ||||
|     let ud = ctx.user_data(&invoke.author_id()).await.unwrap(); | ||||
|     let now = Utc::now().with_timezone(&ud.timezone()); | ||||
|  | ||||
|   | ||||
| @@ -2,16 +2,21 @@ use chrono::offset::Utc; | ||||
| use chrono_tz::{Tz, TZ_VARIANTS}; | ||||
| use levenshtein::levenshtein; | ||||
| use regex_command_attr::command; | ||||
| use serenity::{client::Context, model::misc::Mentionable}; | ||||
| use serenity::{ | ||||
|     client::Context, | ||||
|     model::{ | ||||
|         interactions::InteractionResponseType, misc::Mentionable, | ||||
|         prelude::InteractionApplicationCommandCallbackDataFlags, | ||||
|     }, | ||||
| }; | ||||
|  | ||||
| use crate::{ | ||||
|     component_models::{ComponentDataModel, Restrict}, | ||||
|     consts::THEME_COLOR, | ||||
|     framework::{ | ||||
|         CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue, PermissionLevel, | ||||
|     }, | ||||
|     models::{channel_data::ChannelData, CtxData}, | ||||
|     PopularTimezones, RegexFramework, SQLPool, | ||||
|     framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue}, | ||||
|     hooks::{CHECK_GUILD_PERMISSIONS_HOOK, CHECK_MANAGED_PERMISSIONS_HOOK}, | ||||
|     models::{channel_data::ChannelData, command_macro::CommandMacro, CtxData}, | ||||
|     PopularTimezones, RecordingMacros, RegexFramework, SQLPool, | ||||
| }; | ||||
|  | ||||
| #[command("blacklist")] | ||||
| @@ -23,13 +28,9 @@ use crate::{ | ||||
|     required = false | ||||
| )] | ||||
| #[supports_dm(false)] | ||||
| #[required_permissions(Restricted)] | ||||
| #[hook(CHECK_GUILD_PERMISSIONS_HOOK)] | ||||
| #[can_blacklist(false)] | ||||
| async fn blacklist( | ||||
|     ctx: &Context, | ||||
|     invoke: &(dyn CommandInvoke + Send + Sync), | ||||
|     args: CommandOptions, | ||||
| ) { | ||||
| async fn blacklist(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|  | ||||
|     let channel = match args.get("channel") { | ||||
| @@ -72,7 +73,7 @@ async fn blacklist( | ||||
|     kind = "String", | ||||
|     required = false | ||||
| )] | ||||
| async fn timezone(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { | ||||
| async fn timezone(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|     let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap(); | ||||
|  | ||||
| @@ -178,8 +179,8 @@ You may want to use one of the popular timezones below, otherwise click [here](h | ||||
| #[command("prefix")] | ||||
| #[description("Configure a prefix for text-based commands (deprecated)")] | ||||
| #[supports_dm(false)] | ||||
| #[required_permissions(Restricted)] | ||||
| async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: String) { | ||||
| #[hook(CHECK_GUILD_PERMISSIONS_HOOK)] | ||||
| async fn prefix(ctx: &Context, invoke: CommandInvoke, args: String) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|  | ||||
|     let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap(); | ||||
| @@ -222,8 +223,8 @@ async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: | ||||
|     required = true | ||||
| )] | ||||
| #[supports_dm(false)] | ||||
| #[required_permissions(Restricted)] | ||||
| async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { | ||||
| #[hook(CHECK_GUILD_PERMISSIONS_HOOK)] | ||||
| async fn restrict(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|     let framework = ctx.data.read().await.get::<RegexFramework>().cloned().unwrap(); | ||||
|  | ||||
| @@ -240,7 +241,7 @@ async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), arg | ||||
|         let restrictable_commands = framework | ||||
|             .commands | ||||
|             .iter() | ||||
|             .filter(|c| c.required_permissions == PermissionLevel::Managed) | ||||
|             .filter(|c| c.hooks.contains(&&CHECK_MANAGED_PERMISSIONS_HOOK)) | ||||
|             .map(|c| c.names[0].to_string()) | ||||
|             .collect::<Vec<String>>(); | ||||
|  | ||||
| @@ -289,6 +290,132 @@ async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), arg | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[command("macro")] | ||||
| #[description("Record and replay command sequences")] | ||||
| #[subcommand("record")] | ||||
| #[description("Start recording up to 5 commands to replay")] | ||||
| #[arg(name = "name", description = "Name for the new macro", kind = "String", required = true)] | ||||
| #[arg( | ||||
|     name = "description", | ||||
|     description = "Description for the new macro", | ||||
|     kind = "String", | ||||
|     required = false | ||||
| )] | ||||
| #[subcommand("finish")] | ||||
| #[description("Finish current recording")] | ||||
| #[subcommand("list")] | ||||
| #[description("List recorded macros")] | ||||
| #[subcommand("run")] | ||||
| #[description("Run a recorded macro")] | ||||
| #[arg(name = "name", description = "Name of the macro to run", kind = "String", required = true)] | ||||
| #[supports_dm(false)] | ||||
| #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] | ||||
| async fn macro_cmd(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let interaction = invoke.interaction().unwrap(); | ||||
|  | ||||
|     match args.subcommand.clone().unwrap().as_str() { | ||||
|         "record" => { | ||||
|             let macro_buffer = ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap(); | ||||
|  | ||||
|             { | ||||
|                 let mut lock = macro_buffer.write().await; | ||||
|  | ||||
|                 let guild_id = interaction.guild_id.unwrap(); | ||||
|  | ||||
|                 lock.insert( | ||||
|                     (guild_id, interaction.user.id), | ||||
|                     CommandMacro { | ||||
|                         guild_id, | ||||
|                         name: args.get("name").unwrap().to_string(), | ||||
|                         description: args.get("description").map(|d| d.to_string()), | ||||
|                         commands: vec![], | ||||
|                     }, | ||||
|                 ); | ||||
|             } | ||||
|  | ||||
|             let _ = interaction | ||||
|                 .create_interaction_response(&ctx, |r| { | ||||
|                     r.kind(InteractionResponseType::ChannelMessageWithSource) | ||||
|                         .interaction_response_data(|d| { | ||||
|                             d.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL) | ||||
|                                 .create_embed(|e| { | ||||
|                                     e | ||||
|                                     .title("Macro Recording Started") | ||||
|                                     .description( | ||||
| "Run up to 5 commands, or type `/macro finish` to stop at any point. | ||||
| Any commands ran as part of recording will be inconsequential") | ||||
|                                     .color(*THEME_COLOR) | ||||
|                                 }) | ||||
|                         }) | ||||
|                 }) | ||||
|                 .await; | ||||
|         } | ||||
|         "finish" => { | ||||
|             let key = (interaction.guild_id.unwrap(), interaction.user.id); | ||||
|             let macro_buffer = ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap(); | ||||
|  | ||||
|             { | ||||
|                 let lock = macro_buffer.read().await; | ||||
|                 let contained = lock.get(&key); | ||||
|  | ||||
|                 if contained.map_or(true, |cmacro| cmacro.commands.is_empty()) { | ||||
|                     let _ = interaction | ||||
|                         .create_interaction_response(&ctx, |r| { | ||||
|                             r.kind(InteractionResponseType::ChannelMessageWithSource) | ||||
|                                 .interaction_response_data(|d| { | ||||
|                                     d.create_embed(|e| { | ||||
|                                         e.title("No Macro Recorded") | ||||
|                                             .description( | ||||
|                                                 "Use `/macro record` to start recording a macro", | ||||
|                                             ) | ||||
|                                             .color(*THEME_COLOR) | ||||
|                                     }) | ||||
|                                 }) | ||||
|                         }) | ||||
|                         .await; | ||||
|                 } else { | ||||
|                     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|  | ||||
|                     let command_macro = contained.unwrap(); | ||||
|                     let json = serde_json::to_string(&command_macro.commands).unwrap(); | ||||
|  | ||||
|                     sqlx::query!( | ||||
|                         "INSERT INTO macro (guild_id, name, description, commands) VALUES ((SELECT id FROM guilds WHERE guild = ?), ?, ?, ?)", | ||||
|                         command_macro.guild_id.0, | ||||
|                         command_macro.name, | ||||
|                         command_macro.description, | ||||
|                         json | ||||
|                     ) | ||||
|                         .execute(&pool) | ||||
|                         .await | ||||
|                         .unwrap(); | ||||
|  | ||||
|                     let _ = interaction | ||||
|                         .create_interaction_response(&ctx, |r| { | ||||
|                             r.kind(InteractionResponseType::ChannelMessageWithSource) | ||||
|                                 .interaction_response_data(|d| { | ||||
|                                     d.create_embed(|e| { | ||||
|                                         e.title("Macro Recorded") | ||||
|                                             .description("Use `/macro run` to execute the macro") | ||||
|                                             .color(*THEME_COLOR) | ||||
|                                     }) | ||||
|                                 }) | ||||
|                         }) | ||||
|                         .await; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             { | ||||
|                 let mut lock = macro_buffer.write().await; | ||||
|                 lock.remove(&key); | ||||
|             } | ||||
|         } | ||||
|         "list" => {} | ||||
|         "run" => {} | ||||
|         _ => {} | ||||
|     } | ||||
| } | ||||
|  | ||||
| /* | ||||
| #[command("alias")] | ||||
| #[supports_dm(false)] | ||||
|   | ||||
| @@ -5,16 +5,13 @@ use std::{ | ||||
| }; | ||||
|  | ||||
| use chrono::NaiveDateTime; | ||||
| use chrono_tz::Tz; | ||||
| use num_integer::Integer; | ||||
| use regex_command_attr::command; | ||||
| use serenity::{ | ||||
|     builder::CreateEmbed, | ||||
|     builder::{CreateEmbed, CreateInteractionResponse}, | ||||
|     client::Context, | ||||
|     model::{ | ||||
|         channel::Channel, | ||||
|         id::{GuildId, UserId}, | ||||
|         interactions::InteractionResponseType, | ||||
|     }, | ||||
|     model::{channel::Channel, interactions::InteractionResponseType}, | ||||
| }; | ||||
|  | ||||
| use crate::{ | ||||
| @@ -28,6 +25,7 @@ use crate::{ | ||||
|         REGEX_NATURAL_COMMAND_2, THEME_COLOR, | ||||
|     }, | ||||
|     framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue}, | ||||
|     hooks::{CHECK_GUILD_PERMISSIONS_HOOK, CHECK_MANAGED_PERMISSIONS_HOOK}, | ||||
|     models::{ | ||||
|         channel_data::ChannelData, | ||||
|         guild_data::GuildData, | ||||
| @@ -55,8 +53,8 @@ use crate::{ | ||||
|     required = false | ||||
| )] | ||||
| #[supports_dm(false)] | ||||
| #[required_permissions(Restricted)] | ||||
| async fn pause(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { | ||||
| #[hook(CHECK_GUILD_PERMISSIONS_HOOK)] | ||||
| async fn pause(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|  | ||||
|     let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; | ||||
| @@ -141,8 +139,8 @@ async fn pause(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: | ||||
|     kind = "Integer", | ||||
|     required = false | ||||
| )] | ||||
| #[required_permissions(Restricted)] | ||||
| async fn offset(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { | ||||
| #[hook(CHECK_GUILD_PERMISSIONS_HOOK)] | ||||
| async fn offset(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|  | ||||
|     let combined_time = args.get("hours").map_or(0, |h| h.as_i64().unwrap() * 3600) | ||||
| @@ -218,8 +216,8 @@ WHERE FIND_IN_SET(channels.`channel`, ?)", | ||||
|     kind = "Integer", | ||||
|     required = false | ||||
| )] | ||||
| #[required_permissions(Restricted)] | ||||
| async fn nudge(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { | ||||
| #[hook(CHECK_GUILD_PERMISSIONS_HOOK)] | ||||
| async fn nudge(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|  | ||||
|     let combined_time = args.get("minutes").map_or(0, |m| m.as_i64().unwrap() * 60) | ||||
| @@ -270,8 +268,8 @@ async fn nudge(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: | ||||
|     kind = "Boolean", | ||||
|     required = false | ||||
| )] | ||||
| #[required_permissions(Managed)] | ||||
| async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { | ||||
| #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] | ||||
| async fn look(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
|  | ||||
|     let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; | ||||
| @@ -363,103 +361,132 @@ async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: C | ||||
|  | ||||
| #[command("del")] | ||||
| #[description("Delete reminders")] | ||||
| #[required_permissions(Managed)] | ||||
| async fn delete(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { | ||||
|     let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); | ||||
| #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] | ||||
| async fn delete(ctx: &Context, invoke: CommandInvoke, _args: CommandOptions) { | ||||
|     let interaction = invoke.interaction().unwrap(); | ||||
|  | ||||
|     let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; | ||||
|     let timezone = ctx.timezone(interaction.user.id).await; | ||||
|  | ||||
|     let reminders = Reminder::from_guild(ctx, invoke.guild_id(), invoke.author_id()).await; | ||||
|     let reminders = Reminder::from_guild(ctx, interaction.guild_id, interaction.user.id).await; | ||||
|  | ||||
|     if reminders.is_empty() { | ||||
|         let _ = invoke | ||||
|             .respond( | ||||
|                 ctx.http.clone(), | ||||
|                 CreateGenericResponse::new().content("No reminders to delete!"), | ||||
|             ) | ||||
|             .await; | ||||
|     } else { | ||||
|         let mut char_count = 0; | ||||
|     let resp = show_delete_page(&reminders, 0, timezone).await; | ||||
|  | ||||
|         let (shown_reminders, display_vec): (Vec<&Reminder>, Vec<String>) = reminders | ||||
|             .iter() | ||||
|             .enumerate() | ||||
|             .map(|(count, reminder)| (reminder, reminder.display_del(count, &timezone))) | ||||
|             .take_while(|(_, p)| { | ||||
|                 char_count += p.len(); | ||||
|  | ||||
|                 char_count < EMBED_DESCRIPTION_MAX_LENGTH | ||||
|             }) | ||||
|             .unzip(); | ||||
|  | ||||
|         let display = display_vec.join("\n"); | ||||
|  | ||||
|         let pages = reminders | ||||
|             .iter() | ||||
|             .enumerate() | ||||
|             .map(|(count, reminder)| reminder.display_del(count, &timezone)) | ||||
|             .fold(0, |t, r| t + r.len()) | ||||
|             .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); | ||||
|  | ||||
|         let pager = DelPager::new(timezone); | ||||
|  | ||||
|         let del_selector = ComponentDataModel::DelSelector(DelSelector { page: 0, timezone }); | ||||
|  | ||||
|         invoke | ||||
|             .respond( | ||||
|                 ctx.http.clone(), | ||||
|                 CreateGenericResponse::new() | ||||
|                     .embed(|e| { | ||||
|                         e.title("Delete Reminders") | ||||
|                             .description(display) | ||||
|                             .footer(|f| f.text(format!("Page {} of {}", 1, pages))) | ||||
|                             .color(*THEME_COLOR) | ||||
|                     }) | ||||
|                     .components(|comp| { | ||||
|                         pager.create_button_row(pages, comp); | ||||
|  | ||||
|                         comp.create_action_row(|row| { | ||||
|                             row.create_select_menu(|menu| { | ||||
|                                 menu.custom_id(del_selector.to_custom_id()).options(|opt| { | ||||
|                                     for (count, reminder) in shown_reminders.iter().enumerate() { | ||||
|                                         opt.create_option(|o| { | ||||
|                                             o.label(count + 1).value(reminder.id).description({ | ||||
|                                                 let c = reminder.display_content(); | ||||
|  | ||||
|                                                 if c.len() > 100 { | ||||
|                                                     format!( | ||||
|                                                         "{}...", | ||||
|                                                         reminder | ||||
|                                                             .display_content() | ||||
|                                                             .chars() | ||||
|                                                             .take(97) | ||||
|                                                             .collect::<String>() | ||||
|                                                     ) | ||||
|                                                 } else { | ||||
|                                                     c.to_string() | ||||
|                                                 } | ||||
|                                             }) | ||||
|                                         }); | ||||
|                                     } | ||||
|  | ||||
|                                     opt | ||||
|                                 }) | ||||
|                             }) | ||||
|                         }) | ||||
|                     }), | ||||
|             ) | ||||
|             .await | ||||
|             .unwrap(); | ||||
|     } | ||||
|     let _ = interaction | ||||
|         .create_interaction_response(&ctx, |r| { | ||||
|             *r = resp; | ||||
|             r | ||||
|         }) | ||||
|         .await; | ||||
| } | ||||
|  | ||||
| async fn show_delete_page( | ||||
|     ctx: &Context, | ||||
|     guild_id: Option<GuildId>, | ||||
|     user_id: UserId, | ||||
| pub fn max_delete_page(reminders: &Vec<Reminder>, timezone: &Tz) -> usize { | ||||
|     reminders | ||||
|         .iter() | ||||
|         .enumerate() | ||||
|         .map(|(count, reminder)| reminder.display_del(count, timezone)) | ||||
|         .fold(0, |t, r| t + r.len()) | ||||
|         .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH) | ||||
| } | ||||
|  | ||||
| pub async fn show_delete_page( | ||||
|     reminders: &Vec<Reminder>, | ||||
|     page: usize, | ||||
|     timezone: Tz, | ||||
| ) { | ||||
| ) -> CreateInteractionResponse { | ||||
|     let pager = DelPager::new(timezone); | ||||
|  | ||||
|     if reminders.is_empty() { | ||||
|         let mut embed = CreateEmbed::default(); | ||||
|         embed.title("Delete Reminders").description("No Reminders").color(*THEME_COLOR); | ||||
|  | ||||
|         let mut response = CreateInteractionResponse::default(); | ||||
|         response.kind(InteractionResponseType::UpdateMessage).interaction_response_data( | ||||
|             |response| { | ||||
|                 response.embeds(vec![embed]).components(|comp| { | ||||
|                     pager.create_button_row(0, comp); | ||||
|                     comp | ||||
|                 }) | ||||
|             }, | ||||
|         ); | ||||
|  | ||||
|         return response; | ||||
|     } | ||||
|  | ||||
|     let pages = max_delete_page(&reminders, &timezone); | ||||
|  | ||||
|     let mut page = page; | ||||
|     if page >= pages { | ||||
|         page = pages - 1; | ||||
|     } | ||||
|  | ||||
|     let mut char_count = 0; | ||||
|     let mut skip_char_count = 0; | ||||
|     let mut first_num = 0; | ||||
|  | ||||
|     let (shown_reminders, display_vec): (Vec<&Reminder>, Vec<String>) = reminders | ||||
|         .iter() | ||||
|         .enumerate() | ||||
|         .map(|(count, reminder)| (reminder, reminder.display_del(count, &timezone))) | ||||
|         .skip_while(|(_, p)| { | ||||
|             first_num += 1; | ||||
|             skip_char_count += p.len(); | ||||
|  | ||||
|             skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * page | ||||
|         }) | ||||
|         .take_while(|(_, p)| { | ||||
|             char_count += p.len(); | ||||
|  | ||||
|             char_count < EMBED_DESCRIPTION_MAX_LENGTH | ||||
|         }) | ||||
|         .unzip(); | ||||
|  | ||||
|     let display = display_vec.join("\n"); | ||||
|  | ||||
|     let del_selector = ComponentDataModel::DelSelector(DelSelector { page, timezone }); | ||||
|  | ||||
|     let mut embed = CreateEmbed::default(); | ||||
|     embed | ||||
|         .title("Delete Reminders") | ||||
|         .description(display) | ||||
|         .footer(|f| f.text(format!("Page {} of {}", page + 1, pages))) | ||||
|         .color(*THEME_COLOR); | ||||
|  | ||||
|     let mut response = CreateInteractionResponse::default(); | ||||
|     response.kind(InteractionResponseType::UpdateMessage).interaction_response_data(|d| { | ||||
|         d.embeds(vec![embed]).components(|comp| { | ||||
|             pager.create_button_row(pages, comp); | ||||
|  | ||||
|             comp.create_action_row(|row| { | ||||
|                 row.create_select_menu(|menu| { | ||||
|                     menu.custom_id(del_selector.to_custom_id()).options(|opt| { | ||||
|                         for (count, reminder) in shown_reminders.iter().enumerate() { | ||||
|                             opt.create_option(|o| { | ||||
|                                 o.label(count + first_num).value(reminder.id).description({ | ||||
|                                     let c = reminder.display_content(); | ||||
|  | ||||
|                                     if c.len() > 100 { | ||||
|                                         format!( | ||||
|                                             "{}...", | ||||
|                                             reminder | ||||
|                                                 .display_content() | ||||
|                                                 .chars() | ||||
|                                                 .take(97) | ||||
|                                                 .collect::<String>() | ||||
|                                         ) | ||||
|                                     } else { | ||||
|                                         c.to_string() | ||||
|                                     } | ||||
|                                 }) | ||||
|                             }); | ||||
|                         } | ||||
|  | ||||
|                         opt | ||||
|                     }) | ||||
|                 }) | ||||
|             }) | ||||
|         }) | ||||
|     }); | ||||
|     response | ||||
| } | ||||
|  | ||||
| #[command("timer")] | ||||
| @@ -472,8 +499,8 @@ async fn show_delete_page( | ||||
| #[subcommand("delete")] | ||||
| #[description("Delete a timer")] | ||||
| #[arg(name = "name", description = "Name of the timer to delete", kind = "String", required = true)] | ||||
| #[required_permissions(Managed)] | ||||
| async fn timer(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { | ||||
| #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] | ||||
| async fn timer(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     fn time_difference(start_time: NaiveDateTime) -> String { | ||||
|         let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; | ||||
|         let now = NaiveDateTime::from_timestamp(unix_time, 0); | ||||
| @@ -638,8 +665,8 @@ DELETE FROM timers WHERE owner = ? AND name = ? | ||||
|     kind = "Boolean", | ||||
|     required = false | ||||
| )] | ||||
| #[required_permissions(Managed)] | ||||
| async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { | ||||
| #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)] | ||||
| async fn remind(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) { | ||||
|     let interaction = invoke.interaction().unwrap(); | ||||
|  | ||||
|     // defer response since processing times can take some time | ||||
| @@ -650,7 +677,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: | ||||
|         .await | ||||
|         .unwrap(); | ||||
|  | ||||
|     let user_data = ctx.user_data(invoke.author_id()).await.unwrap(); | ||||
|     let user_data = ctx.user_data(interaction.user.id).await.unwrap(); | ||||
|     let timezone = user_data.timezone(); | ||||
|  | ||||
|     let time = { | ||||
| @@ -675,7 +702,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: | ||||
|                     .unwrap_or(vec![]); | ||||
|  | ||||
|                 if list.is_empty() { | ||||
|                     vec![ReminderScope::Channel(invoke.channel_id().0)] | ||||
|                     vec![ReminderScope::Channel(interaction.channel_id.0)] | ||||
|                 } else { | ||||
|                     list | ||||
|                 } | ||||
| @@ -698,7 +725,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             let mut builder = MultiReminderBuilder::new(ctx, invoke.guild_id()) | ||||
|             let mut builder = MultiReminderBuilder::new(ctx, interaction.guild_id) | ||||
|                 .author(user_data) | ||||
|                 .content(content) | ||||
|                 .time(time) | ||||
|   | ||||
| @@ -17,6 +17,7 @@ use serenity::{ | ||||
| }; | ||||
|  | ||||
| use crate::{ | ||||
|     commands::reminder_cmds::{max_delete_page, show_delete_page}, | ||||
|     component_models::pager::{DelPager, LookPager, Pager}, | ||||
|     consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR}, | ||||
|     models::reminder::Reminder, | ||||
| @@ -165,98 +166,15 @@ INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id | ||||
|                 let reminders = | ||||
|                     Reminder::from_guild(ctx, component.guild_id, component.user.id).await; | ||||
|  | ||||
|                 let pages = reminders | ||||
|                     .iter() | ||||
|                     .enumerate() | ||||
|                     .map(|(count, reminder)| reminder.display_del(count, &pager.timezone)) | ||||
|                     .fold(0, |t, r| t + r.len()) | ||||
|                     .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); | ||||
|                 let max_pages = max_delete_page(&reminders, &pager.timezone); | ||||
|  | ||||
|                 let next_page = pager.next_page(pages); | ||||
|                 let resp = | ||||
|                     show_delete_page(&reminders, pager.next_page(max_pages), pager.timezone).await; | ||||
|  | ||||
|                 let mut char_count = 0; | ||||
|                 let mut skip_char_count = 0; | ||||
|                 let mut first_num = 0; | ||||
|  | ||||
|                 let (shown_reminders, display_vec): (Vec<&Reminder>, Vec<String>) = reminders | ||||
|                     .iter() | ||||
|                     .enumerate() | ||||
|                     .map(|(count, reminder)| { | ||||
|                         (reminder, reminder.display_del(count, &pager.timezone)) | ||||
|                     }) | ||||
|                     .skip_while(|(_, p)| { | ||||
|                         first_num += 1; | ||||
|                         skip_char_count += p.len(); | ||||
|  | ||||
|                         skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * next_page | ||||
|                     }) | ||||
|                     .take_while(|(_, p)| { | ||||
|                         char_count += p.len(); | ||||
|  | ||||
|                         char_count < EMBED_DESCRIPTION_MAX_LENGTH | ||||
|                     }) | ||||
|                     .unzip(); | ||||
|  | ||||
|                 let display = display_vec.join("\n"); | ||||
|  | ||||
|                 let del_selector = ComponentDataModel::DelSelector(DelSelector { | ||||
|                     page: next_page, | ||||
|                     timezone: pager.timezone, | ||||
|                 }); | ||||
|  | ||||
|                 let mut embed = CreateEmbed::default(); | ||||
|                 embed | ||||
|                     .title("Delete Reminders") | ||||
|                     .description(display) | ||||
|                     .footer(|f| f.text(format!("Page {} of {}", next_page + 1, pages))) | ||||
|                     .color(*THEME_COLOR); | ||||
|  | ||||
|                 component | ||||
|                     .create_interaction_response(&ctx, |r| { | ||||
|                         r.kind(InteractionResponseType::UpdateMessage).interaction_response_data( | ||||
|                             |response| { | ||||
|                                 response.embeds(vec![embed]).components(|comp| { | ||||
|                                     pager.create_button_row(pages, comp); | ||||
|  | ||||
|                                     comp.create_action_row(|row| { | ||||
|                                         row.create_select_menu(|menu| { | ||||
|                                             menu.custom_id(del_selector.to_custom_id()).options( | ||||
|                                                 |opt| { | ||||
|                                                     for (count, reminder) in | ||||
|                                                         shown_reminders.iter().enumerate() | ||||
|                                                     { | ||||
|                                                         opt.create_option(|o| { | ||||
|                                                             o.label(count + first_num) | ||||
|                                                                 .value(reminder.id) | ||||
|                                                                 .description({ | ||||
|                                                                     let c = | ||||
|                                                                         reminder.display_content(); | ||||
|  | ||||
|                                                                     if c.len() > 100 { | ||||
|                                                                         format!( | ||||
|                                                                             "{}...", | ||||
|                                                                             reminder | ||||
|                                                                                 .display_content() | ||||
|                                                                                 .chars() | ||||
|                                                                                 .take(97) | ||||
|                                                                                 .collect::<String>( | ||||
|                                                                                 ) | ||||
|                                                                         ) | ||||
|                                                                     } else { | ||||
|                                                                         c.to_string() | ||||
|                                                                     } | ||||
|                                                                 }) | ||||
|                                                         }); | ||||
|                                                     } | ||||
|  | ||||
|                                                     opt | ||||
|                                                 }, | ||||
|                                             ) | ||||
|                                         }) | ||||
|                                     }) | ||||
|                                 }) | ||||
|                             }, | ||||
|                         ) | ||||
|                 let _ = component | ||||
|                     .create_interaction_response(&ctx, move |r| { | ||||
|                         *r = resp; | ||||
|                         r | ||||
|                     }) | ||||
|                     .await; | ||||
|             } | ||||
| @@ -272,119 +190,12 @@ INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id | ||||
|                 let reminders = | ||||
|                     Reminder::from_guild(ctx, component.guild_id, component.user.id).await; | ||||
|  | ||||
|                 if reminders.is_empty() { | ||||
|                     let mut embed = CreateEmbed::default(); | ||||
|                     embed.title("Delete Reminders").description("No Reminders").color(*THEME_COLOR); | ||||
|                 let resp = show_delete_page(&reminders, selector.page, selector.timezone).await; | ||||
|  | ||||
|                     component | ||||
|                         .create_interaction_response(&ctx, |r| { | ||||
|                             r.kind(InteractionResponseType::UpdateMessage) | ||||
|                                 .interaction_response_data(|response| { | ||||
|                                     response.embeds(vec![embed]).components(|comp| comp) | ||||
|                                 }) | ||||
|                         }) | ||||
|                         .await; | ||||
|  | ||||
|                     return; | ||||
|                 } | ||||
|  | ||||
|                 let pages = reminders | ||||
|                     .iter() | ||||
|                     .enumerate() | ||||
|                     .map(|(count, reminder)| reminder.display_del(count, &selector.timezone)) | ||||
|                     .fold(0, |t, r| t + r.len()) | ||||
|                     .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); | ||||
|  | ||||
|                 let mut page = selector.page; | ||||
|                 if page >= pages { | ||||
|                     page = pages - 1; | ||||
|                 } | ||||
|  | ||||
|                 let mut char_count = 0; | ||||
|                 let mut skip_char_count = 0; | ||||
|                 let mut first_num = 0; | ||||
|  | ||||
|                 let (shown_reminders, display_vec): (Vec<&Reminder>, Vec<String>) = reminders | ||||
|                     .iter() | ||||
|                     .enumerate() | ||||
|                     .map(|(count, reminder)| { | ||||
|                         (reminder, reminder.display_del(count, &selector.timezone)) | ||||
|                     }) | ||||
|                     .skip_while(|(_, p)| { | ||||
|                         first_num += 1; | ||||
|                         skip_char_count += p.len(); | ||||
|  | ||||
|                         skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * page | ||||
|                     }) | ||||
|                     .take_while(|(_, p)| { | ||||
|                         char_count += p.len(); | ||||
|  | ||||
|                         char_count < EMBED_DESCRIPTION_MAX_LENGTH | ||||
|                     }) | ||||
|                     .unzip(); | ||||
|  | ||||
|                 let display = display_vec.join("\n"); | ||||
|  | ||||
|                 let pager = DelPager::new(selector.timezone); | ||||
|  | ||||
|                 let del_selector = ComponentDataModel::DelSelector(DelSelector { | ||||
|                     page, | ||||
|                     timezone: selector.timezone, | ||||
|                 }); | ||||
|  | ||||
|                 let mut embed = CreateEmbed::default(); | ||||
|                 embed | ||||
|                     .title("Delete Reminders") | ||||
|                     .description(display) | ||||
|                     .footer(|f| f.text(format!("Page {} of {}", page + 1, pages))) | ||||
|                     .color(*THEME_COLOR); | ||||
|  | ||||
|                 component | ||||
|                     .create_interaction_response(&ctx, |r| { | ||||
|                         r.kind(InteractionResponseType::UpdateMessage).interaction_response_data( | ||||
|                             |response| { | ||||
|                                 response.embeds(vec![embed]).components(|comp| { | ||||
|                                     pager.create_button_row(pages, comp); | ||||
|  | ||||
|                                     comp.create_action_row(|row| { | ||||
|                                         row.create_select_menu(|menu| { | ||||
|                                             menu.custom_id(del_selector.to_custom_id()).options( | ||||
|                                                 |opt| { | ||||
|                                                     for (count, reminder) in | ||||
|                                                         shown_reminders.iter().enumerate() | ||||
|                                                     { | ||||
|                                                         opt.create_option(|o| { | ||||
|                                                             o.label(count + first_num) | ||||
|                                                                 .value(reminder.id) | ||||
|                                                                 .description({ | ||||
|                                                                     let c = | ||||
|                                                                         reminder.display_content(); | ||||
|  | ||||
|                                                                     if c.len() > 100 { | ||||
|                                                                         format!( | ||||
|                                                                             "{}...", | ||||
|                                                                             reminder | ||||
|                                                                                 .display_content() | ||||
|                                                                                 .chars() | ||||
|                                                                                 .take(97) | ||||
|                                                                                 .collect::<String>( | ||||
|                                                                                 ) | ||||
|                                                                         ) | ||||
|                                                                     } else { | ||||
|                                                                         c.to_string() | ||||
|                                                                     } | ||||
|                                                                 }) | ||||
|                                                         }); | ||||
|                                                     } | ||||
|  | ||||
|                                                     opt | ||||
|                                                 }, | ||||
|                                             ) | ||||
|                                         }) | ||||
|                                     }) | ||||
|                                 }) | ||||
|                             }, | ||||
|                         ) | ||||
|                 let _ = component | ||||
|                     .create_interaction_response(&ctx, move |r| { | ||||
|                         *r = resp; | ||||
|                         r | ||||
|                     }) | ||||
|                     .await; | ||||
|             } | ||||
|   | ||||
							
								
								
									
										705
									
								
								src/framework.rs
									
									
									
									
									
								
							
							
						
						
									
										705
									
								
								src/framework.rs
									
									
									
									
									
								
							| @@ -6,8 +6,9 @@ use std::{ | ||||
|     sync::Arc, | ||||
| }; | ||||
|  | ||||
| use log::{error, info, warn}; | ||||
| use log::info; | ||||
| use regex::{Match, Regex, RegexBuilder}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use serenity::{ | ||||
|     async_trait, | ||||
|     builder::{CreateApplicationCommands, CreateComponents, CreateEmbed}, | ||||
| @@ -15,9 +16,9 @@ use serenity::{ | ||||
|     client::Context, | ||||
|     framework::Framework, | ||||
|     futures::prelude::future::BoxFuture, | ||||
|     http::Http, | ||||
|     http::{CacheHttp, Http}, | ||||
|     model::{ | ||||
|         channel::{Channel, GuildChannel, Message}, | ||||
|         channel::Message, | ||||
|         guild::{Guild, Member}, | ||||
|         id::{ChannelId, GuildId, MessageId, RoleId, UserId}, | ||||
|         interactions::{ | ||||
| @@ -29,20 +30,10 @@ use serenity::{ | ||||
|         prelude::application_command::ApplicationCommandInteractionDataOption, | ||||
|     }, | ||||
|     prelude::TypeMapKey, | ||||
|     FutureExt, Result as SerenityResult, | ||||
|     Result as SerenityResult, | ||||
| }; | ||||
|  | ||||
| use crate::{ | ||||
|     models::{channel_data::ChannelData, CtxData}, | ||||
|     LimitExecutors, SQLPool, | ||||
| }; | ||||
|  | ||||
| #[derive(Debug, PartialEq)] | ||||
| pub enum PermissionLevel { | ||||
|     Unrestricted, | ||||
|     Managed, | ||||
|     Restricted, | ||||
| } | ||||
| use crate::{models::CtxData, LimitExecutors}; | ||||
|  | ||||
| pub struct CreateGenericResponse { | ||||
|     content: String, | ||||
| @@ -81,196 +72,135 @@ impl CreateGenericResponse { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[async_trait] | ||||
| pub trait CommandInvoke { | ||||
|     fn channel_id(&self) -> ChannelId; | ||||
|     fn guild_id(&self) -> Option<GuildId>; | ||||
|     fn guild(&self, cache: Arc<Cache>) -> Option<Guild>; | ||||
|     fn author_id(&self) -> UserId; | ||||
|     async fn member(&self, context: &Context) -> SerenityResult<Member>; | ||||
|     fn msg(&self) -> Option<Message>; | ||||
|     fn interaction(&self) -> Option<ApplicationCommandInteraction>; | ||||
|     async fn respond( | ||||
|         &self, | ||||
|         http: Arc<Http>, | ||||
|         generic_response: CreateGenericResponse, | ||||
|     ) -> SerenityResult<()>; | ||||
|     async fn followup( | ||||
|         &self, | ||||
|         http: Arc<Http>, | ||||
|         generic_response: CreateGenericResponse, | ||||
|     ) -> SerenityResult<()>; | ||||
| enum InvokeModel { | ||||
|     Slash(ApplicationCommandInteraction), | ||||
|     Text(Message), | ||||
| } | ||||
|  | ||||
| #[async_trait] | ||||
| impl CommandInvoke for Message { | ||||
|     fn channel_id(&self) -> ChannelId { | ||||
|         self.channel_id | ||||
|     } | ||||
|  | ||||
|     fn guild_id(&self) -> Option<GuildId> { | ||||
|         self.guild_id | ||||
|     } | ||||
|  | ||||
|     fn guild(&self, cache: Arc<Cache>) -> Option<Guild> { | ||||
|         self.guild(cache) | ||||
|     } | ||||
|  | ||||
|     fn author_id(&self) -> UserId { | ||||
|         self.author.id | ||||
|     } | ||||
|  | ||||
|     async fn member(&self, context: &Context) -> SerenityResult<Member> { | ||||
|         self.member(context).await | ||||
|     } | ||||
|  | ||||
|     fn msg(&self) -> Option<Message> { | ||||
|         Some(self.clone()) | ||||
|     } | ||||
|  | ||||
|     fn interaction(&self) -> Option<ApplicationCommandInteraction> { | ||||
|         None | ||||
|     } | ||||
|  | ||||
|     async fn respond( | ||||
|         &self, | ||||
|         http: Arc<Http>, | ||||
|         generic_response: CreateGenericResponse, | ||||
|     ) -> SerenityResult<()> { | ||||
|         self.channel_id | ||||
|             .send_message(http, |m| { | ||||
|                 m.content(generic_response.content); | ||||
|  | ||||
|                 if let Some(embed) = generic_response.embed { | ||||
|                     m.set_embed(embed.clone()); | ||||
|                 } | ||||
|  | ||||
|                 if let Some(components) = generic_response.components { | ||||
|                     m.components(|c| { | ||||
|                         *c = components; | ||||
|                         c | ||||
|                     }); | ||||
|                 } | ||||
|  | ||||
|                 m | ||||
|             }) | ||||
|             .await | ||||
|             .map(|_| ()) | ||||
|     } | ||||
|  | ||||
|     async fn followup( | ||||
|         &self, | ||||
|         http: Arc<Http>, | ||||
|         generic_response: CreateGenericResponse, | ||||
|     ) -> SerenityResult<()> { | ||||
|         self.channel_id | ||||
|             .send_message(http, |m| { | ||||
|                 m.content(generic_response.content); | ||||
|  | ||||
|                 if let Some(embed) = generic_response.embed { | ||||
|                     m.set_embed(embed.clone()); | ||||
|                 } | ||||
|  | ||||
|                 if let Some(components) = generic_response.components { | ||||
|                     m.components(|c| { | ||||
|                         *c = components; | ||||
|                         c | ||||
|                     }); | ||||
|                 } | ||||
|  | ||||
|                 m | ||||
|             }) | ||||
|             .await | ||||
|             .map(|_| ()) | ||||
|     } | ||||
| pub struct CommandInvoke { | ||||
|     model: InvokeModel, | ||||
|     already_responded: bool, | ||||
| } | ||||
|  | ||||
| #[async_trait] | ||||
| impl CommandInvoke for ApplicationCommandInteraction { | ||||
|     fn channel_id(&self) -> ChannelId { | ||||
|         self.channel_id | ||||
| impl CommandInvoke { | ||||
|     fn slash(interaction: ApplicationCommandInteraction) -> Self { | ||||
|         Self { model: InvokeModel::Slash(interaction), already_responded: false } | ||||
|     } | ||||
|  | ||||
|     fn guild_id(&self) -> Option<GuildId> { | ||||
|         self.guild_id | ||||
|     fn msg(msg: Message) -> Self { | ||||
|         Self { model: InvokeModel::Text(msg), already_responded: false } | ||||
|     } | ||||
|  | ||||
|     fn guild(&self, cache: Arc<Cache>) -> Option<Guild> { | ||||
|         if let Some(guild_id) = self.guild_id { | ||||
|             guild_id.to_guild_cached(cache) | ||||
|         } else { | ||||
|             None | ||||
|     pub fn interaction(self) -> Option<ApplicationCommandInteraction> { | ||||
|         match self.model { | ||||
|             InvokeModel::Slash(i) => Some(i), | ||||
|             InvokeModel::Text(_) => None, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn author_id(&self) -> UserId { | ||||
|         self.member.as_ref().unwrap().user.id | ||||
|     pub fn channel_id(&self) -> ChannelId { | ||||
|         match &self.model { | ||||
|             InvokeModel::Slash(i) => i.channel_id, | ||||
|             InvokeModel::Text(m) => m.channel_id, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     async fn member(&self, _: &Context) -> SerenityResult<Member> { | ||||
|         Ok(self.member.clone().unwrap()) | ||||
|     pub fn guild_id(&self) -> Option<GuildId> { | ||||
|         match &self.model { | ||||
|             InvokeModel::Slash(i) => i.guild_id, | ||||
|             InvokeModel::Text(m) => m.guild_id, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn msg(&self) -> Option<Message> { | ||||
|         None | ||||
|     pub fn guild(&self, cache: impl AsRef<Cache>) -> Option<Guild> { | ||||
|         self.guild_id().map(|id| id.to_guild_cached(cache)).flatten() | ||||
|     } | ||||
|  | ||||
|     fn interaction(&self) -> Option<ApplicationCommandInteraction> { | ||||
|         Some(self.clone()) | ||||
|     pub fn author_id(&self) -> UserId { | ||||
|         match &self.model { | ||||
|             InvokeModel::Slash(i) => i.user.id, | ||||
|             InvokeModel::Text(m) => m.author.id, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     async fn respond( | ||||
|     pub async fn member(&self, cache_http: impl CacheHttp) -> Option<Member> { | ||||
|         match &self.model { | ||||
|             InvokeModel::Slash(i) => i.member.clone(), | ||||
|             InvokeModel::Text(m) => m.member(cache_http).await.ok(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub async fn respond( | ||||
|         &self, | ||||
|         http: Arc<Http>, | ||||
|         http: impl AsRef<Http>, | ||||
|         generic_response: CreateGenericResponse, | ||||
|     ) -> SerenityResult<()> { | ||||
|         self.create_interaction_response(http, |r| { | ||||
|             r.kind(InteractionResponseType::ChannelMessageWithSource).interaction_response_data( | ||||
|                 |d| { | ||||
|                     d.content(generic_response.content); | ||||
|         match &self.model { | ||||
|             InvokeModel::Slash(i) => { | ||||
|                 if !self.already_responded { | ||||
|                     i.create_interaction_response(http, |r| { | ||||
|                         r.kind(InteractionResponseType::ChannelMessageWithSource) | ||||
|                             .interaction_response_data(|d| { | ||||
|                                 d.content(generic_response.content); | ||||
|  | ||||
|                                 if let Some(embed) = generic_response.embed { | ||||
|                                     d.add_embed(embed.clone()); | ||||
|                                 } | ||||
|  | ||||
|                                 if let Some(components) = generic_response.components { | ||||
|                                     d.components(|c| { | ||||
|                                         *c = components; | ||||
|                                         c | ||||
|                                     }); | ||||
|                                 } | ||||
|  | ||||
|                                 d | ||||
|                             }) | ||||
|                     }) | ||||
|                     .await | ||||
|                     .map(|_| ()) | ||||
|                 } else { | ||||
|                     i.create_followup_message(http, |d| { | ||||
|                         d.content(generic_response.content); | ||||
|  | ||||
|                         if let Some(embed) = generic_response.embed { | ||||
|                             d.add_embed(embed.clone()); | ||||
|                         } | ||||
|  | ||||
|                         if let Some(components) = generic_response.components { | ||||
|                             d.components(|c| { | ||||
|                                 *c = components; | ||||
|                                 c | ||||
|                             }); | ||||
|                         } | ||||
|  | ||||
|                         d | ||||
|                     }) | ||||
|                     .await | ||||
|                     .map(|_| ()) | ||||
|                 } | ||||
|             } | ||||
|             InvokeModel::Text(m) => m | ||||
|                 .channel_id | ||||
|                 .send_message(http, |m| { | ||||
|                     m.content(generic_response.content); | ||||
|  | ||||
|                     if let Some(embed) = generic_response.embed { | ||||
|                         d.add_embed(embed.clone()); | ||||
|                         m.set_embed(embed.clone()); | ||||
|                     } | ||||
|  | ||||
|                     if let Some(components) = generic_response.components { | ||||
|                         d.components(|c| { | ||||
|                         m.components(|c| { | ||||
|                             *c = components; | ||||
|                             c | ||||
|                         }); | ||||
|                     } | ||||
|  | ||||
|                     d | ||||
|                 }, | ||||
|             ) | ||||
|         }) | ||||
|         .await | ||||
|         .map(|_| ()) | ||||
|     } | ||||
|  | ||||
|     async fn followup( | ||||
|         &self, | ||||
|         http: Arc<Http>, | ||||
|         generic_response: CreateGenericResponse, | ||||
|     ) -> SerenityResult<()> { | ||||
|         self.create_followup_message(http, |d| { | ||||
|             d.content(generic_response.content); | ||||
|  | ||||
|             if let Some(embed) = generic_response.embed { | ||||
|                 d.add_embed(embed.clone()); | ||||
|             } | ||||
|  | ||||
|             if let Some(components) = generic_response.components { | ||||
|                 d.components(|c| { | ||||
|                     *c = components; | ||||
|                     c | ||||
|                 }); | ||||
|             } | ||||
|  | ||||
|             d | ||||
|         }) | ||||
|         .await | ||||
|         .map(|_| ()) | ||||
|                     m | ||||
|                 }) | ||||
|                 .await | ||||
|                 .map(|_| ()), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -283,6 +213,7 @@ pub struct Arg { | ||||
|     pub options: &'static [&'static Self], | ||||
| } | ||||
|  | ||||
| #[derive(Serialize, Deserialize, Clone)] | ||||
| pub enum OptionValue { | ||||
|     String(String), | ||||
|     Integer(i64), | ||||
| @@ -330,8 +261,9 @@ impl OptionValue { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Serialize, Deserialize, Clone)] | ||||
| pub struct CommandOptions { | ||||
|     pub command: String, | ||||
|     pub command: &'static str, | ||||
|     pub subcommand: Option<String>, | ||||
|     pub subcommand_group: Option<String>, | ||||
|     pub options: HashMap<String, OptionValue>, | ||||
| @@ -343,8 +275,17 @@ impl CommandOptions { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<ApplicationCommandInteraction> for CommandOptions { | ||||
|     fn from(interaction: ApplicationCommandInteraction) -> Self { | ||||
| impl CommandOptions { | ||||
|     fn new(command: &'static Command) -> Self { | ||||
|         Self { | ||||
|             command: command.names[0], | ||||
|             subcommand: None, | ||||
|             subcommand_group: None, | ||||
|             options: Default::default(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn populate(mut self, interaction: &ApplicationCommandInteraction) -> Self { | ||||
|         fn match_option( | ||||
|             option: ApplicationCommandInteractionDataOption, | ||||
|             cmd_opts: &mut CommandOptions, | ||||
| @@ -429,35 +370,31 @@ impl From<ApplicationCommandInteraction> for CommandOptions { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let mut cmd_opts = Self { | ||||
|             command: interaction.data.name, | ||||
|             subcommand: None, | ||||
|             subcommand_group: None, | ||||
|             options: Default::default(), | ||||
|         }; | ||||
|  | ||||
|         for option in interaction.data.options { | ||||
|             match_option(option, &mut cmd_opts) | ||||
|         for option in &interaction.data.options { | ||||
|             match_option(option.clone(), &mut self) | ||||
|         } | ||||
|  | ||||
|         cmd_opts | ||||
|         self | ||||
|     } | ||||
| } | ||||
|  | ||||
| type SlashCommandFn = for<'fut> fn( | ||||
|     &'fut Context, | ||||
|     &'fut (dyn CommandInvoke + Sync + Send), | ||||
|     CommandOptions, | ||||
| ) -> BoxFuture<'fut, ()>; | ||||
| pub enum HookResult { | ||||
|     Continue, | ||||
|     Halt, | ||||
| } | ||||
|  | ||||
| type TextCommandFn = for<'fut> fn( | ||||
|     &'fut Context, | ||||
|     &'fut (dyn CommandInvoke + Sync + Send), | ||||
|     String, | ||||
| ) -> BoxFuture<'fut, ()>; | ||||
| type SlashCommandFn = | ||||
|     for<'fut> fn(&'fut Context, CommandInvoke, CommandOptions) -> BoxFuture<'fut, ()>; | ||||
|  | ||||
| type MultiCommandFn = | ||||
|     for<'fut> fn(&'fut Context, &'fut (dyn CommandInvoke + Sync + Send)) -> BoxFuture<'fut, ()>; | ||||
| type TextCommandFn = for<'fut> fn(&'fut Context, CommandInvoke, String) -> BoxFuture<'fut, ()>; | ||||
|  | ||||
| type MultiCommandFn = for<'fut> fn(&'fut Context, CommandInvoke) -> BoxFuture<'fut, ()>; | ||||
|  | ||||
| pub type HookFn = for<'fut> fn( | ||||
|     &'fut Context, | ||||
|     &'fut CommandInvoke, | ||||
|     &'fut CommandOptions, | ||||
| ) -> BoxFuture<'fut, HookResult>; | ||||
|  | ||||
| pub enum CommandFnType { | ||||
|     Slash(SlashCommandFn), | ||||
| @@ -474,6 +411,17 @@ impl CommandFnType { | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub struct Hook { | ||||
|     pub fun: HookFn, | ||||
|     pub uuid: u128, | ||||
| } | ||||
|  | ||||
| impl PartialEq for Hook { | ||||
|     fn eq(&self, other: &Self) -> bool { | ||||
|         self.uuid == other.uuid | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub struct Command { | ||||
|     pub fun: CommandFnType, | ||||
|  | ||||
| @@ -483,11 +431,12 @@ pub struct Command { | ||||
|     pub examples: &'static [&'static str], | ||||
|     pub group: &'static str, | ||||
|  | ||||
|     pub required_permissions: PermissionLevel, | ||||
|     pub args: &'static [&'static Arg], | ||||
|  | ||||
|     pub can_blacklist: bool, | ||||
|     pub supports_dm: bool, | ||||
|  | ||||
|     pub hooks: &'static [&'static Hook], | ||||
| } | ||||
|  | ||||
| impl Hash for Command { | ||||
| @@ -504,81 +453,6 @@ impl PartialEq for Command { | ||||
|  | ||||
| impl Eq for Command {} | ||||
|  | ||||
| impl Command { | ||||
|     async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool { | ||||
|         if self.required_permissions == PermissionLevel::Unrestricted { | ||||
|             true | ||||
|         } else { | ||||
|             let permissions = guild.member_permissions(&ctx, &member.user).await.unwrap(); | ||||
|  | ||||
|             if permissions.manage_guild() | ||||
|                 || (permissions.manage_messages() | ||||
|                     && self.required_permissions == PermissionLevel::Managed) | ||||
|             { | ||||
|                 return true; | ||||
|             } | ||||
|  | ||||
|             if self.required_permissions == PermissionLevel::Managed { | ||||
|                 let pool = ctx | ||||
|                     .data | ||||
|                     .read() | ||||
|                     .await | ||||
|                     .get::<SQLPool>() | ||||
|                     .cloned() | ||||
|                     .expect("Could not get SQLPool from data"); | ||||
|  | ||||
|                 match sqlx::query!( | ||||
|                     " | ||||
| SELECT | ||||
|     role | ||||
| FROM | ||||
|     roles | ||||
| INNER JOIN | ||||
|     command_restrictions ON roles.id = command_restrictions.role_id | ||||
| WHERE | ||||
|     command_restrictions.command = ? AND | ||||
|     roles.guild_id = ( | ||||
|         SELECT | ||||
|             id | ||||
|         FROM | ||||
|             guilds | ||||
|         WHERE | ||||
|             guild = ?) | ||||
|                     ", | ||||
|                     self.names[0], | ||||
|                     guild.id.as_u64() | ||||
|                 ) | ||||
|                 .fetch_all(&pool) | ||||
|                 .await | ||||
|                 { | ||||
|                     Ok(rows) => { | ||||
|                         let role_ids = | ||||
|                             member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>(); | ||||
|  | ||||
|                         for row in rows { | ||||
|                             if role_ids.contains(&row.role) { | ||||
|                                 return true; | ||||
|                             } | ||||
|                         } | ||||
|  | ||||
|                         false | ||||
|                     } | ||||
|  | ||||
|                     Err(sqlx::Error::RowNotFound) => false, | ||||
|  | ||||
|                     Err(e) => { | ||||
|                         warn!("Unexpected error occurred querying command_restrictions: {:?}", e); | ||||
|  | ||||
|                         false | ||||
|                     } | ||||
|                 } | ||||
|             } else { | ||||
|                 false | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub struct RegexFramework { | ||||
|     pub commands_map: HashMap<String, &'static Command>, | ||||
|     pub commands: HashSet<&'static Command>, | ||||
| @@ -589,23 +463,14 @@ pub struct RegexFramework { | ||||
|     ignore_bots: bool, | ||||
|     case_insensitive: bool, | ||||
|     dm_enabled: bool, | ||||
|     default_text_fun: TextCommandFn, | ||||
|     debug_guild: Option<GuildId>, | ||||
|     hooks: Vec<&'static Hook>, | ||||
| } | ||||
|  | ||||
| impl TypeMapKey for RegexFramework { | ||||
|     type Value = Arc<RegexFramework>; | ||||
| } | ||||
|  | ||||
| fn drop_text<'fut>( | ||||
|     _: &'fut Context, | ||||
|     _: &'fut (dyn CommandInvoke + Sync + Send), | ||||
|     _: String, | ||||
| ) -> std::pin::Pin<std::boxed::Box<(dyn std::future::Future<Output = ()> + std::marker::Send + 'fut)>> | ||||
| { | ||||
|     async move {}.boxed() | ||||
| } | ||||
|  | ||||
| impl RegexFramework { | ||||
|     pub fn new<T: Into<u64>>(client_id: T) -> Self { | ||||
|         Self { | ||||
| @@ -618,8 +483,8 @@ impl RegexFramework { | ||||
|             ignore_bots: true, | ||||
|             case_insensitive: true, | ||||
|             dm_enabled: true, | ||||
|             default_text_fun: drop_text, | ||||
|             debug_guild: None, | ||||
|             hooks: vec![], | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -647,6 +512,12 @@ impl RegexFramework { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     pub fn add_hook(mut self, fun: &'static Hook) -> Self { | ||||
|         self.hooks.push(fun); | ||||
|  | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     pub fn add_command(mut self, command: &'static Command) -> Self { | ||||
|         self.commands.insert(command); | ||||
|  | ||||
| @@ -791,77 +662,46 @@ impl RegexFramework { | ||||
|                 .expect(&format!("Received invalid command: {}", interaction.data.name)) | ||||
|         }; | ||||
|  | ||||
|         let guild = interaction.guild(ctx.cache.clone()).unwrap(); | ||||
|         let member = interaction.clone().member.unwrap(); | ||||
|         let args = CommandOptions::new(command).populate(&interaction); | ||||
|         let command_invoke = CommandInvoke::slash(interaction); | ||||
|  | ||||
|         if command.check_permissions(&ctx, &guild, &member).await { | ||||
|             let args = CommandOptions::from(interaction.clone()); | ||||
|  | ||||
|             if !ctx.check_executing(interaction.author_id()).await { | ||||
|                 ctx.set_executing(interaction.author_id()).await; | ||||
|  | ||||
|                 match command.fun { | ||||
|                     CommandFnType::Slash(t) => t(&ctx, &interaction, args).await, | ||||
|                     CommandFnType::Multi(m) => m(&ctx, &interaction).await, | ||||
|                     _ => (), | ||||
|         for hook in command.hooks { | ||||
|             match (hook.fun)(&ctx, &command_invoke, &args).await { | ||||
|                 HookResult::Continue => {} | ||||
|                 HookResult::Halt => { | ||||
|                     return; | ||||
|                 } | ||||
|  | ||||
|                 ctx.drop_executing(interaction.author_id()).await; | ||||
|             } | ||||
|         } else if command.required_permissions == PermissionLevel::Restricted { | ||||
|             let _ = interaction | ||||
|                 .respond( | ||||
|                     ctx.http.clone(), | ||||
|                     CreateGenericResponse::new().content( | ||||
|                         "You must have the `Manage Server` permission to use this command.", | ||||
|                     ), | ||||
|                 ) | ||||
|                 .await; | ||||
|         } else if command.required_permissions == PermissionLevel::Managed { | ||||
|             let _ = interaction | ||||
|                 .respond( | ||||
|                     ctx.http.clone(), | ||||
|                     CreateGenericResponse::new().content( | ||||
|                         "You must have `Manage Messages` or have a role capable of sending reminders to that channel. \ | ||||
|                          Please talk to your server admin, and ask them to use the `/restrict` command to specify \ | ||||
|                          allowed roles.", | ||||
|                     ), | ||||
|                 ) | ||||
|                 .await; | ||||
|         } | ||||
|  | ||||
|         for hook in &self.hooks { | ||||
|             match (hook.fun)(&ctx, &command_invoke, &args).await { | ||||
|                 HookResult::Continue => {} | ||||
|                 HookResult::Halt => { | ||||
|                     return; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let user_id = command_invoke.author_id(); | ||||
|  | ||||
|         if !ctx.check_executing(user_id).await { | ||||
|             ctx.set_executing(user_id).await; | ||||
|  | ||||
|             match command.fun { | ||||
|                 CommandFnType::Slash(t) => t(&ctx, command_invoke, args).await, | ||||
|                 CommandFnType::Multi(m) => m(&ctx, command_invoke).await, | ||||
|                 _ => (), | ||||
|             } | ||||
|  | ||||
|             ctx.drop_executing(user_id).await; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| enum PermissionCheck { | ||||
|     None,              // No permissions | ||||
|     Basic(bool, bool), // Send + Embed permissions (sufficient to reply) | ||||
|     All,               // Above + Manage Webhooks (sufficient to operate) | ||||
| } | ||||
|  | ||||
| #[async_trait] | ||||
| impl Framework for RegexFramework { | ||||
|     async fn dispatch(&self, ctx: Context, msg: Message) { | ||||
|         async fn check_self_permissions( | ||||
|             ctx: &Context, | ||||
|             guild: &Guild, | ||||
|             channel: &GuildChannel, | ||||
|         ) -> SerenityResult<PermissionCheck> { | ||||
|             let user_id = ctx.cache.current_user_id(); | ||||
|  | ||||
|             let guild_perms = guild.member_permissions(&ctx, user_id).await?; | ||||
|             let channel_perms = channel.permissions_for_user(ctx, user_id)?; | ||||
|  | ||||
|             let basic_perms = channel_perms.send_messages(); | ||||
|  | ||||
|             Ok(if basic_perms && guild_perms.manage_webhooks() && channel_perms.embed_links() { | ||||
|                 PermissionCheck::All | ||||
|             } else if basic_perms { | ||||
|                 PermissionCheck::Basic(guild_perms.manage_webhooks(), channel_perms.embed_links()) | ||||
|             } else { | ||||
|                 PermissionCheck::None | ||||
|             }) | ||||
|         } | ||||
|  | ||||
|         async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool { | ||||
|             if let Some(prefix) = prefix_opt { | ||||
|                 let guild_prefix = ctx.prefix(Some(guild.id)).await; | ||||
| @@ -874,144 +714,65 @@ impl Framework for RegexFramework { | ||||
|  | ||||
|         // gate to prevent analysing messages unnecessarily | ||||
|         if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() { | ||||
|         } else { | ||||
|             // Guild Command | ||||
|             if let (Some(guild), Ok(Channel::Guild(channel))) = | ||||
|                 (msg.guild(&ctx), msg.channel(&ctx).await) | ||||
|             { | ||||
|                 let data = ctx.data.read().await; | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|                 let pool = data.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); | ||||
|         let user_id = msg.author.id; | ||||
|         let invoke = CommandInvoke::msg(msg.clone()); | ||||
|  | ||||
|                 if let Some(full_match) = self.command_matcher.captures(&msg.content) { | ||||
|                     if check_prefix(&ctx, &guild, full_match.name("prefix")).await { | ||||
|                         match check_self_permissions(&ctx, &guild, &channel).await { | ||||
|                             Ok(perms) => match perms { | ||||
|                                 PermissionCheck::All => { | ||||
|                                     let command = self | ||||
|                                         .commands_map | ||||
|                                         .get( | ||||
|                                             &full_match | ||||
|                                                 .name("cmd") | ||||
|                                                 .unwrap() | ||||
|                                                 .as_str() | ||||
|                                                 .to_lowercase(), | ||||
|                                         ) | ||||
|                                         .unwrap(); | ||||
|  | ||||
|                                     let channel = msg.channel(&ctx).await.unwrap(); | ||||
|  | ||||
|                                     let channel_data = | ||||
|                                         ChannelData::from_channel(&channel, &pool).await.unwrap(); | ||||
|  | ||||
|                                     if !command.can_blacklist || !channel_data.blacklisted { | ||||
|                                         let args = full_match | ||||
|                                             .name("args") | ||||
|                                             .map(|m| m.as_str()) | ||||
|                                             .unwrap_or("") | ||||
|                                             .to_string(); | ||||
|  | ||||
|                                         let member = guild.member(&ctx, &msg.author).await.unwrap(); | ||||
|  | ||||
|                                         if command.check_permissions(&ctx, &guild, &member).await { | ||||
|                                             if msg.id == MessageId(0) | ||||
|                                                 || !ctx.check_executing(msg.author.id).await | ||||
|                                             { | ||||
|                                                 ctx.set_executing(msg.author.id).await; | ||||
|  | ||||
|                                                 match command.fun { | ||||
|                                                     CommandFnType::Text(t) => t(&ctx, &msg, args), | ||||
|                                                     CommandFnType::Multi(m) => m(&ctx, &msg), | ||||
|                                                     _ => (self.default_text_fun)(&ctx, &msg, args), | ||||
|                                                 } | ||||
|                                                 .await; | ||||
|  | ||||
|                                                 ctx.drop_executing(msg.author.id).await; | ||||
|                                             } | ||||
|                                         } else if command.required_permissions | ||||
|                                             == PermissionLevel::Restricted | ||||
|                                         { | ||||
|                                             let _ = msg | ||||
|                                                 .channel_id | ||||
|                                                 .say( | ||||
|                                                     &ctx, | ||||
|                                                     "You must have the `Manage Server` permission to use this command.", | ||||
|                                                 ) | ||||
|                                                 .await; | ||||
|                                         } else if command.required_permissions | ||||
|                                             == PermissionLevel::Managed | ||||
|                                         { | ||||
|                                             let _ = msg | ||||
|                                                 .channel_id | ||||
|                                                 .say( | ||||
|                                                     &ctx, | ||||
|                                                     "You must have `Manage Messages` or have a role capable of \ | ||||
|                                                      sending reminders to that channel. Please talk to your server \ | ||||
|                                                      admin, and ask them to use the `/restrict` command to specify \ | ||||
|                                                      allowed roles.", | ||||
|                                                 ) | ||||
|                                                 .await; | ||||
|                                         } | ||||
|                                     } | ||||
|                                 } | ||||
|  | ||||
|                                 PermissionCheck::Basic(manage_webhooks, embed_links) => { | ||||
|                                     let _ = msg | ||||
|                                         .channel_id | ||||
|                                         .say( | ||||
|                                             &ctx, | ||||
|                                             format!( | ||||
|                                                 "Please ensure the bot has the correct permissions: | ||||
|  | ||||
| ✅     **Send Message** | ||||
| {}     **Embed Links** | ||||
| {}     **Manage Webhooks**", | ||||
|                                                 if manage_webhooks { "✅" } else { "❌" }, | ||||
|                                                 if embed_links { "✅" } else { "❌" }, | ||||
|                                             ), | ||||
|                                         ) | ||||
|                                         .await; | ||||
|                                 } | ||||
|  | ||||
|                                 PermissionCheck::None => { | ||||
|                                     warn!("Missing enough permissions for guild {}", guild.id); | ||||
|                                 } | ||||
|                             }, | ||||
|  | ||||
|                             Err(e) => { | ||||
|                                 error!( | ||||
|                                     "Error occurred getting permissions in guild {}: {:?}", | ||||
|                                     guild.id, e | ||||
|                                 ); | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             // DM Command | ||||
|             else if self.dm_enabled { | ||||
|                 if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { | ||||
|         // Guild Command | ||||
|         if let Some(guild) = msg.guild(&ctx) { | ||||
|             if let Some(full_match) = self.command_matcher.captures(&msg.content) { | ||||
|                 if check_prefix(&ctx, &guild, full_match.name("prefix")).await { | ||||
|                     let command = self | ||||
|                         .commands_map | ||||
|                         .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) | ||||
|                         .unwrap(); | ||||
|                     let args = | ||||
|                         full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); | ||||
|  | ||||
|                     if msg.id == MessageId(0) || !ctx.check_executing(msg.author.id).await { | ||||
|                         ctx.set_executing(msg.author.id).await; | ||||
|                     let channel_data = ctx.channel_data(invoke.channel_id()).await.unwrap(); | ||||
|  | ||||
|                         match command.fun { | ||||
|                             CommandFnType::Text(t) => t(&ctx, &msg, args), | ||||
|                             CommandFnType::Multi(m) => m(&ctx, &msg), | ||||
|                             _ => (self.default_text_fun)(&ctx, &msg, args), | ||||
|                     if !command.can_blacklist || !channel_data.blacklisted { | ||||
|                         let args = | ||||
|                             full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); | ||||
|  | ||||
|                         if msg.id == MessageId(0) || !ctx.check_executing(user_id).await { | ||||
|                             ctx.set_executing(user_id).await; | ||||
|  | ||||
|                             match command.fun { | ||||
|                                 CommandFnType::Text(t) => t(&ctx, invoke, args).await, | ||||
|                                 CommandFnType::Multi(m) => m(&ctx, invoke).await, | ||||
|                                 _ => {} | ||||
|                             }; | ||||
|  | ||||
|                             ctx.drop_executing(user_id).await; | ||||
|                         } | ||||
|                         .await; | ||||
|  | ||||
|                         ctx.drop_executing(msg.author.id).await; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         // DM Command | ||||
|         else if self.dm_enabled { | ||||
|             if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { | ||||
|                 let command = self | ||||
|                     .commands_map | ||||
|                     .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) | ||||
|                     .unwrap(); | ||||
|                 let args = full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); | ||||
|  | ||||
|                 let user_id = invoke.author_id(); | ||||
|  | ||||
|                 if msg.id == MessageId(0) || !ctx.check_executing(user_id).await { | ||||
|                     ctx.set_executing(user_id).await; | ||||
|  | ||||
|                     match command.fun { | ||||
|                         CommandFnType::Text(t) => t(&ctx, invoke, args).await, | ||||
|                         CommandFnType::Multi(m) => m(&ctx, invoke).await, | ||||
|                         _ => {} | ||||
|                     }; | ||||
|  | ||||
|                     ctx.drop_executing(user_id).await; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
							
								
								
									
										217
									
								
								src/hooks.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								src/hooks.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,217 @@ | ||||
| use log::warn; | ||||
| use regex_command_attr::check; | ||||
| use serenity::{client::Context, model::channel::Channel}; | ||||
|  | ||||
| use crate::{ | ||||
|     framework::{CommandInvoke, CommandOptions, CreateGenericResponse, HookResult}, | ||||
|     moderation_cmds, RecordingMacros, SQLPool, | ||||
| }; | ||||
|  | ||||
| #[check] | ||||
| pub async fn macro_check( | ||||
|     ctx: &Context, | ||||
|     invoke: &CommandInvoke, | ||||
|     args: &CommandOptions, | ||||
| ) -> HookResult { | ||||
|     if let Some(guild_id) = invoke.guild_id() { | ||||
|         if args.command != moderation_cmds::MACRO_CMD_COMMAND.names[0] { | ||||
|             let active_recordings = | ||||
|                 ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap(); | ||||
|             let mut lock = active_recordings.write().await; | ||||
|  | ||||
|             if let Some(command_macro) = lock.get_mut(&(guild_id, invoke.author_id())) { | ||||
|                 command_macro.commands.push(args.clone()); | ||||
|  | ||||
|                 let _ = invoke | ||||
|                     .respond( | ||||
|                         &ctx, | ||||
|                         CreateGenericResponse::new().content("Command recorded to macro"), | ||||
|                     ) | ||||
|                     .await; | ||||
|  | ||||
|                 HookResult::Halt | ||||
|             } else { | ||||
|                 HookResult::Continue | ||||
|             } | ||||
|         } else { | ||||
|             HookResult::Continue | ||||
|         } | ||||
|     } else { | ||||
|         HookResult::Continue | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[check] | ||||
| pub async fn check_self_permissions( | ||||
|     ctx: &Context, | ||||
|     invoke: &CommandInvoke, | ||||
|     _args: &CommandOptions, | ||||
| ) -> HookResult { | ||||
|     if let Some(guild) = invoke.guild(&ctx) { | ||||
|         let user_id = ctx.cache.current_user_id(); | ||||
|  | ||||
|         let manage_webhooks = | ||||
|             guild.member_permissions(&ctx, user_id).await.map_or(false, |p| p.manage_webhooks()); | ||||
|         let (send_messages, embed_links) = invoke | ||||
|             .channel_id() | ||||
|             .to_channel_cached(&ctx) | ||||
|             .map(|c| { | ||||
|                 if let Channel::Guild(channel) = c { | ||||
|                     channel.permissions_for_user(ctx, user_id).ok() | ||||
|                 } else { | ||||
|                     None | ||||
|                 } | ||||
|             }) | ||||
|             .flatten() | ||||
|             .map_or((false, false), |p| (p.send_messages(), p.embed_links())); | ||||
|  | ||||
|         if manage_webhooks && send_messages && embed_links { | ||||
|             HookResult::Continue | ||||
|         } else { | ||||
|             if send_messages { | ||||
|                 let _ = invoke | ||||
|                     .respond( | ||||
|                         &ctx, | ||||
|                         CreateGenericResponse::new().content(format!( | ||||
|                             "Please ensure the bot has the correct permissions: | ||||
|  | ||||
| ✅     **Send Message** | ||||
| {}     **Embed Links** | ||||
| {}     **Manage Webhooks**", | ||||
|                             if manage_webhooks { "✅" } else { "❌" }, | ||||
|                             if embed_links { "✅" } else { "❌" }, | ||||
|                         )), | ||||
|                     ) | ||||
|                     .await; | ||||
|             } else { | ||||
|                 warn!("Missing permissions in guild {}", guild.id); | ||||
|             } | ||||
|  | ||||
|             HookResult::Halt | ||||
|         } | ||||
|     } else { | ||||
|         HookResult::Continue | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[check] | ||||
| pub async fn check_managed_permissions( | ||||
|     ctx: &Context, | ||||
|     invoke: &CommandInvoke, | ||||
|     args: &CommandOptions, | ||||
| ) -> HookResult { | ||||
|     if let Some(guild) = invoke.guild(&ctx) { | ||||
|         let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap(); | ||||
|  | ||||
|         if permissions.manage_messages() { | ||||
|             return HookResult::Continue; | ||||
|         } | ||||
|  | ||||
|         let member = invoke.member(&ctx).await.unwrap(); | ||||
|  | ||||
|         let pool = ctx | ||||
|             .data | ||||
|             .read() | ||||
|             .await | ||||
|             .get::<SQLPool>() | ||||
|             .cloned() | ||||
|             .expect("Could not get SQLPool from data"); | ||||
|  | ||||
|         match sqlx::query!( | ||||
|             " | ||||
| SELECT | ||||
|     role | ||||
| FROM | ||||
|     roles | ||||
| INNER JOIN | ||||
|     command_restrictions ON roles.id = command_restrictions.role_id | ||||
| WHERE | ||||
|     command_restrictions.command = ? AND | ||||
|     roles.guild_id = ( | ||||
|         SELECT | ||||
|             id | ||||
|         FROM | ||||
|             guilds | ||||
|         WHERE | ||||
|             guild = ?) | ||||
|                     ", | ||||
|             args.command, | ||||
|             guild.id.as_u64() | ||||
|         ) | ||||
|         .fetch_all(&pool) | ||||
|         .await | ||||
|         { | ||||
|             Ok(rows) => { | ||||
|                 let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>(); | ||||
|  | ||||
|                 for row in rows { | ||||
|                     if role_ids.contains(&row.role) { | ||||
|                         return HookResult::Continue; | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 let _ = invoke | ||||
|                     .respond( | ||||
|                         &ctx, | ||||
|                         CreateGenericResponse::new().content( | ||||
|                             "You must have \"Manage Messages\" or have a role capable of sending reminders to that channel. \ | ||||
| Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.", | ||||
|                         ), | ||||
|                     ) | ||||
|                     .await; | ||||
|  | ||||
|                 HookResult::Halt | ||||
|             } | ||||
|  | ||||
|             Err(sqlx::Error::RowNotFound) => { | ||||
|                 let _ = invoke | ||||
|                     .respond( | ||||
|                         &ctx, | ||||
|                         CreateGenericResponse::new().content( | ||||
|                             "You must have \"Manage Messages\" or have a role capable of sending reminders to that channel. \ | ||||
| Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.", | ||||
|                         ), | ||||
|                     ) | ||||
|                     .await; | ||||
|  | ||||
|                 HookResult::Halt | ||||
|             } | ||||
|  | ||||
|             Err(e) => { | ||||
|                 warn!("Unexpected error occurred querying command_restrictions: {:?}", e); | ||||
|  | ||||
|                 HookResult::Halt | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         HookResult::Continue | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[check] | ||||
| pub async fn check_guild_permissions( | ||||
|     ctx: &Context, | ||||
|     invoke: &CommandInvoke, | ||||
|     _args: &CommandOptions, | ||||
| ) -> HookResult { | ||||
|     if let Some(guild) = invoke.guild(&ctx) { | ||||
|         let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap(); | ||||
|  | ||||
|         if !permissions.manage_guild() { | ||||
|             let _ = invoke | ||||
|                 .respond( | ||||
|                     &ctx, | ||||
|                     CreateGenericResponse::new().content( | ||||
|                         "You must have the \"Manage Server\" permission to use this command", | ||||
|                     ), | ||||
|                 ) | ||||
|                 .await; | ||||
|  | ||||
|             HookResult::Halt | ||||
|         } else { | ||||
|             HookResult::Continue | ||||
|         } | ||||
|     } else { | ||||
|         HookResult::Continue | ||||
|     } | ||||
| } | ||||
							
								
								
									
										13
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								src/main.rs
									
									
									
									
									
								
							| @@ -6,6 +6,7 @@ mod commands; | ||||
| mod component_models; | ||||
| mod consts; | ||||
| mod framework; | ||||
| mod hooks; | ||||
| mod models; | ||||
| mod time_parser; | ||||
|  | ||||
| @@ -39,7 +40,7 @@ use crate::{ | ||||
|     component_models::ComponentDataModel, | ||||
|     consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR}, | ||||
|     framework::RegexFramework, | ||||
|     models::guild_data::GuildData, | ||||
|     models::{command_macro::CommandMacro, guild_data::GuildData}, | ||||
| }; | ||||
|  | ||||
| struct GuildDataCache; | ||||
| @@ -72,6 +73,12 @@ impl TypeMapKey for CurrentlyExecuting { | ||||
|     type Value = Arc<RwLock<HashMap<UserId, Instant>>>; | ||||
| } | ||||
|  | ||||
| struct RecordingMacros; | ||||
|  | ||||
| impl TypeMapKey for RecordingMacros { | ||||
|     type Value = Arc<RwLock<HashMap<(GuildId, UserId), CommandMacro>>>; | ||||
| } | ||||
|  | ||||
| #[async_trait] | ||||
| trait LimitExecutors { | ||||
|     async fn check_executing(&self, user: UserId) -> bool; | ||||
| @@ -326,10 +333,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { | ||||
|         .add_command(&moderation_cmds::RESTRICT_COMMAND) | ||||
|         .add_command(&moderation_cmds::TIMEZONE_COMMAND) | ||||
|         .add_command(&moderation_cmds::PREFIX_COMMAND) | ||||
|         .add_command(&moderation_cmds::MACRO_CMD_COMMAND) | ||||
|         /* | ||||
|         .add_command("alias", &moderation_cmds::ALIAS_COMMAND) | ||||
|         .add_command("a", &moderation_cmds::ALIAS_COMMAND) | ||||
|         */ | ||||
|         .add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK) | ||||
|         .add_hook(&hooks::MACRO_CHECK_HOOK) | ||||
|         .build(); | ||||
|  | ||||
|     let framework_arc = Arc::new(framework); | ||||
| @@ -375,6 +385,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { | ||||
|         data.insert::<PopularTimezones>(Arc::new(popular_timezones)); | ||||
|         data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new())); | ||||
|         data.insert::<RegexFramework>(framework_arc.clone()); | ||||
|         data.insert::<RecordingMacros>(Arc::new(RwLock::new(HashMap::new()))); | ||||
|     } | ||||
|  | ||||
|     if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| { | ||||
|   | ||||
							
								
								
									
										10
									
								
								src/models/command_macro.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								src/models/command_macro.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| use serenity::model::id::GuildId; | ||||
|  | ||||
| use crate::framework::CommandOptions; | ||||
|  | ||||
| pub struct CommandMacro { | ||||
|     pub guild_id: GuildId, | ||||
|     pub name: String, | ||||
|     pub description: Option<String>, | ||||
|     pub commands: Vec<CommandOptions>, | ||||
| } | ||||
| @@ -1,4 +1,5 @@ | ||||
| pub mod channel_data; | ||||
| pub mod command_macro; | ||||
| pub mod guild_data; | ||||
| pub mod reminder; | ||||
| pub mod timer; | ||||
|   | ||||
| @@ -1,32 +1,5 @@ | ||||
| use crate::consts::{MAX_TIME, MIN_INTERVAL}; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub enum InteractionError { | ||||
|     InvalidFormat, | ||||
|     InvalidBase64, | ||||
|     InvalidSize, | ||||
|     NoReminder, | ||||
|     SignatureMismatch, | ||||
|     InvalidAction, | ||||
| } | ||||
|  | ||||
| impl ToString for InteractionError { | ||||
|     fn to_string(&self) -> String { | ||||
|         match self { | ||||
|             InteractionError::InvalidFormat => { | ||||
|                 String::from("The interaction data was improperly formatted") | ||||
|             } | ||||
|             InteractionError::InvalidBase64 => String::from("The interaction data was invalid"), | ||||
|             InteractionError::InvalidSize => String::from("The interaction data was invalid"), | ||||
|             InteractionError::NoReminder => String::from("Reminder could not be found"), | ||||
|             InteractionError::SignatureMismatch => { | ||||
|                 String::from("Only the user who did the command can use interactions") | ||||
|             } | ||||
|             InteractionError::InvalidAction => String::from("The action was invalid"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(PartialEq, Eq, Hash, Debug)] | ||||
| pub enum ReminderError { | ||||
|     LongTime, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user