// todo move framework to its own module, split out permission checks use std::{ collections::{HashMap, HashSet}, hash::{Hash, Hasher}, sync::Arc, }; use log::info; use regex::{Match, Regex, RegexBuilder}; use serde::{Deserialize, Serialize}; use serenity::{ async_trait, builder::{CreateApplicationCommands, CreateComponents, CreateEmbed}, cache::Cache, client::Context, framework::Framework, futures::prelude::future::BoxFuture, http::{CacheHttp, Http}, model::{ channel::Message, guild::{Guild, Member}, id::{ChannelId, GuildId, MessageId, RoleId, UserId}, interactions::{ application_command::{ ApplicationCommand, ApplicationCommandInteraction, ApplicationCommandOptionType, }, InteractionResponseType, }, prelude::application_command::ApplicationCommandInteractionDataOption, }, prelude::TypeMapKey, Result as SerenityResult, }; use crate::{models::CtxData, LimitExecutors}; pub struct CreateGenericResponse { content: String, embed: Option, components: Option, } impl CreateGenericResponse { pub fn new() -> Self { Self { content: "".to_string(), embed: None, components: None } } pub fn content(mut self, content: D) -> Self { self.content = content.to_string(); self } pub fn embed &mut CreateEmbed>(mut self, f: F) -> Self { let mut embed = CreateEmbed::default(); f(&mut embed); self.embed = Some(embed); self } pub fn components &mut CreateComponents>( mut self, f: F, ) -> Self { let mut components = CreateComponents::default(); f(&mut components); self.components = Some(components); self } } enum InvokeModel { Slash(ApplicationCommandInteraction), Text(Message), } pub struct CommandInvoke { model: InvokeModel, already_responded: bool, } impl CommandInvoke { fn slash(interaction: ApplicationCommandInteraction) -> Self { Self { model: InvokeModel::Slash(interaction), already_responded: false } } fn msg(msg: Message) -> Self { Self { model: InvokeModel::Text(msg), already_responded: false } } pub fn interaction(self) -> Option { match self.model { InvokeModel::Slash(i) => Some(i), InvokeModel::Text(_) => None, } } pub fn channel_id(&self) -> ChannelId { match &self.model { InvokeModel::Slash(i) => i.channel_id, InvokeModel::Text(m) => m.channel_id, } } pub fn guild_id(&self) -> Option { match &self.model { InvokeModel::Slash(i) => i.guild_id, InvokeModel::Text(m) => m.guild_id, } } pub fn guild(&self, cache: impl AsRef) -> Option { self.guild_id().map(|id| id.to_guild_cached(cache)).flatten() } pub fn author_id(&self) -> UserId { match &self.model { InvokeModel::Slash(i) => i.user.id, InvokeModel::Text(m) => m.author.id, } } pub async fn member(&self, cache_http: impl CacheHttp) -> Option { match &self.model { InvokeModel::Slash(i) => i.member.clone(), InvokeModel::Text(m) => m.member(cache_http).await.ok(), } } pub async fn respond( &self, http: impl AsRef, generic_response: CreateGenericResponse, ) -> SerenityResult<()> { match &self.model { InvokeModel::Slash(i) => { if !self.already_responded { i.create_interaction_response(http, |r| { r.kind(InteractionResponseType::ChannelMessageWithSource) .interaction_response_data(|d| { d.content(generic_response.content); if let Some(embed) = generic_response.embed { d.add_embed(embed.clone()); } if let Some(components) = generic_response.components { d.components(|c| { *c = components; c }); } d }) }) .await .map(|_| ()) } else { i.create_followup_message(http, |d| { d.content(generic_response.content); if let Some(embed) = generic_response.embed { d.add_embed(embed.clone()); } if let Some(components) = generic_response.components { d.components(|c| { *c = components; c }); } d }) .await .map(|_| ()) } } InvokeModel::Text(m) => m .channel_id .send_message(http, |m| { m.content(generic_response.content); if let Some(embed) = generic_response.embed { m.set_embed(embed.clone()); } if let Some(components) = generic_response.components { m.components(|c| { *c = components; c }); } m }) .await .map(|_| ()), } } } #[derive(Debug)] pub struct Arg { pub name: &'static str, pub description: &'static str, pub kind: ApplicationCommandOptionType, pub required: bool, pub options: &'static [&'static Self], } #[derive(Serialize, Deserialize, Clone)] pub enum OptionValue { String(String), Integer(i64), Boolean(bool), User(UserId), Channel(ChannelId), Role(RoleId), Mentionable(u64), Number(f64), } impl OptionValue { pub fn as_i64(&self) -> Option { match self { OptionValue::Integer(i) => Some(*i), _ => None, } } pub fn as_bool(&self) -> Option { match self { OptionValue::Boolean(b) => Some(*b), _ => None, } } pub fn as_channel_id(&self) -> Option { match self { OptionValue::Channel(c) => Some(*c), _ => None, } } pub fn to_string(&self) -> String { match self { OptionValue::String(s) => s.to_string(), OptionValue::Integer(i) => i.to_string(), OptionValue::Boolean(b) => b.to_string(), OptionValue::User(u) => u.to_string(), OptionValue::Channel(c) => c.to_string(), OptionValue::Role(r) => r.to_string(), OptionValue::Mentionable(m) => m.to_string(), OptionValue::Number(n) => n.to_string(), } } } #[derive(Serialize, Deserialize, Clone)] pub struct CommandOptions { pub command: &'static str, pub subcommand: Option, pub subcommand_group: Option, pub options: HashMap, } impl CommandOptions { pub fn get(&self, key: &str) -> Option<&OptionValue> { self.options.get(key) } } impl CommandOptions { fn new(command: &'static Command) -> Self { Self { command: command.names[0], subcommand: None, subcommand_group: None, options: Default::default(), } } fn populate(mut self, interaction: &ApplicationCommandInteraction) -> Self { fn match_option( option: ApplicationCommandInteractionDataOption, cmd_opts: &mut CommandOptions, ) { match option.kind { ApplicationCommandOptionType::SubCommand => { cmd_opts.subcommand = Some(option.name); for opt in option.options { match_option(opt, cmd_opts); } } ApplicationCommandOptionType::SubCommandGroup => { cmd_opts.subcommand_group = Some(option.name); for opt in option.options { match_option(opt, cmd_opts); } } ApplicationCommandOptionType::String => { cmd_opts.options.insert( option.name, OptionValue::String(option.value.unwrap().to_string()), ); } ApplicationCommandOptionType::Integer => { cmd_opts.options.insert( option.name, OptionValue::Integer(option.value.map(|m| m.as_i64()).flatten().unwrap()), ); } ApplicationCommandOptionType::Boolean => { cmd_opts.options.insert( option.name, OptionValue::Boolean(option.value.map(|m| m.as_bool()).flatten().unwrap()), ); } ApplicationCommandOptionType::User => { cmd_opts.options.insert( option.name, OptionValue::User(UserId( option.value.map(|m| m.as_u64()).flatten().unwrap(), )), ); } ApplicationCommandOptionType::Channel => { cmd_opts.options.insert( option.name, OptionValue::Channel(ChannelId( option.value.map(|m| m.as_u64()).flatten().unwrap(), )), ); } ApplicationCommandOptionType::Role => { cmd_opts.options.insert( option.name, OptionValue::Role(RoleId( option .value .map(|m| m.as_str().map(|s| s.parse::().ok())) .flatten() .flatten() .unwrap(), )), ); } ApplicationCommandOptionType::Mentionable => { cmd_opts.options.insert( option.name, OptionValue::Mentionable( option.value.map(|m| m.as_u64()).flatten().unwrap(), ), ); } ApplicationCommandOptionType::Number => { cmd_opts.options.insert( option.name, OptionValue::Number(option.value.map(|m| m.as_f64()).flatten().unwrap()), ); } _ => {} } } for option in &interaction.data.options { match_option(option.clone(), &mut self) } self } } pub enum HookResult { Continue, Halt, } type SlashCommandFn = for<'fut> fn(&'fut Context, CommandInvoke, CommandOptions) -> BoxFuture<'fut, ()>; type TextCommandFn = for<'fut> fn(&'fut Context, CommandInvoke, String) -> BoxFuture<'fut, ()>; type MultiCommandFn = for<'fut> fn(&'fut Context, CommandInvoke) -> BoxFuture<'fut, ()>; pub type HookFn = for<'fut> fn( &'fut Context, &'fut CommandInvoke, &'fut CommandOptions, ) -> BoxFuture<'fut, HookResult>; pub enum CommandFnType { Slash(SlashCommandFn), Text(TextCommandFn), Multi(MultiCommandFn), } impl CommandFnType { pub fn is_slash(&self) -> bool { match self { CommandFnType::Text(_) => false, _ => true, } } } pub struct Hook { pub fun: HookFn, pub uuid: u128, } impl PartialEq for Hook { fn eq(&self, other: &Self) -> bool { self.uuid == other.uuid } } pub struct Command { pub fun: CommandFnType, pub names: &'static [&'static str], pub desc: &'static str, pub examples: &'static [&'static str], pub group: &'static str, pub args: &'static [&'static Arg], pub can_blacklist: bool, pub supports_dm: bool, pub hooks: &'static [&'static Hook], } impl Hash for Command { fn hash(&self, state: &mut H) { self.names[0].hash(state) } } impl PartialEq for Command { fn eq(&self, other: &Self) -> bool { self.names[0] == other.names[0] } } impl Eq for Command {} pub struct RegexFramework { pub commands_map: HashMap, pub commands: HashSet<&'static Command>, command_matcher: Regex, dm_regex_matcher: Regex, default_prefix: String, client_id: u64, ignore_bots: bool, case_insensitive: bool, dm_enabled: bool, debug_guild: Option, hooks: Vec<&'static Hook>, } impl TypeMapKey for RegexFramework { type Value = Arc; } impl RegexFramework { pub fn new>(client_id: T) -> Self { Self { commands_map: HashMap::new(), commands: HashSet::new(), command_matcher: Regex::new(r#"^$"#).unwrap(), dm_regex_matcher: Regex::new(r#"^$"#).unwrap(), default_prefix: "".to_string(), client_id: client_id.into(), ignore_bots: true, case_insensitive: true, dm_enabled: true, debug_guild: None, hooks: vec![], } } pub fn case_insensitive(mut self, case_insensitive: bool) -> Self { self.case_insensitive = case_insensitive; self } pub fn default_prefix(mut self, new_prefix: T) -> Self { self.default_prefix = new_prefix.to_string(); self } pub fn ignore_bots(mut self, ignore_bots: bool) -> Self { self.ignore_bots = ignore_bots; self } pub fn dm_enabled(mut self, dm_enabled: bool) -> Self { self.dm_enabled = dm_enabled; self } pub fn add_hook(mut self, fun: &'static Hook) -> Self { self.hooks.push(fun); self } pub fn add_command(mut self, command: &'static Command) -> Self { self.commands.insert(command); for name in command.names { self.commands_map.insert(name.to_string(), command); } self } pub fn debug_guild(mut self, guild_id: Option) -> Self { self.debug_guild = guild_id; self } pub fn build(mut self) -> Self { { let command_names; { let mut command_names_vec = self.commands_map.keys().map(|k| &k[..]).collect::>(); command_names_vec.sort_unstable_by_key(|a| a.len()); command_names = command_names_vec.join("|"); } info!("Command names: {}", command_names); { let match_string = r#"^(?:(?:<@ID>\s*)|(?:<@!ID>\s*)|(?P\S{1,5}?))(?PCOMMANDS)(?:$|\s+(?P.*))$"# .replace("COMMANDS", command_names.as_str()) .replace("ID", self.client_id.to_string().as_str()); self.command_matcher = RegexBuilder::new(match_string.as_str()) .case_insensitive(self.case_insensitive) .dot_matches_new_line(true) .build() .unwrap(); } } { let dm_command_names; { let mut command_names_vec = self .commands_map .iter() .filter_map( |(key, command)| if command.supports_dm { Some(&key[..]) } else { None }, ) .collect::>(); command_names_vec.sort_unstable_by_key(|a| a.len()); dm_command_names = command_names_vec.join("|"); } { let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(\$)|())(?PCOMMANDS)(?:$|\s+(?P.*))$"# .replace("COMMANDS", dm_command_names.as_str()) .replace("ID", self.client_id.to_string().as_str()); self.dm_regex_matcher = RegexBuilder::new(match_string.as_str()) .case_insensitive(self.case_insensitive) .dot_matches_new_line(true) .build() .unwrap(); } } self } fn _populate_commands<'a>( &self, commands: &'a mut CreateApplicationCommands, ) -> &'a mut CreateApplicationCommands { for command in &self.commands { if command.fun.is_slash() { commands.create_application_command(|c| { c.name(command.names[0]).description(command.desc); for arg in command.args { c.create_option(|o| { o.name(arg.name) .description(arg.description) .kind(arg.kind) .required(arg.required); for option in arg.options { o.create_sub_option(|s| { s.name(option.name) .description(option.description) .kind(option.kind) .required(option.required) }); } o }); } c }); } } commands } pub async fn build_slash(&self, http: impl AsRef) { info!("Building slash commands..."); match self.debug_guild { None => { ApplicationCommand::set_global_application_commands(&http, |c| { self._populate_commands(c) }) .await .unwrap(); } Some(debug_guild) => { debug_guild .set_application_commands(&http, |c| self._populate_commands(c)) .await .unwrap(); } } info!("Slash commands built!"); } pub async fn execute(&self, ctx: Context, interaction: ApplicationCommandInteraction) { let command = { self.commands_map .get(&interaction.data.name) .expect(&format!("Received invalid command: {}", interaction.data.name)) }; let args = CommandOptions::new(command).populate(&interaction); let command_invoke = CommandInvoke::slash(interaction); for hook in command.hooks { match (hook.fun)(&ctx, &command_invoke, &args).await { HookResult::Continue => {} HookResult::Halt => { return; } } } for hook in &self.hooks { match (hook.fun)(&ctx, &command_invoke, &args).await { HookResult::Continue => {} HookResult::Halt => { return; } } } let user_id = command_invoke.author_id(); if !ctx.check_executing(user_id).await { ctx.set_executing(user_id).await; match command.fun { CommandFnType::Slash(t) => t(&ctx, command_invoke, args).await, CommandFnType::Multi(m) => m(&ctx, command_invoke).await, _ => (), } ctx.drop_executing(user_id).await; } } } #[async_trait] impl Framework for RegexFramework { async fn dispatch(&self, ctx: Context, msg: Message) { async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option>) -> bool { if let Some(prefix) = prefix_opt { let guild_prefix = ctx.prefix(Some(guild.id)).await; guild_prefix.as_str() == prefix.as_str() } else { true } } // gate to prevent analysing messages unnecessarily if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() { return; } let user_id = msg.author.id; let invoke = CommandInvoke::msg(msg.clone()); // Guild Command if let Some(guild) = msg.guild(&ctx) { if let Some(full_match) = self.command_matcher.captures(&msg.content) { if check_prefix(&ctx, &guild, full_match.name("prefix")).await { let command = self .commands_map .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) .unwrap(); let channel_data = ctx.channel_data(invoke.channel_id()).await.unwrap(); if !command.can_blacklist || !channel_data.blacklisted { let args = full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); if msg.id == MessageId(0) || !ctx.check_executing(user_id).await { ctx.set_executing(user_id).await; match command.fun { CommandFnType::Text(t) => t(&ctx, invoke, args).await, CommandFnType::Multi(m) => m(&ctx, invoke).await, _ => {} }; ctx.drop_executing(user_id).await; } } } } } // DM Command else if self.dm_enabled { if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { let command = self .commands_map .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) .unwrap(); let args = full_match.name("args").map(|m| m.as_str()).unwrap_or("").to_string(); let user_id = invoke.author_id(); if msg.id == MessageId(0) || !ctx.check_executing(user_id).await { ctx.set_executing(user_id).await; match command.fun { CommandFnType::Text(t) => t(&ctx, invoke, args).await, CommandFnType::Multi(m) => m(&ctx, invoke).await, _ => {} }; ctx.drop_executing(user_id).await; } } } } }