added functionality for reusable hook functions that will execute on commands

This commit is contained in:
jellywx 2021-09-22 21:12:29 +01:00
parent a0974795e1
commit d84d7ab62b
18 changed files with 864 additions and 954 deletions

10
Cargo.lock generated
View File

@ -1213,6 +1213,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn",
"uuid",
] ]
[[package]] [[package]]
@ -2016,6 +2017,15 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "uuid"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
dependencies = [
"getrandom 0.2.3",
]
[[package]] [[package]]
name = "uwl" name = "uwl"
version = "0.6.0" version = "0.6.0"

View File

@ -15,7 +15,6 @@ Reminder Bot can be built by running `cargo build --release` in the top level di
These environment variables must be provided when compiling the bot These environment variables must be provided when compiling the bot
* `DATABASE_URL` - the URL of your MySQL database (`mysql://user[:password]@domain/database`) * `DATABASE_URL` - the URL of your MySQL database (`mysql://user[:password]@domain/database`)
* `WEBHOOK_AVATAR` - accepts the name of an image file located in `$CARGO_MANIFEST_DIR/assets/` to be used as the avatar when creating webhooks. **IMPORTANT: image file must be 128x128 or smaller in size** * `WEBHOOK_AVATAR` - accepts the name of an image file located in `$CARGO_MANIFEST_DIR/assets/` to be used as the avatar when creating webhooks. **IMPORTANT: image file must be 128x128 or smaller in size**
* `STRINGS_FILE` - accepts the name of a compiled strings file located in `$CARGO_MANIFEST_DIR/assets/` to be used for creating messages. Compiled string files can be generated with `compile.py` at https://github.com/reminder-bot/languages
### Setting up Python ### Setting up Python
Reminder Bot by default looks for a venv within it's working directory to run Python out of. To set up a venv, install `python3-venv` and run `python3 -m venv venv`. Then, run `source venv/bin/activate` to activate the venv, and do `pip install dateparser` to install the required library Reminder Bot by default looks for a venv within it's working directory to run Python out of. To set up a venv, install `python3-venv` and run `python3 -m venv venv`. Then, run `source venv/bin/activate` to activate the venv, and do `pip install dateparser` to install the required library
@ -29,14 +28,12 @@ __Required Variables__
__Other Variables__ __Other Variables__
* `MIN_INTERVAL` - default `600`, defines the shortest interval the bot should accept * `MIN_INTERVAL` - default `600`, defines the shortest interval the bot should accept
* `MAX_TIME` - default `1576800000`, defines the maximum time ahead that reminders can be set for
* `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor * `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor
* `DEFAULT_PREFIX` - default `$`, used for the default prefix on new guilds * `DEFAULT_PREFIX` - default `$`, used for the default prefix on new guilds
* `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users * `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users
* `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to * `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to
* `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots * `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots
* `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else * `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else
* `LOCAL_LANGUAGE` - default `EN`. Specifies the string set to fall back to if a string cannot be found (and to be used with new users)
* `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds * `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds
* `CASE_INSENSITIVE` - default `1`, if `1`, commands will be treated with case insensitivity (so both `$help` and `$HELP` will work) * `CASE_INSENSITIVE` - default `1`, if `1`, commands will be treated with case insensitivity (so both `$help` and `$HELP` will work)
* `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran * `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran

View File

@ -13,3 +13,4 @@ proc-macro = true
quote = "^1.0" quote = "^1.0"
syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] } syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] }
proc-macro2 = "1.0" proc-macro2 = "1.0"
uuid = { version = "0.8", features = ["v4"] }

View File

@ -8,7 +8,7 @@ use syn::{
}; };
use crate::{ use crate::{
structures::{ApplicationCommandOptionType, Arg, PermissionLevel}, structures::{ApplicationCommandOptionType, Arg},
util::{AsOption, LitExt}, util::{AsOption, LitExt},
}; };
@ -46,24 +46,15 @@ impl fmt::Display for ValueKind {
fn to_ident(p: Path) -> Result<Ident> { fn to_ident(p: Path) -> Result<Ident> {
if p.segments.is_empty() { if p.segments.is_empty() {
return Err(Error::new( return Err(Error::new(p.span(), "cannot convert an empty path to an identifier"));
p.span(),
"cannot convert an empty path to an identifier",
));
} }
if p.segments.len() > 1 { if p.segments.len() > 1 {
return Err(Error::new( return Err(Error::new(p.span(), "the path must not have more than one segment"));
p.span(),
"the path must not have more than one segment",
));
} }
if !p.segments[0].arguments.is_empty() { if !p.segments[0].arguments.is_empty() {
return Err(Error::new( return Err(Error::new(p.span(), "the singular path segment must not have any arguments"));
p.span(),
"the singular path segment must not have any arguments",
));
} }
Ok(p.segments[0].ident.clone()) Ok(p.segments[0].ident.clone())
@ -85,12 +76,7 @@ impl Values {
literals: Vec<(Option<String>, Lit)>, literals: Vec<(Option<String>, Lit)>,
span: Span, span: Span,
) -> Self { ) -> Self {
Values { Values { name, literals, kind, span }
name,
literals,
kind,
span,
}
} }
} }
@ -145,11 +131,7 @@ pub fn parse_values(attr: &Attribute) -> Result<Values> {
} }
} }
let kind = if lits.len() == 1 { let kind = if lits.len() == 1 { ValueKind::SingleList } else { ValueKind::List };
ValueKind::SingleList
} else {
ValueKind::List
};
Ok(Values::new(name, kind, lits, attr.span())) Ok(Values::new(name, kind, lits, attr.span()))
} else { } else {
@ -183,12 +165,7 @@ pub fn parse_values(attr: &Attribute) -> Result<Values> {
let name = to_ident(meta.path)?; let name = to_ident(meta.path)?;
let lit = meta.lit; let lit = meta.lit;
Ok(Values::new( Ok(Values::new(name, ValueKind::Equals, vec![(None, lit)], attr.span()))
name,
ValueKind::Equals,
vec![(None, lit)],
attr.span(),
))
} }
} }
} }
@ -231,10 +208,7 @@ fn validate(values: &Values, forms: &[ValueKind]) -> Result<()> {
return Err(Error::new( return Err(Error::new(
values.span, values.span,
// Using the `_args` version here to avoid an allocation. // Using the `_args` version here to avoid an allocation.
format_args!( format_args!("the attribute must be in of these forms:\n{}", DisplaySlice(forms)),
"the attribute must be in of these forms:\n{}",
DisplaySlice(forms)
),
)); ));
} }
@ -254,11 +228,7 @@ impl AttributeOption for Vec<String> {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::List])?; validate(&values, &[ValueKind::List])?;
Ok(values Ok(values.literals.into_iter().map(|(_, l)| l.to_str()).collect())
.literals
.into_iter()
.map(|(_, l)| l.to_str())
.collect())
} }
} }
@ -294,37 +264,18 @@ impl AttributeOption for Vec<Ident> {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::List])?; validate(&values, &[ValueKind::List])?;
Ok(values Ok(values.literals.into_iter().map(|(_, l)| l.to_ident()).collect())
.literals
.into_iter()
.map(|(_, l)| l.to_ident())
.collect())
} }
} }
impl AttributeOption for Option<String> { impl AttributeOption for Option<String> {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate( validate(&values, &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList])?;
&values,
&[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList],
)?;
Ok(values.literals.get(0).map(|(_, l)| l.to_str())) Ok(values.literals.get(0).map(|(_, l)| l.to_str()))
} }
} }
impl AttributeOption for PermissionLevel {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::SingleList])?;
Ok(values
.literals
.get(0)
.map(|(_, l)| PermissionLevel::from_str(&*l.to_str()).unwrap())
.unwrap())
}
}
impl AttributeOption for Arg { impl AttributeOption for Arg {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::EqualsList])?; validate(&values, &[ValueKind::EqualsList])?;

View File

@ -2,6 +2,8 @@ pub mod suffixes {
pub const COMMAND: &str = "COMMAND"; pub const COMMAND: &str = "COMMAND";
pub const ARG: &str = "ARG"; pub const ARG: &str = "ARG";
pub const SUBCOMMAND: &str = "SUBCOMMAND"; pub const SUBCOMMAND: &str = "SUBCOMMAND";
pub const CHECK: &str = "CHECK";
pub const HOOK: &str = "HOOK";
} }
pub use self::suffixes::*; pub use self::suffixes::*;

View File

@ -5,6 +5,7 @@ use proc_macro::TokenStream;
use proc_macro2::Ident; use proc_macro2::Ident;
use quote::quote; use quote::quote;
use syn::{parse::Error, parse_macro_input, parse_quote, spanned::Spanned, Lit, Type}; use syn::{parse::Error, parse_macro_input, parse_quote, spanned::Spanned, Lit, Type};
use uuid::Uuid;
pub(crate) mod attributes; pub(crate) mod attributes;
pub(crate) mod consts; pub(crate) mod consts;
@ -43,6 +44,7 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
fun.name.to_string() fun.name.to_string()
}; };
let mut hooks: Vec<Ident> = Vec::new();
let mut options = Options::new(); let mut options = Options::new();
for attribute in &fun.attributes { for attribute in &fun.attributes {
@ -76,11 +78,13 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
util::append_line(&mut options.description, line); util::append_line(&mut options.description, line);
} }
} }
"hook" => {
hooks.push(propagate_err!(attributes::parse(values)));
}
_ => { _ => {
match_options!(name, values, options, span => [ match_options!(name, values, options, span => [
aliases; aliases;
group; group;
required_permissions;
can_blacklist; can_blacklist;
supports_dm supports_dm
]); ]);
@ -93,7 +97,6 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
description, description,
group, group,
examples, examples,
required_permissions,
can_blacklist, can_blacklist,
supports_dm, supports_dm,
mut cmd_args, mut cmd_args,
@ -235,10 +238,10 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
desc: #description, desc: #description,
group: #group, group: #group,
examples: &[#(#examples),*], examples: &[#(#examples),*],
required_permissions: #required_permissions,
can_blacklist: #can_blacklist, can_blacklist: #can_blacklist,
supports_dm: #supports_dm, supports_dm: #supports_dm,
args: &[#(&#arg_idents),*], args: &[#(&#arg_idents),*],
hooks: &[#(&#hooks),*],
}; };
}); });
@ -256,3 +259,44 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
tokens.into() tokens.into()
} }
#[proc_macro_attribute]
pub fn check(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut fun = parse_macro_input!(input as CommandFun);
let n = fun.name.clone();
let name = n.with_suffix(HOOK);
let fn_name = n.with_suffix(CHECK);
let visibility = fun.visibility;
let cooked = fun.cooked;
let body = fun.body;
let ret = fun.ret;
populate_fut_lifetimes_on_refs(&mut fun.args);
let args = fun.args;
let hook_path = quote!(crate::framework::Hook);
let uuid = Uuid::new_v4().as_u128();
(quote! {
#(#cooked)*
#[allow(missing_docs)]
#visibility fn #fn_name<'fut>(#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> {
use ::serenity::futures::future::FutureExt;
async move {
let _output: #ret = { #(#body)* };
#[allow(unreachable_code)]
_output
}.boxed()
}
#(#cooked)*
#[allow(missing_docs)]
pub static #name: #hook_path = #hook_path {
fun: #fn_name,
uuid: #uuid,
};
})
.into()
}

View File

@ -4,7 +4,7 @@ use syn::{
braced, braced,
parse::{Error, Parse, ParseStream, Result}, parse::{Error, Parse, ParseStream, Result},
spanned::Spanned, spanned::Spanned,
Attribute, Block, FnArg, Ident, Pat, Stmt, Token, Visibility, Attribute, Block, FnArg, Ident, Pat, ReturnType, Stmt, Token, Type, Visibility,
}; };
use crate::util::{Argument, Parenthesised}; use crate::util::{Argument, Parenthesised};
@ -78,6 +78,7 @@ pub struct CommandFun {
pub visibility: Visibility, pub visibility: Visibility,
pub name: Ident, pub name: Ident,
pub args: Vec<Argument>, pub args: Vec<Argument>,
pub ret: Type,
pub body: Vec<Stmt>, pub body: Vec<Stmt>,
} }
@ -97,6 +98,11 @@ impl Parse for CommandFun {
// (...) // (...)
let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?; let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?;
let ret = match input.parse::<ReturnType>()? {
ReturnType::Type(_, t) => (*t).clone(),
ReturnType::Default => Type::Verbatim(quote!(())),
};
// { ... } // { ... }
let bcont; let bcont;
braced!(bcont in input); braced!(bcont in input);
@ -104,72 +110,23 @@ impl Parse for CommandFun {
let args = args.into_iter().map(parse_argument).collect::<Result<Vec<_>>>()?; let args = args.into_iter().map(parse_argument).collect::<Result<Vec<_>>>()?;
Ok(Self { attributes, cooked, visibility, name, args, body }) Ok(Self { attributes, cooked, visibility, name, args, ret, body })
} }
} }
impl ToTokens for CommandFun { impl ToTokens for CommandFun {
fn to_tokens(&self, stream: &mut TokenStream2) { fn to_tokens(&self, stream: &mut TokenStream2) {
let Self { attributes: _, cooked, visibility, name, args, body } = self; let Self { attributes: _, cooked, visibility, name, args, ret, body } = self;
stream.extend(quote! { stream.extend(quote! {
#(#cooked)* #(#cooked)*
#visibility async fn #name (#(#args),*) { #visibility async fn #name (#(#args),*) -> #ret {
#(#body)* #(#body)*
} }
}); });
} }
} }
#[derive(Debug)]
pub enum PermissionLevel {
Unrestricted,
Managed,
Restricted,
}
impl Default for PermissionLevel {
fn default() -> Self {
Self::Unrestricted
}
}
impl PermissionLevel {
pub fn from_str(s: &str) -> Option<Self> {
Some(match s.to_uppercase().as_str() {
"UNRESTRICTED" => Self::Unrestricted,
"MANAGED" => Self::Managed,
"RESTRICTED" => Self::Restricted,
_ => return None,
})
}
}
impl ToTokens for PermissionLevel {
fn to_tokens(&self, stream: &mut TokenStream2) {
let path = quote!(crate::framework::PermissionLevel);
let variant;
match self {
Self::Unrestricted => {
variant = quote!(Unrestricted);
}
Self::Managed => {
variant = quote!(Managed);
}
Self::Restricted => {
variant = quote!(Restricted);
}
}
stream.extend(quote! {
#path::#variant
});
}
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum ApplicationCommandOptionType { pub(crate) enum ApplicationCommandOptionType {
SubCommand, SubCommand,
@ -272,7 +229,6 @@ pub(crate) struct Options {
pub description: String, pub description: String,
pub group: String, pub group: String,
pub examples: Vec<String>, pub examples: Vec<String>,
pub required_permissions: PermissionLevel,
pub can_blacklist: bool, pub can_blacklist: bool,
pub supports_dm: bool, pub supports_dm: bool,
pub cmd_args: Vec<Arg>, pub cmd_args: Vec<Arg>,

11
migration/02-macro.sql Normal file
View File

@ -0,0 +1,11 @@
CREATE TABLE macro (
id INT UNSIGNED AUTO_INCREMENT,
guild_id INT UNSIGNED NOT NULL,
name VARCHAR(100) NOT NULL,
description VARCHAR(100),
commands TEXT,
FOREIGN KEY (guild_id) REFERENCES guilds(id),
PRIMARY KEY (id)
);

View File

@ -27,7 +27,7 @@ fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEm
#[aliases("invite")] #[aliases("invite")]
#[description("Get information about the bot")] #[description("Get information about the bot")]
#[group("Info")] #[group("Info")]
async fn info(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { async fn info(ctx: &Context, invoke: CommandInvoke) {
let prefix = ctx.prefix(invoke.guild_id()).await; let prefix = ctx.prefix(invoke.guild_id()).await;
let current_user = ctx.cache.current_user(); let current_user = ctx.cache.current_user();
let footer = footer(ctx); let footer = footer(ctx);
@ -61,7 +61,7 @@ Use our dashboard: https://reminder-bot.com/",
#[command] #[command]
#[description("Details on supporting the bot and Patreon benefits")] #[description("Details on supporting the bot and Patreon benefits")]
#[group("Info")] #[group("Info")]
async fn donate(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { async fn donate(ctx: &Context, invoke: CommandInvoke) {
let footer = footer(ctx); let footer = footer(ctx);
let _ = invoke let _ = invoke
@ -94,7 +94,7 @@ Just $2 USD/month!
#[command] #[command]
#[description("Get the link to the online dashboard")] #[description("Get the link to the online dashboard")]
#[group("Info")] #[group("Info")]
async fn dashboard(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { async fn dashboard(ctx: &Context, invoke: CommandInvoke) {
let footer = footer(ctx); let footer = footer(ctx);
let _ = invoke let _ = invoke
@ -113,7 +113,7 @@ async fn dashboard(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) {
#[command] #[command]
#[description("View the current time in your selected timezone")] #[description("View the current time in your selected timezone")]
#[group("Info")] #[group("Info")]
async fn clock(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { async fn clock(ctx: &Context, invoke: CommandInvoke) {
let ud = ctx.user_data(&invoke.author_id()).await.unwrap(); let ud = ctx.user_data(&invoke.author_id()).await.unwrap();
let now = Utc::now().with_timezone(&ud.timezone()); let now = Utc::now().with_timezone(&ud.timezone());

View File

@ -2,16 +2,21 @@ use chrono::offset::Utc;
use chrono_tz::{Tz, TZ_VARIANTS}; use chrono_tz::{Tz, TZ_VARIANTS};
use levenshtein::levenshtein; use levenshtein::levenshtein;
use regex_command_attr::command; use regex_command_attr::command;
use serenity::{client::Context, model::misc::Mentionable}; use serenity::{
client::Context,
model::{
interactions::InteractionResponseType, misc::Mentionable,
prelude::InteractionApplicationCommandCallbackDataFlags,
},
};
use crate::{ use crate::{
component_models::{ComponentDataModel, Restrict}, component_models::{ComponentDataModel, Restrict},
consts::THEME_COLOR, consts::THEME_COLOR,
framework::{ framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue},
CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue, PermissionLevel, hooks::{CHECK_GUILD_PERMISSIONS_HOOK, CHECK_MANAGED_PERMISSIONS_HOOK},
}, models::{channel_data::ChannelData, command_macro::CommandMacro, CtxData},
models::{channel_data::ChannelData, CtxData}, PopularTimezones, RecordingMacros, RegexFramework, SQLPool,
PopularTimezones, RegexFramework, SQLPool,
}; };
#[command("blacklist")] #[command("blacklist")]
@ -23,13 +28,9 @@ use crate::{
required = false required = false
)] )]
#[supports_dm(false)] #[supports_dm(false)]
#[required_permissions(Restricted)] #[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
#[can_blacklist(false)] #[can_blacklist(false)]
async fn blacklist( async fn blacklist(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: CommandOptions,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let channel = match args.get("channel") { let channel = match args.get("channel") {
@ -72,7 +73,7 @@ async fn blacklist(
kind = "String", kind = "String",
required = false required = false
)] )]
async fn timezone(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { async fn timezone(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap(); let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap();
@ -178,8 +179,8 @@ You may want to use one of the popular timezones below, otherwise click [here](h
#[command("prefix")] #[command("prefix")]
#[description("Configure a prefix for text-based commands (deprecated)")] #[description("Configure a prefix for text-based commands (deprecated)")]
#[supports_dm(false)] #[supports_dm(false)]
#[required_permissions(Restricted)] #[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: String) { async fn prefix(ctx: &Context, invoke: CommandInvoke, args: String) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap(); let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap();
@ -222,8 +223,8 @@ async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args:
required = true required = true
)] )]
#[supports_dm(false)] #[supports_dm(false)]
#[required_permissions(Restricted)] #[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { async fn restrict(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let framework = ctx.data.read().await.get::<RegexFramework>().cloned().unwrap(); let framework = ctx.data.read().await.get::<RegexFramework>().cloned().unwrap();
@ -240,7 +241,7 @@ async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), arg
let restrictable_commands = framework let restrictable_commands = framework
.commands .commands
.iter() .iter()
.filter(|c| c.required_permissions == PermissionLevel::Managed) .filter(|c| c.hooks.contains(&&CHECK_MANAGED_PERMISSIONS_HOOK))
.map(|c| c.names[0].to_string()) .map(|c| c.names[0].to_string())
.collect::<Vec<String>>(); .collect::<Vec<String>>();
@ -289,6 +290,132 @@ async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), arg
} }
} }
#[command("macro")]
#[description("Record and replay command sequences")]
#[subcommand("record")]
#[description("Start recording up to 5 commands to replay")]
#[arg(name = "name", description = "Name for the new macro", kind = "String", required = true)]
#[arg(
name = "description",
description = "Description for the new macro",
kind = "String",
required = false
)]
#[subcommand("finish")]
#[description("Finish current recording")]
#[subcommand("list")]
#[description("List recorded macros")]
#[subcommand("run")]
#[description("Run a recorded macro")]
#[arg(name = "name", description = "Name of the macro to run", kind = "String", required = true)]
#[supports_dm(false)]
#[hook(CHECK_MANAGED_PERMISSIONS_HOOK)]
async fn macro_cmd(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
let interaction = invoke.interaction().unwrap();
match args.subcommand.clone().unwrap().as_str() {
"record" => {
let macro_buffer = ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
{
let mut lock = macro_buffer.write().await;
let guild_id = interaction.guild_id.unwrap();
lock.insert(
(guild_id, interaction.user.id),
CommandMacro {
guild_id,
name: args.get("name").unwrap().to_string(),
description: args.get("description").map(|d| d.to_string()),
commands: vec![],
},
);
}
let _ = interaction
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
.create_embed(|e| {
e
.title("Macro Recording Started")
.description(
"Run up to 5 commands, or type `/macro finish` to stop at any point.
Any commands ran as part of recording will be inconsequential")
.color(*THEME_COLOR)
})
})
})
.await;
}
"finish" => {
let key = (interaction.guild_id.unwrap(), interaction.user.id);
let macro_buffer = ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
{
let lock = macro_buffer.read().await;
let contained = lock.get(&key);
if contained.map_or(true, |cmacro| cmacro.commands.is_empty()) {
let _ = interaction
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.create_embed(|e| {
e.title("No Macro Recorded")
.description(
"Use `/macro record` to start recording a macro",
)
.color(*THEME_COLOR)
})
})
})
.await;
} else {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let command_macro = contained.unwrap();
let json = serde_json::to_string(&command_macro.commands).unwrap();
sqlx::query!(
"INSERT INTO macro (guild_id, name, description, commands) VALUES ((SELECT id FROM guilds WHERE guild = ?), ?, ?, ?)",
command_macro.guild_id.0,
command_macro.name,
command_macro.description,
json
)
.execute(&pool)
.await
.unwrap();
let _ = interaction
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.create_embed(|e| {
e.title("Macro Recorded")
.description("Use `/macro run` to execute the macro")
.color(*THEME_COLOR)
})
})
})
.await;
}
}
{
let mut lock = macro_buffer.write().await;
lock.remove(&key);
}
}
"list" => {}
"run" => {}
_ => {}
}
}
/* /*
#[command("alias")] #[command("alias")]
#[supports_dm(false)] #[supports_dm(false)]

View File

@ -5,16 +5,13 @@ use std::{
}; };
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use chrono_tz::Tz;
use num_integer::Integer; use num_integer::Integer;
use regex_command_attr::command; use regex_command_attr::command;
use serenity::{ use serenity::{
builder::CreateEmbed, builder::{CreateEmbed, CreateInteractionResponse},
client::Context, client::Context,
model::{ model::{channel::Channel, interactions::InteractionResponseType},
channel::Channel,
id::{GuildId, UserId},
interactions::InteractionResponseType,
},
}; };
use crate::{ use crate::{
@ -28,6 +25,7 @@ use crate::{
REGEX_NATURAL_COMMAND_2, THEME_COLOR, REGEX_NATURAL_COMMAND_2, THEME_COLOR,
}, },
framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue}, framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue},
hooks::{CHECK_GUILD_PERMISSIONS_HOOK, CHECK_MANAGED_PERMISSIONS_HOOK},
models::{ models::{
channel_data::ChannelData, channel_data::ChannelData,
guild_data::GuildData, guild_data::GuildData,
@ -55,8 +53,8 @@ use crate::{
required = false required = false
)] )]
#[supports_dm(false)] #[supports_dm(false)]
#[required_permissions(Restricted)] #[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn pause(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { async fn pause(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await;
@ -141,8 +139,8 @@ async fn pause(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args:
kind = "Integer", kind = "Integer",
required = false required = false
)] )]
#[required_permissions(Restricted)] #[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn offset(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { async fn offset(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let combined_time = args.get("hours").map_or(0, |h| h.as_i64().unwrap() * 3600) let combined_time = args.get("hours").map_or(0, |h| h.as_i64().unwrap() * 3600)
@ -218,8 +216,8 @@ WHERE FIND_IN_SET(channels.`channel`, ?)",
kind = "Integer", kind = "Integer",
required = false required = false
)] )]
#[required_permissions(Restricted)] #[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn nudge(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { async fn nudge(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let combined_time = args.get("minutes").map_or(0, |m| m.as_i64().unwrap() * 60) let combined_time = args.get("minutes").map_or(0, |m| m.as_i64().unwrap() * 60)
@ -270,8 +268,8 @@ async fn nudge(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args:
kind = "Boolean", kind = "Boolean",
required = false required = false
)] )]
#[required_permissions(Managed)] #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)]
async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { async fn look(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await;
@ -363,103 +361,132 @@ async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: C
#[command("del")] #[command("del")]
#[description("Delete reminders")] #[description("Delete reminders")]
#[required_permissions(Managed)] #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)]
async fn delete(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { async fn delete(ctx: &Context, invoke: CommandInvoke, _args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let interaction = invoke.interaction().unwrap();
let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; let timezone = ctx.timezone(interaction.user.id).await;
let reminders = Reminder::from_guild(ctx, invoke.guild_id(), invoke.author_id()).await; let reminders = Reminder::from_guild(ctx, interaction.guild_id, interaction.user.id).await;
if reminders.is_empty() { let resp = show_delete_page(&reminders, 0, timezone).await;
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content("No reminders to delete!"),
)
.await;
} else {
let mut char_count = 0;
let (shown_reminders, display_vec): (Vec<&Reminder>, Vec<String>) = reminders let _ = interaction
.iter() .create_interaction_response(&ctx, |r| {
.enumerate() *r = resp;
.map(|(count, reminder)| (reminder, reminder.display_del(count, &timezone))) r
.take_while(|(_, p)| { })
char_count += p.len(); .await;
char_count < EMBED_DESCRIPTION_MAX_LENGTH
})
.unzip();
let display = display_vec.join("\n");
let pages = reminders
.iter()
.enumerate()
.map(|(count, reminder)| reminder.display_del(count, &timezone))
.fold(0, |t, r| t + r.len())
.div_ceil(EMBED_DESCRIPTION_MAX_LENGTH);
let pager = DelPager::new(timezone);
let del_selector = ComponentDataModel::DelSelector(DelSelector { page: 0, timezone });
invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new()
.embed(|e| {
e.title("Delete Reminders")
.description(display)
.footer(|f| f.text(format!("Page {} of {}", 1, pages)))
.color(*THEME_COLOR)
})
.components(|comp| {
pager.create_button_row(pages, comp);
comp.create_action_row(|row| {
row.create_select_menu(|menu| {
menu.custom_id(del_selector.to_custom_id()).options(|opt| {
for (count, reminder) in shown_reminders.iter().enumerate() {
opt.create_option(|o| {
o.label(count + 1).value(reminder.id).description({
let c = reminder.display_content();
if c.len() > 100 {
format!(
"{}...",
reminder
.display_content()
.chars()
.take(97)
.collect::<String>()
)
} else {
c.to_string()
}
})
});
}
opt
})
})
})
}),
)
.await
.unwrap();
}
} }
async fn show_delete_page( pub fn max_delete_page(reminders: &Vec<Reminder>, timezone: &Tz) -> usize {
ctx: &Context, reminders
guild_id: Option<GuildId>, .iter()
user_id: UserId, .enumerate()
.map(|(count, reminder)| reminder.display_del(count, timezone))
.fold(0, |t, r| t + r.len())
.div_ceil(EMBED_DESCRIPTION_MAX_LENGTH)
}
pub async fn show_delete_page(
reminders: &Vec<Reminder>,
page: usize, page: usize,
timezone: Tz, timezone: Tz,
) { ) -> CreateInteractionResponse {
let pager = DelPager::new(timezone);
if reminders.is_empty() {
let mut embed = CreateEmbed::default();
embed.title("Delete Reminders").description("No Reminders").color(*THEME_COLOR);
let mut response = CreateInteractionResponse::default();
response.kind(InteractionResponseType::UpdateMessage).interaction_response_data(
|response| {
response.embeds(vec![embed]).components(|comp| {
pager.create_button_row(0, comp);
comp
})
},
);
return response;
}
let pages = max_delete_page(&reminders, &timezone);
let mut page = page;
if page >= pages {
page = pages - 1;
}
let mut char_count = 0;
let mut skip_char_count = 0;
let mut first_num = 0;
let (shown_reminders, display_vec): (Vec<&Reminder>, Vec<String>) = reminders
.iter()
.enumerate()
.map(|(count, reminder)| (reminder, reminder.display_del(count, &timezone)))
.skip_while(|(_, p)| {
first_num += 1;
skip_char_count += p.len();
skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * page
})
.take_while(|(_, p)| {
char_count += p.len();
char_count < EMBED_DESCRIPTION_MAX_LENGTH
})
.unzip();
let display = display_vec.join("\n");
let del_selector = ComponentDataModel::DelSelector(DelSelector { page, timezone });
let mut embed = CreateEmbed::default();
embed
.title("Delete Reminders")
.description(display)
.footer(|f| f.text(format!("Page {} of {}", page + 1, pages)))
.color(*THEME_COLOR);
let mut response = CreateInteractionResponse::default();
response.kind(InteractionResponseType::UpdateMessage).interaction_response_data(|d| {
d.embeds(vec![embed]).components(|comp| {
pager.create_button_row(pages, comp);
comp.create_action_row(|row| {
row.create_select_menu(|menu| {
menu.custom_id(del_selector.to_custom_id()).options(|opt| {
for (count, reminder) in shown_reminders.iter().enumerate() {
opt.create_option(|o| {
o.label(count + first_num).value(reminder.id).description({
let c = reminder.display_content();
if c.len() > 100 {
format!(
"{}...",
reminder
.display_content()
.chars()
.take(97)
.collect::<String>()
)
} else {
c.to_string()
}
})
});
}
opt
})
})
})
})
});
response
} }
#[command("timer")] #[command("timer")]
@ -472,8 +499,8 @@ async fn show_delete_page(
#[subcommand("delete")] #[subcommand("delete")]
#[description("Delete a timer")] #[description("Delete a timer")]
#[arg(name = "name", description = "Name of the timer to delete", kind = "String", required = true)] #[arg(name = "name", description = "Name of the timer to delete", kind = "String", required = true)]
#[required_permissions(Managed)] #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)]
async fn timer(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { async fn timer(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
fn time_difference(start_time: NaiveDateTime) -> String { fn time_difference(start_time: NaiveDateTime) -> String {
let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
let now = NaiveDateTime::from_timestamp(unix_time, 0); let now = NaiveDateTime::from_timestamp(unix_time, 0);
@ -638,8 +665,8 @@ DELETE FROM timers WHERE owner = ? AND name = ?
kind = "Boolean", kind = "Boolean",
required = false required = false
)] )]
#[required_permissions(Managed)] #[hook(CHECK_MANAGED_PERMISSIONS_HOOK)]
async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { async fn remind(ctx: &Context, invoke: CommandInvoke, args: CommandOptions) {
let interaction = invoke.interaction().unwrap(); let interaction = invoke.interaction().unwrap();
// defer response since processing times can take some time // defer response since processing times can take some time
@ -650,7 +677,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args:
.await .await
.unwrap(); .unwrap();
let user_data = ctx.user_data(invoke.author_id()).await.unwrap(); let user_data = ctx.user_data(interaction.user.id).await.unwrap();
let timezone = user_data.timezone(); let timezone = user_data.timezone();
let time = { let time = {
@ -675,7 +702,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args:
.unwrap_or(vec![]); .unwrap_or(vec![]);
if list.is_empty() { if list.is_empty() {
vec![ReminderScope::Channel(invoke.channel_id().0)] vec![ReminderScope::Channel(interaction.channel_id.0)]
} else { } else {
list list
} }
@ -698,7 +725,7 @@ async fn remind(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args:
} }
}; };
let mut builder = MultiReminderBuilder::new(ctx, invoke.guild_id()) let mut builder = MultiReminderBuilder::new(ctx, interaction.guild_id)
.author(user_data) .author(user_data)
.content(content) .content(content)
.time(time) .time(time)

View File

@ -17,6 +17,7 @@ use serenity::{
}; };
use crate::{ use crate::{
commands::reminder_cmds::{max_delete_page, show_delete_page},
component_models::pager::{DelPager, LookPager, Pager}, component_models::pager::{DelPager, LookPager, Pager},
consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR}, consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR},
models::reminder::Reminder, models::reminder::Reminder,
@ -165,98 +166,15 @@ INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id
let reminders = let reminders =
Reminder::from_guild(ctx, component.guild_id, component.user.id).await; Reminder::from_guild(ctx, component.guild_id, component.user.id).await;
let pages = reminders let max_pages = max_delete_page(&reminders, &pager.timezone);
.iter()
.enumerate()
.map(|(count, reminder)| reminder.display_del(count, &pager.timezone))
.fold(0, |t, r| t + r.len())
.div_ceil(EMBED_DESCRIPTION_MAX_LENGTH);
let next_page = pager.next_page(pages); let resp =
show_delete_page(&reminders, pager.next_page(max_pages), pager.timezone).await;
let mut char_count = 0; let _ = component
let mut skip_char_count = 0; .create_interaction_response(&ctx, move |r| {
let mut first_num = 0; *r = resp;
r
let (shown_reminders, display_vec): (Vec<&Reminder>, Vec<String>) = reminders
.iter()
.enumerate()
.map(|(count, reminder)| {
(reminder, reminder.display_del(count, &pager.timezone))
})
.skip_while(|(_, p)| {
first_num += 1;
skip_char_count += p.len();
skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * next_page
})
.take_while(|(_, p)| {
char_count += p.len();
char_count < EMBED_DESCRIPTION_MAX_LENGTH
})
.unzip();
let display = display_vec.join("\n");
let del_selector = ComponentDataModel::DelSelector(DelSelector {
page: next_page,
timezone: pager.timezone,
});
let mut embed = CreateEmbed::default();
embed
.title("Delete Reminders")
.description(display)
.footer(|f| f.text(format!("Page {} of {}", next_page + 1, pages)))
.color(*THEME_COLOR);
component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::UpdateMessage).interaction_response_data(
|response| {
response.embeds(vec![embed]).components(|comp| {
pager.create_button_row(pages, comp);
comp.create_action_row(|row| {
row.create_select_menu(|menu| {
menu.custom_id(del_selector.to_custom_id()).options(
|opt| {
for (count, reminder) in
shown_reminders.iter().enumerate()
{
opt.create_option(|o| {
o.label(count + first_num)
.value(reminder.id)
.description({
let c =
reminder.display_content();
if c.len() > 100 {
format!(
"{}...",
reminder
.display_content()
.chars()
.take(97)
.collect::<String>(
)
)
} else {
c.to_string()
}
})
});
}
opt
},
)
})
})
})
},
)
}) })
.await; .await;
} }
@ -272,119 +190,12 @@ INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id
let reminders = let reminders =
Reminder::from_guild(ctx, component.guild_id, component.user.id).await; Reminder::from_guild(ctx, component.guild_id, component.user.id).await;
if reminders.is_empty() { let resp = show_delete_page(&reminders, selector.page, selector.timezone).await;
let mut embed = CreateEmbed::default();
embed.title("Delete Reminders").description("No Reminders").color(*THEME_COLOR);
component let _ = component
.create_interaction_response(&ctx, |r| { .create_interaction_response(&ctx, move |r| {
r.kind(InteractionResponseType::UpdateMessage) *r = resp;
.interaction_response_data(|response| { r
response.embeds(vec![embed]).components(|comp| comp)
})
})
.await;
return;
}
let pages = reminders
.iter()
.enumerate()
.map(|(count, reminder)| reminder.display_del(count, &selector.timezone))
.fold(0, |t, r| t + r.len())
.div_ceil(EMBED_DESCRIPTION_MAX_LENGTH);
let mut page = selector.page;
if page >= pages {
page = pages - 1;
}
let mut char_count = 0;
let mut skip_char_count = 0;
let mut first_num = 0;
let (shown_reminders, display_vec): (Vec<&Reminder>, Vec<String>) = reminders
.iter()
.enumerate()
.map(|(count, reminder)| {
(reminder, reminder.display_del(count, &selector.timezone))
})
.skip_while(|(_, p)| {
first_num += 1;
skip_char_count += p.len();
skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * page
})
.take_while(|(_, p)| {
char_count += p.len();
char_count < EMBED_DESCRIPTION_MAX_LENGTH
})
.unzip();
let display = display_vec.join("\n");
let pager = DelPager::new(selector.timezone);
let del_selector = ComponentDataModel::DelSelector(DelSelector {
page,
timezone: selector.timezone,
});
let mut embed = CreateEmbed::default();
embed
.title("Delete Reminders")
.description(display)
.footer(|f| f.text(format!("Page {} of {}", page + 1, pages)))
.color(*THEME_COLOR);
component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::UpdateMessage).interaction_response_data(
|response| {
response.embeds(vec![embed]).components(|comp| {
pager.create_button_row(pages, comp);
comp.create_action_row(|row| {
row.create_select_menu(|menu| {
menu.custom_id(del_selector.to_custom_id()).options(
|opt| {
for (count, reminder) in
shown_reminders.iter().enumerate()
{
opt.create_option(|o| {
o.label(count + first_num)
.value(reminder.id)
.description({
let c =
reminder.display_content();
if c.len() > 100 {
format!(
"{}...",
reminder
.display_content()
.chars()
.take(97)
.collect::<String>(
)
)
} else {
c.to_string()
}
})
});
}
opt
},
)
})
})
})
},
)
}) })
.await; .await;
} }

View File

@ -6,8 +6,9 @@ use std::{
sync::Arc, sync::Arc,
}; };
use log::{error, info, warn}; use log::info;
use regex::{Match, Regex, RegexBuilder}; use regex::{Match, Regex, RegexBuilder};
use serde::{Deserialize, Serialize};
use serenity::{ use serenity::{
async_trait, async_trait,
builder::{CreateApplicationCommands, CreateComponents, CreateEmbed}, builder::{CreateApplicationCommands, CreateComponents, CreateEmbed},
@ -15,9 +16,9 @@ use serenity::{
client::Context, client::Context,
framework::Framework, framework::Framework,
futures::prelude::future::BoxFuture, futures::prelude::future::BoxFuture,
http::Http, http::{CacheHttp, Http},
model::{ model::{
channel::{Channel, GuildChannel, Message}, channel::Message,
guild::{Guild, Member}, guild::{Guild, Member},
id::{ChannelId, GuildId, MessageId, RoleId, UserId}, id::{ChannelId, GuildId, MessageId, RoleId, UserId},
interactions::{ interactions::{
@ -29,20 +30,10 @@ use serenity::{
prelude::application_command::ApplicationCommandInteractionDataOption, prelude::application_command::ApplicationCommandInteractionDataOption,
}, },
prelude::TypeMapKey, prelude::TypeMapKey,
FutureExt, Result as SerenityResult, Result as SerenityResult,
}; };
use crate::{ use crate::{models::CtxData, LimitExecutors};
models::{channel_data::ChannelData, CtxData},
LimitExecutors, SQLPool,
};
#[derive(Debug, PartialEq)]
pub enum PermissionLevel {
Unrestricted,
Managed,
Restricted,
}
pub struct CreateGenericResponse { pub struct CreateGenericResponse {
content: String, content: String,
@ -81,196 +72,135 @@ impl CreateGenericResponse {
} }
} }
#[async_trait] enum InvokeModel {
pub trait CommandInvoke { Slash(ApplicationCommandInteraction),
fn channel_id(&self) -> ChannelId; Text(Message),
fn guild_id(&self) -> Option<GuildId>;
fn guild(&self, cache: Arc<Cache>) -> Option<Guild>;
fn author_id(&self) -> UserId;
async fn member(&self, context: &Context) -> SerenityResult<Member>;
fn msg(&self) -> Option<Message>;
fn interaction(&self) -> Option<ApplicationCommandInteraction>;
async fn respond(
&self,
http: Arc<Http>,
generic_response: CreateGenericResponse,
) -> SerenityResult<()>;
async fn followup(
&self,
http: Arc<Http>,
generic_response: CreateGenericResponse,
) -> SerenityResult<()>;
} }
#[async_trait] pub struct CommandInvoke {
impl CommandInvoke for Message { model: InvokeModel,
fn channel_id(&self) -> ChannelId { already_responded: bool,
self.channel_id
}
fn guild_id(&self) -> Option<GuildId> {
self.guild_id
}
fn guild(&self, cache: Arc<Cache>) -> Option<Guild> {
self.guild(cache)
}
fn author_id(&self) -> UserId {
self.author.id
}
async fn member(&self, context: &Context) -> SerenityResult<Member> {
self.member(context).await
}
fn msg(&self) -> Option<Message> {
Some(self.clone())
}
fn interaction(&self) -> Option<ApplicationCommandInteraction> {
None
}
async fn respond(
&self,
http: Arc<Http>,
generic_response: CreateGenericResponse,
) -> SerenityResult<()> {
self.channel_id
.send_message(http, |m| {
m.content(generic_response.content);
if let Some(embed) = generic_response.embed {
m.set_embed(embed.clone());
}
if let Some(components) = generic_response.components {
m.components(|c| {
*c = components;
c
});
}
m
})
.await
.map(|_| ())
}
async fn followup(
&self,
http: Arc<Http>,
generic_response: CreateGenericResponse,
) -> SerenityResult<()> {
self.channel_id
.send_message(http, |m| {
m.content(generic_response.content);
if let Some(embed) = generic_response.embed {
m.set_embed(embed.clone());
}
if let Some(components) = generic_response.components {
m.components(|c| {
*c = components;
c
});
}
m
})
.await
.map(|_| ())
}
} }
#[async_trait] impl CommandInvoke {
impl CommandInvoke for ApplicationCommandInteraction { fn slash(interaction: ApplicationCommandInteraction) -> Self {
fn channel_id(&self) -> ChannelId { Self { model: InvokeModel::Slash(interaction), already_responded: false }
self.channel_id
} }
fn guild_id(&self) -> Option<GuildId> { fn msg(msg: Message) -> Self {
self.guild_id Self { model: InvokeModel::Text(msg), already_responded: false }
} }
fn guild(&self, cache: Arc<Cache>) -> Option<Guild> { pub fn interaction(self) -> Option<ApplicationCommandInteraction> {
if let Some(guild_id) = self.guild_id { match self.model {
guild_id.to_guild_cached(cache) InvokeModel::Slash(i) => Some(i),
} else { InvokeModel::Text(_) => None,
None
} }
} }
fn author_id(&self) -> UserId { pub fn channel_id(&self) -> ChannelId {
self.member.as_ref().unwrap().user.id match &self.model {
InvokeModel::Slash(i) => i.channel_id,
InvokeModel::Text(m) => m.channel_id,
}
} }
async fn member(&self, _: &Context) -> SerenityResult<Member> { pub fn guild_id(&self) -> Option<GuildId> {
Ok(self.member.clone().unwrap()) match &self.model {
InvokeModel::Slash(i) => i.guild_id,
InvokeModel::Text(m) => m.guild_id,
}
} }
fn msg(&self) -> Option<Message> { pub fn guild(&self, cache: impl AsRef<Cache>) -> Option<Guild> {
None self.guild_id().map(|id| id.to_guild_cached(cache)).flatten()
} }
fn interaction(&self) -> Option<ApplicationCommandInteraction> { pub fn author_id(&self) -> UserId {
Some(self.clone()) match &self.model {
InvokeModel::Slash(i) => i.user.id,
InvokeModel::Text(m) => m.author.id,
}
} }
async fn respond( pub async fn member(&self, cache_http: impl CacheHttp) -> Option<Member> {
match &self.model {
InvokeModel::Slash(i) => i.member.clone(),
InvokeModel::Text(m) => m.member(cache_http).await.ok(),
}
}
pub async fn respond(
&self, &self,
http: Arc<Http>, http: impl AsRef<Http>,
generic_response: CreateGenericResponse, generic_response: CreateGenericResponse,
) -> SerenityResult<()> { ) -> SerenityResult<()> {
self.create_interaction_response(http, |r| { match &self.model {
r.kind(InteractionResponseType::ChannelMessageWithSource).interaction_response_data( InvokeModel::Slash(i) => {
|d| { if !self.already_responded {
d.content(generic_response.content); i.create_interaction_response(http, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.content(generic_response.content);
if let Some(embed) = generic_response.embed {
d.add_embed(embed.clone());
}
if let Some(components) = generic_response.components {
d.components(|c| {
*c = components;
c
});
}
d
})
})
.await
.map(|_| ())
} else {
i.create_followup_message(http, |d| {
d.content(generic_response.content);
if let Some(embed) = generic_response.embed {
d.add_embed(embed.clone());
}
if let Some(components) = generic_response.components {
d.components(|c| {
*c = components;
c
});
}
d
})
.await
.map(|_| ())
}
}
InvokeModel::Text(m) => m
.channel_id
.send_message(http, |m| {
m.content(generic_response.content);
if let Some(embed) = generic_response.embed { if let Some(embed) = generic_response.embed {
d.add_embed(embed.clone()); m.set_embed(embed.clone());
} }
if let Some(components) = generic_response.components { if let Some(components) = generic_response.components {
d.components(|c| { m.components(|c| {
*c = components; *c = components;
c c
}); });
} }
d m
}, })
) .await
}) .map(|_| ()),
.await }
.map(|_| ())
}
async fn followup(
&self,
http: Arc<Http>,
generic_response: CreateGenericResponse,
) -> SerenityResult<()> {
self.create_followup_message(http, |d| {
d.content(generic_response.content);
if let Some(embed) = generic_response.embed {
d.add_embed(embed.clone());
}
if let Some(components) = generic_response.components {
d.components(|c| {
*c = components;
c
});
}
d
})
.await
.map(|_| ())
} }
} }
@ -283,6 +213,7 @@ pub struct Arg {
pub options: &'static [&'static Self], pub options: &'static [&'static Self],
} }
#[derive(Serialize, Deserialize, Clone)]
pub enum OptionValue { pub enum OptionValue {
String(String), String(String),
Integer(i64), Integer(i64),
@ -330,8 +261,9 @@ impl OptionValue {
} }
} }
#[derive(Serialize, Deserialize, Clone)]
pub struct CommandOptions { pub struct CommandOptions {
pub command: String, pub command: &'static str,
pub subcommand: Option<String>, pub subcommand: Option<String>,
pub subcommand_group: Option<String>, pub subcommand_group: Option<String>,
pub options: HashMap<String, OptionValue>, pub options: HashMap<String, OptionValue>,
@ -343,8 +275,17 @@ impl CommandOptions {
} }
} }
impl From<ApplicationCommandInteraction> for CommandOptions { impl CommandOptions {
fn from(interaction: ApplicationCommandInteraction) -> Self { fn new(command: &'static Command) -> Self {
Self {
command: command.names[0],
subcommand: None,
subcommand_group: None,
options: Default::default(),
}
}
fn populate(mut self, interaction: &ApplicationCommandInteraction) -> Self {
fn match_option( fn match_option(
option: ApplicationCommandInteractionDataOption, option: ApplicationCommandInteractionDataOption,
cmd_opts: &mut CommandOptions, cmd_opts: &mut CommandOptions,
@ -429,35 +370,31 @@ impl From<ApplicationCommandInteraction> for CommandOptions {
} }
} }
let mut cmd_opts = Self { for option in &interaction.data.options {
command: interaction.data.name, match_option(option.clone(), &mut self)
subcommand: None,
subcommand_group: None,
options: Default::default(),
};
for option in interaction.data.options {
match_option(option, &mut cmd_opts)
} }
cmd_opts self
} }
} }
type SlashCommandFn = for<'fut> fn( pub enum HookResult {
&'fut Context, Continue,
&'fut (dyn CommandInvoke + Sync + Send), Halt,
CommandOptions, }
) -> BoxFuture<'fut, ()>;
type TextCommandFn = for<'fut> fn( type SlashCommandFn =
&'fut Context, for<'fut> fn(&'fut Context, CommandInvoke, CommandOptions) -> BoxFuture<'fut, ()>;
&'fut (dyn CommandInvoke + Sync + Send),
String,
) -> BoxFuture<'fut, ()>;
type MultiCommandFn = type TextCommandFn = for<'fut> fn(&'fut Context, CommandInvoke, String) -> BoxFuture<'fut, ()>;
for<'fut> fn(&'fut Context, &'fut (dyn CommandInvoke + Sync + Send)) -> BoxFuture<'fut, ()>;
type MultiCommandFn = for<'fut> fn(&'fut Context, CommandInvoke) -> BoxFuture<'fut, ()>;
pub type HookFn = for<'fut> fn(
&'fut Context,
&'fut CommandInvoke,
&'fut CommandOptions,
) -> BoxFuture<'fut, HookResult>;
pub enum CommandFnType { pub enum CommandFnType {
Slash(SlashCommandFn), Slash(SlashCommandFn),
@ -474,6 +411,17 @@ impl CommandFnType {
} }
} }
pub struct Hook {
pub fun: HookFn,
pub uuid: u128,
}
impl PartialEq for Hook {
fn eq(&self, other: &Self) -> bool {
self.uuid == other.uuid
}
}
pub struct Command { pub struct Command {
pub fun: CommandFnType, pub fun: CommandFnType,
@ -483,11 +431,12 @@ pub struct Command {
pub examples: &'static [&'static str], pub examples: &'static [&'static str],
pub group: &'static str, pub group: &'static str,
pub required_permissions: PermissionLevel,
pub args: &'static [&'static Arg], pub args: &'static [&'static Arg],
pub can_blacklist: bool, pub can_blacklist: bool,
pub supports_dm: bool, pub supports_dm: bool,
pub hooks: &'static [&'static Hook],
} }
impl Hash for Command { impl Hash for Command {
@ -504,81 +453,6 @@ impl PartialEq for Command {
impl Eq for Command {} impl Eq for Command {}
impl Command {
async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
if self.required_permissions == PermissionLevel::Unrestricted {
true
} else {
let permissions = guild.member_permissions(&ctx, &member.user).await.unwrap();
if permissions.manage_guild()
|| (permissions.manage_messages()
&& self.required_permissions == PermissionLevel::Managed)
{
return true;
}
if self.required_permissions == PermissionLevel::Managed {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
match sqlx::query!(
"
SELECT
role
FROM
roles
INNER JOIN
command_restrictions ON roles.id = command_restrictions.role_id
WHERE
command_restrictions.command = ? AND
roles.guild_id = (
SELECT
id
FROM
guilds
WHERE
guild = ?)
",
self.names[0],
guild.id.as_u64()
)
.fetch_all(&pool)
.await
{
Ok(rows) => {
let role_ids =
member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>();
for row in rows {
if role_ids.contains(&row.role) {
return true;
}
}
false
}
Err(sqlx::Error::RowNotFound) => false,
Err(e) => {
warn!("Unexpected error occurred querying command_restrictions: {:?}", e);
false
}
}
} else {
false
}
}
}
}
pub struct RegexFramework { pub struct RegexFramework {
pub commands_map: HashMap<String, &'static Command>, pub commands_map: HashMap<String, &'static Command>,
pub commands: HashSet<&'static Command>, pub commands: HashSet<&'static Command>,
@ -589,23 +463,14 @@ pub struct RegexFramework {
ignore_bots: bool, ignore_bots: bool,
case_insensitive: bool, case_insensitive: bool,
dm_enabled: bool, dm_enabled: bool,
default_text_fun: TextCommandFn,
debug_guild: Option<GuildId>, debug_guild: Option<GuildId>,
hooks: Vec<&'static Hook>,
} }
impl TypeMapKey for RegexFramework { impl TypeMapKey for RegexFramework {
type Value = Arc<RegexFramework>; type Value = Arc<RegexFramework>;
} }
fn drop_text<'fut>(
_: &'fut Context,
_: &'fut (dyn CommandInvoke + Sync + Send),
_: String,
) -> std::pin::Pin<std::boxed::Box<(dyn std::future::Future<Output = ()> + std::marker::Send + 'fut)>>
{
async move {}.boxed()
}
impl RegexFramework { impl RegexFramework {
pub fn new<T: Into<u64>>(client_id: T) -> Self { pub fn new<T: Into<u64>>(client_id: T) -> Self {
Self { Self {
@ -618,8 +483,8 @@ impl RegexFramework {
ignore_bots: true, ignore_bots: true,
case_insensitive: true, case_insensitive: true,
dm_enabled: true, dm_enabled: true,
default_text_fun: drop_text,
debug_guild: None, debug_guild: None,
hooks: vec![],
} }
} }
@ -647,6 +512,12 @@ impl RegexFramework {
self self
} }
pub fn add_hook(mut self, fun: &'static Hook) -> Self {
self.hooks.push(fun);
self
}
pub fn add_command(mut self, command: &'static Command) -> Self { pub fn add_command(mut self, command: &'static Command) -> Self {
self.commands.insert(command); self.commands.insert(command);
@ -791,77 +662,46 @@ impl RegexFramework {
.expect(&format!("Received invalid command: {}", interaction.data.name)) .expect(&format!("Received invalid command: {}", interaction.data.name))
}; };
let guild = interaction.guild(ctx.cache.clone()).unwrap(); let args = CommandOptions::new(command).populate(&interaction);
let member = interaction.clone().member.unwrap(); let command_invoke = CommandInvoke::slash(interaction);
if command.check_permissions(&ctx, &guild, &member).await { for hook in command.hooks {
let args = CommandOptions::from(interaction.clone()); match (hook.fun)(&ctx, &command_invoke, &args).await {
HookResult::Continue => {}
if !ctx.check_executing(interaction.author_id()).await { HookResult::Halt => {
ctx.set_executing(interaction.author_id()).await; return;
match command.fun {
CommandFnType::Slash(t) => t(&ctx, &interaction, args).await,
CommandFnType::Multi(m) => m(&ctx, &interaction).await,
_ => (),
} }
ctx.drop_executing(interaction.author_id()).await;
} }
} else if command.required_permissions == PermissionLevel::Restricted { }
let _ = interaction
.respond( for hook in &self.hooks {
ctx.http.clone(), match (hook.fun)(&ctx, &command_invoke, &args).await {
CreateGenericResponse::new().content( HookResult::Continue => {}
"You must have the `Manage Server` permission to use this command.", HookResult::Halt => {
), return;
) }
.await; }
} else if command.required_permissions == PermissionLevel::Managed { }
let _ = interaction
.respond( let user_id = command_invoke.author_id();
ctx.http.clone(),
CreateGenericResponse::new().content( if !ctx.check_executing(user_id).await {
"You must have `Manage Messages` or have a role capable of sending reminders to that channel. \ ctx.set_executing(user_id).await;
Please talk to your server admin, and ask them to use the `/restrict` command to specify \
allowed roles.", match command.fun {
), CommandFnType::Slash(t) => t(&ctx, command_invoke, args).await,
) CommandFnType::Multi(m) => m(&ctx, command_invoke).await,
.await; _ => (),
}
ctx.drop_executing(user_id).await;
} }
} }
} }
enum PermissionCheck {
None, // No permissions
Basic(bool, bool), // Send + Embed permissions (sufficient to reply)
All, // Above + Manage Webhooks (sufficient to operate)
}
#[async_trait] #[async_trait]
impl Framework for RegexFramework { impl Framework for RegexFramework {
async fn dispatch(&self, ctx: Context, msg: Message) { async fn dispatch(&self, ctx: Context, msg: Message) {
async fn check_self_permissions(
ctx: &Context,
guild: &Guild,
channel: &GuildChannel,
) -> SerenityResult<PermissionCheck> {
let user_id = ctx.cache.current_user_id();
let guild_perms = guild.member_permissions(&ctx, user_id).await?;
let channel_perms = channel.permissions_for_user(ctx, user_id)?;
let basic_perms = channel_perms.send_messages();
Ok(if basic_perms && guild_perms.manage_webhooks() && channel_perms.embed_links() {
PermissionCheck::All
} else if basic_perms {
PermissionCheck::Basic(guild_perms.manage_webhooks(), channel_perms.embed_links())
} else {
PermissionCheck::None
})
}
async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool { async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool {
if let Some(prefix) = prefix_opt { if let Some(prefix) = prefix_opt {
let guild_prefix = ctx.prefix(Some(guild.id)).await; let guild_prefix = ctx.prefix(Some(guild.id)).await;
@ -874,144 +714,65 @@ impl Framework for RegexFramework {
// gate to prevent analysing messages unnecessarily // gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() { if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() {
} else { return;
// Guild Command }
if let (Some(guild), Ok(Channel::Guild(channel))) =
(msg.guild(&ctx), msg.channel(&ctx).await)
{
let data = ctx.data.read().await;
let pool = data.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); let user_id = msg.author.id;
let invoke = CommandInvoke::msg(msg.clone());
if let Some(full_match) = self.command_matcher.captures(&msg.content) { // Guild Command
if check_prefix(&ctx, &guild, full_match.name("prefix")).await { if let Some(guild) = msg.guild(&ctx) {
match check_self_permissions(&ctx, &guild, &channel).await { if let Some(full_match) = self.command_matcher.captures(&msg.content) {
Ok(perms) => match perms { if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
PermissionCheck::All => {
let command = self
.commands_map
.get(
&full_match
.name("cmd")
.unwrap()
.as_str()
.to_lowercase(),
)
.unwrap();
let channel = msg.channel(&ctx).await.unwrap();
let channel_data =
ChannelData::from_channel(&channel, &pool).await.unwrap();
if !command.can_blacklist || !channel_data.blacklisted {
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
let member = guild.member(&ctx, &msg.author).await.unwrap();
if command.check_permissions(&ctx, &guild, &member).await {
if msg.id == MessageId(0)
|| !ctx.check_executing(msg.author.id).await
{
ctx.set_executing(msg.author.id).await;
match command.fun {
CommandFnType::Text(t) => t(&ctx, &msg, args),
CommandFnType::Multi(m) => m(&ctx, &msg),
_ => (self.default_text_fun)(&ctx, &msg, args),
}
.await;
ctx.drop_executing(msg.author.id).await;
}
} else if command.required_permissions
== PermissionLevel::Restricted
{
let _ = msg
.channel_id
.say(
&ctx,
"You must have the `Manage Server` permission to use this command.",
)
.await;
} else if command.required_permissions
== PermissionLevel::Managed
{
let _ = msg
.channel_id
.say(
&ctx,
"You must have `Manage Messages` or have a role capable of \
sending reminders to that channel. Please talk to your server \
admin, and ask them to use the `/restrict` command to specify \
allowed roles.",
)
.await;
}
}
}
PermissionCheck::Basic(manage_webhooks, embed_links) => {
let _ = msg
.channel_id
.say(
&ctx,
format!(
"Please ensure the bot has the correct permissions:
**Send Message**
{} **Embed Links**
{} **Manage Webhooks**",
if manage_webhooks { "" } else { "" },
if embed_links { "" } else { "" },
),
)
.await;
}
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
},
Err(e) => {
error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
}
}
}
}
}
// DM Command
else if self.dm_enabled {
if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
let command = self let command = self
.commands_map .commands_map
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) .get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.unwrap(); .unwrap();
let args =
full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string();
if msg.id == MessageId(0) || !ctx.check_executing(msg.author.id).await { let channel_data = ctx.channel_data(invoke.channel_id()).await.unwrap();
ctx.set_executing(msg.author.id).await;
match command.fun { if !command.can_blacklist || !channel_data.blacklisted {
CommandFnType::Text(t) => t(&ctx, &msg, args), let args =
CommandFnType::Multi(m) => m(&ctx, &msg), full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string();
_ => (self.default_text_fun)(&ctx, &msg, args),
if msg.id == MessageId(0) || !ctx.check_executing(user_id).await {
ctx.set_executing(user_id).await;
match command.fun {
CommandFnType::Text(t) => t(&ctx, invoke, args).await,
CommandFnType::Multi(m) => m(&ctx, invoke).await,
_ => {}
};
ctx.drop_executing(user_id).await;
} }
.await;
ctx.drop_executing(msg.author.id).await;
} }
} }
} }
} }
// DM Command
else if self.dm_enabled {
if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
let command = self
.commands_map
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.unwrap();
let args = full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string();
let user_id = invoke.author_id();
if msg.id == MessageId(0) || !ctx.check_executing(user_id).await {
ctx.set_executing(user_id).await;
match command.fun {
CommandFnType::Text(t) => t(&ctx, invoke, args).await,
CommandFnType::Multi(m) => m(&ctx, invoke).await,
_ => {}
};
ctx.drop_executing(user_id).await;
}
}
}
} }
} }

217
src/hooks.rs Normal file
View File

@ -0,0 +1,217 @@
use log::warn;
use regex_command_attr::check;
use serenity::{client::Context, model::channel::Channel};
use crate::{
framework::{CommandInvoke, CommandOptions, CreateGenericResponse, HookResult},
moderation_cmds, RecordingMacros, SQLPool,
};
#[check]
pub async fn macro_check(
ctx: &Context,
invoke: &CommandInvoke,
args: &CommandOptions,
) -> HookResult {
if let Some(guild_id) = invoke.guild_id() {
if args.command != moderation_cmds::MACRO_CMD_COMMAND.names[0] {
let active_recordings =
ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
let mut lock = active_recordings.write().await;
if let Some(command_macro) = lock.get_mut(&(guild_id, invoke.author_id())) {
command_macro.commands.push(args.clone());
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("Command recorded to macro"),
)
.await;
HookResult::Halt
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
}
#[check]
pub async fn check_self_permissions(
ctx: &Context,
invoke: &CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let user_id = ctx.cache.current_user_id();
let manage_webhooks =
guild.member_permissions(&ctx, user_id).await.map_or(false, |p| p.manage_webhooks());
let (send_messages, embed_links) = invoke
.channel_id()
.to_channel_cached(&ctx)
.map(|c| {
if let Channel::Guild(channel) = c {
channel.permissions_for_user(ctx, user_id).ok()
} else {
None
}
})
.flatten()
.map_or((false, false), |p| (p.send_messages(), p.embed_links()));
if manage_webhooks && send_messages && embed_links {
HookResult::Continue
} else {
if send_messages {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(format!(
"Please ensure the bot has the correct permissions:
**Send Message**
{} **Embed Links**
{} **Manage Webhooks**",
if manage_webhooks { "" } else { "" },
if embed_links { "" } else { "" },
)),
)
.await;
} else {
warn!("Missing permissions in guild {}", guild.id);
}
HookResult::Halt
}
} else {
HookResult::Continue
}
}
#[check]
pub async fn check_managed_permissions(
ctx: &Context,
invoke: &CommandInvoke,
args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap();
if permissions.manage_messages() {
return HookResult::Continue;
}
let member = invoke.member(&ctx).await.unwrap();
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
match sqlx::query!(
"
SELECT
role
FROM
roles
INNER JOIN
command_restrictions ON roles.id = command_restrictions.role_id
WHERE
command_restrictions.command = ? AND
roles.guild_id = (
SELECT
id
FROM
guilds
WHERE
guild = ?)
",
args.command,
guild.id.as_u64()
)
.fetch_all(&pool)
.await
{
Ok(rows) => {
let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>();
for row in rows {
if role_ids.contains(&row.role) {
return HookResult::Continue;
}
}
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(
"You must have \"Manage Messages\" or have a role capable of sending reminders to that channel. \
Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.",
),
)
.await;
HookResult::Halt
}
Err(sqlx::Error::RowNotFound) => {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(
"You must have \"Manage Messages\" or have a role capable of sending reminders to that channel. \
Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.",
),
)
.await;
HookResult::Halt
}
Err(e) => {
warn!("Unexpected error occurred querying command_restrictions: {:?}", e);
HookResult::Halt
}
}
} else {
HookResult::Continue
}
}
#[check]
pub async fn check_guild_permissions(
ctx: &Context,
invoke: &CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap();
if !permissions.manage_guild() {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(
"You must have the \"Manage Server\" permission to use this command",
),
)
.await;
HookResult::Halt
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
}

View File

@ -6,6 +6,7 @@ mod commands;
mod component_models; mod component_models;
mod consts; mod consts;
mod framework; mod framework;
mod hooks;
mod models; mod models;
mod time_parser; mod time_parser;
@ -39,7 +40,7 @@ use crate::{
component_models::ComponentDataModel, component_models::ComponentDataModel,
consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR}, consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR},
framework::RegexFramework, framework::RegexFramework,
models::guild_data::GuildData, models::{command_macro::CommandMacro, guild_data::GuildData},
}; };
struct GuildDataCache; struct GuildDataCache;
@ -72,6 +73,12 @@ impl TypeMapKey for CurrentlyExecuting {
type Value = Arc<RwLock<HashMap<UserId, Instant>>>; type Value = Arc<RwLock<HashMap<UserId, Instant>>>;
} }
struct RecordingMacros;
impl TypeMapKey for RecordingMacros {
type Value = Arc<RwLock<HashMap<(GuildId, UserId), CommandMacro>>>;
}
#[async_trait] #[async_trait]
trait LimitExecutors { trait LimitExecutors {
async fn check_executing(&self, user: UserId) -> bool; async fn check_executing(&self, user: UserId) -> bool;
@ -326,10 +333,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.add_command(&moderation_cmds::RESTRICT_COMMAND) .add_command(&moderation_cmds::RESTRICT_COMMAND)
.add_command(&moderation_cmds::TIMEZONE_COMMAND) .add_command(&moderation_cmds::TIMEZONE_COMMAND)
.add_command(&moderation_cmds::PREFIX_COMMAND) .add_command(&moderation_cmds::PREFIX_COMMAND)
.add_command(&moderation_cmds::MACRO_CMD_COMMAND)
/* /*
.add_command("alias", &moderation_cmds::ALIAS_COMMAND) .add_command("alias", &moderation_cmds::ALIAS_COMMAND)
.add_command("a", &moderation_cmds::ALIAS_COMMAND) .add_command("a", &moderation_cmds::ALIAS_COMMAND)
*/ */
.add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK)
.add_hook(&hooks::MACRO_CHECK_HOOK)
.build(); .build();
let framework_arc = Arc::new(framework); let framework_arc = Arc::new(framework);
@ -375,6 +385,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
data.insert::<PopularTimezones>(Arc::new(popular_timezones)); data.insert::<PopularTimezones>(Arc::new(popular_timezones));
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new())); data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));
data.insert::<RegexFramework>(framework_arc.clone()); data.insert::<RegexFramework>(framework_arc.clone());
data.insert::<RecordingMacros>(Arc::new(RwLock::new(HashMap::new())));
} }
if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| { if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| {

View File

@ -0,0 +1,10 @@
use serenity::model::id::GuildId;
use crate::framework::CommandOptions;
pub struct CommandMacro {
pub guild_id: GuildId,
pub name: String,
pub description: Option<String>,
pub commands: Vec<CommandOptions>,
}

View File

@ -1,4 +1,5 @@
pub mod channel_data; pub mod channel_data;
pub mod command_macro;
pub mod guild_data; pub mod guild_data;
pub mod reminder; pub mod reminder;
pub mod timer; pub mod timer;

View File

@ -1,32 +1,5 @@
use crate::consts::{MAX_TIME, MIN_INTERVAL}; use crate::consts::{MAX_TIME, MIN_INTERVAL};
#[derive(Debug)]
pub enum InteractionError {
InvalidFormat,
InvalidBase64,
InvalidSize,
NoReminder,
SignatureMismatch,
InvalidAction,
}
impl ToString for InteractionError {
fn to_string(&self) -> String {
match self {
InteractionError::InvalidFormat => {
String::from("The interaction data was improperly formatted")
}
InteractionError::InvalidBase64 => String::from("The interaction data was invalid"),
InteractionError::InvalidSize => String::from("The interaction data was invalid"),
InteractionError::NoReminder => String::from("Reminder could not be found"),
InteractionError::SignatureMismatch => {
String::from("Only the user who did the command can use interactions")
}
InteractionError::InvalidAction => String::from("The action was invalid"),
}
}
}
#[derive(PartialEq, Eq, Hash, Debug)] #[derive(PartialEq, Eq, Hash, Debug)]
pub enum ReminderError { pub enum ReminderError {
LongTime, LongTime,