From 395a8481f12e4921de9bd6a92e22db19ca4d76b8 Mon Sep 17 00:00:00 2001 From: jellywx Date: Sun, 12 Sep 2021 16:59:19 +0100 Subject: [PATCH] typing --- Cargo.toml | 2 +- .../Cargo.toml | 0 .../src/attributes.rs | 0 .../src/consts.rs | 0 .../src/lib.rs | 0 .../src/structures.rs | 0 .../src/util.rs | 0 src/commands/moderation_cmds.rs | 117 ++++++----- src/commands/reminder_cmds.rs | 64 ++---- src/framework.rs | 193 +++++++++++++++--- 10 files changed, 242 insertions(+), 134 deletions(-) rename {regex_command_attr => command_attributes}/Cargo.toml (100%) rename {regex_command_attr => command_attributes}/src/attributes.rs (100%) rename {regex_command_attr => command_attributes}/src/consts.rs (100%) rename {regex_command_attr => command_attributes}/src/lib.rs (100%) rename {regex_command_attr => command_attributes}/src/structures.rs (100%) rename {regex_command_attr => command_attributes}/src/util.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 082e7e2..9dedada 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,4 +27,4 @@ sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", base64 = "0.13.0" [dependencies.regex_command_attr] -path = "./regex_command_attr" +path = "command_attributes" diff --git a/regex_command_attr/Cargo.toml b/command_attributes/Cargo.toml similarity index 100% rename from regex_command_attr/Cargo.toml rename to command_attributes/Cargo.toml diff --git a/regex_command_attr/src/attributes.rs b/command_attributes/src/attributes.rs similarity index 100% rename from regex_command_attr/src/attributes.rs rename to command_attributes/src/attributes.rs diff --git a/regex_command_attr/src/consts.rs b/command_attributes/src/consts.rs similarity index 100% rename from regex_command_attr/src/consts.rs rename to command_attributes/src/consts.rs diff --git a/regex_command_attr/src/lib.rs b/command_attributes/src/lib.rs similarity index 100% rename from regex_command_attr/src/lib.rs rename to command_attributes/src/lib.rs diff --git a/regex_command_attr/src/structures.rs b/command_attributes/src/structures.rs similarity index 100% rename from regex_command_attr/src/structures.rs rename to command_attributes/src/structures.rs diff --git a/regex_command_attr/src/util.rs b/command_attributes/src/util.rs similarity index 100% rename from regex_command_attr/src/util.rs rename to command_attributes/src/util.rs diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index 2c97b75..a68f9e1 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -9,9 +9,7 @@ use serenity::{ client::Context, model::{ channel::Message, - guild::ActionRole::Create, id::{ChannelId, MessageId, RoleId}, - interactions::message_component::ButtonStyle, misc::Mentionable, }, }; @@ -19,7 +17,9 @@ use serenity::{ use crate::{ component_models::{ComponentDataModel, Restrict}, consts::{REGEX_ALIAS, REGEX_COMMANDS, THEME_COLOR}, - framework::{CommandInvoke, CreateGenericResponse, PermissionLevel}, + framework::{ + CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue, PermissionLevel, + }, models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxData}, PopularTimezones, RegexFramework, SQLPool, }; @@ -38,14 +38,14 @@ use crate::{ async fn blacklist( ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), - args: HashMap, + args: CommandOptions, ) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let channel = match args.get("channel") { - Some(channel_id) => ChannelId(channel_id.parse::().unwrap()), + Some(OptionValue::Channel(channel_id)) => *channel_id, - None => invoke.channel_id(), + _ => invoke.channel_id(), } .to_channel_cached(&ctx) .unwrap(); @@ -82,17 +82,13 @@ async fn blacklist( kind = "String", required = false )] -async fn timezone( - ctx: &Context, - invoke: &(dyn CommandInvoke + Send + Sync), - args: HashMap, -) { +async fn timezone(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap(); let footer_text = format!("Current timezone: {}", user_data.timezone); - if let Some(timezone) = args.get("timezone") { + if let Some(OptionValue::String(timezone)) = args.get("timezone") { match timezone.parse::() { Ok(tz) => { user_data.timezone = timezone.clone(); @@ -237,65 +233,66 @@ async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: )] #[supports_dm(false)] #[required_permissions(Restricted)] -async fn restrict( - ctx: &Context, - invoke: &(dyn CommandInvoke + Send + Sync), - args: HashMap, -) { +async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let framework = ctx.data.read().await.get::().cloned().unwrap(); - let role = RoleId(args.get("role").unwrap().parse::().unwrap()); + if let Some(OptionValue::Role(role)) = args.get("role") { + let restricted_commands = + sqlx::query!("SELECT command FROM command_restrictions WHERE role_id = ?", role.0) + .fetch_all(&pool) + .await + .unwrap() + .iter() + .map(|row| row.command.clone()) + .collect::>(); - let restricted_commands = - sqlx::query!("SELECT command FROM command_restrictions WHERE role_id = ?", role.0) - .fetch_all(&pool) - .await - .unwrap() + let restrictable_commands = framework + .commands .iter() - .map(|row| row.command.clone()) + .filter(|c| c.required_permissions == PermissionLevel::Managed) + .map(|c| c.names[0].to_string()) .collect::>(); - let restrictable_commands = framework - .commands - .iter() - .filter(|c| c.required_permissions == PermissionLevel::Managed) - .map(|c| c.names[0].to_string()) - .collect::>(); + let len = restrictable_commands.len(); - let len = restrictable_commands.len(); + let restrict_pl = ComponentDataModel::Restrict(Restrict { role_id: *role }); - let restrict_pl = ComponentDataModel::Restrict(Restrict { role_id: role }); + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(format!( + "Select the commands to allow to {} from below:", + role.mention() + )) + .components(|c| { + c.create_action_row(|row| { + row.create_select_menu(|select| { + select + .custom_id(restrict_pl.to_custom_id()) + .options(|options| { + for command in restrictable_commands { + options.create_option(|opt| { + opt.label(&command) + .value(&command) + .default_selection( + restricted_commands.contains(&command), + ) + }); + } - invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new() - .content(format!("Select the commands to allow to {} from below:", role.mention())) - .components(|c| { - c.create_action_row(|row| { - row.create_select_menu(|select| { - select - .custom_id(restrict_pl.to_custom_id()) - .options(|options| { - for command in restrictable_commands { - options.create_option(|opt| { - opt.label(&command).value(&command).default_selection( - restricted_commands.contains(&command), - ) - }); - } - - options - }) - .min_values(0) - .max_values(len as u64) + options + }) + .min_values(0) + .max_values(len as u64) + }) }) - }) - }), - ) - .await - .unwrap(); + }), + ) + .await + .unwrap(); + } } /* diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 9927e4a..5a34be2 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -24,7 +24,7 @@ use crate::{ EMBED_DESCRIPTION_MAX_LENGTH, REGEX_CHANNEL_USER, REGEX_NATURAL_COMMAND_1, REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND, THEME_COLOR, }, - framework::{CommandInvoke, CreateGenericResponse}, + framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue}, models::{ channel_data::ChannelData, guild_data::GuildData, @@ -52,11 +52,7 @@ use crate::{ )] #[supports_dm(false)] #[required_permissions(Restricted)] -async fn pause( - ctx: &Context, - invoke: &(dyn CommandInvoke + Send + Sync), - args: HashMap, -) { +async fn pause(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; @@ -64,7 +60,7 @@ async fn pause( let mut channel = ctx.channel_data(invoke.channel_id()).await.unwrap(); match args.get("until") { - Some(until) => { + Some(OptionValue::String(until)) => { let parsed = natural_parser(until, &timezone.to_string()).await; if let Some(timestamp) = parsed { @@ -94,7 +90,7 @@ async fn pause( .await; } } - None => { + _ => { channel.paused = !channel.paused; channel.paused_until = None; @@ -142,16 +138,12 @@ async fn pause( required = false )] #[required_permissions(Restricted)] -async fn offset( - ctx: &Context, - invoke: &(dyn CommandInvoke + Send + Sync), - args: HashMap, -) { +async fn offset(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); - let combined_time = args.get("hours").map_or(0, |h| h.parse::().unwrap() * 3600) - + args.get("minutes").map_or(0, |m| m.parse::().unwrap() * 60) - + args.get("seconds").map_or(0, |s| s.parse::().unwrap()); + let combined_time = args.get("hours").map_or(0, |h| h.as_i64().unwrap() * 3600) + + args.get("minutes").map_or(0, |m| m.as_i64().unwrap() * 60) + + args.get("seconds").map_or(0, |s| s.as_i64().unwrap()); if combined_time == 0 { let _ = invoke @@ -223,15 +215,11 @@ WHERE FIND_IN_SET(channels.`channel`, ?)", required = false )] #[required_permissions(Restricted)] -async fn nudge( - ctx: &Context, - invoke: &(dyn CommandInvoke + Send + Sync), - args: HashMap, -) { +async fn nudge(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); - let combined_time = args.get("minutes").map_or(0, |m| m.parse::().unwrap() * 60) - + args.get("seconds").map_or(0, |s| s.parse::().unwrap()); + let combined_time = args.get("minutes").map_or(0, |m| m.as_i64().unwrap() * 60) + + args.get("seconds").map_or(0, |s| s.as_i64().unwrap()); if combined_time < i16::MIN as i64 || combined_time > i16::MAX as i64 { let _ = invoke @@ -279,20 +267,16 @@ async fn nudge( required = false )] #[required_permissions(Managed)] -async fn look( - ctx: &Context, - invoke: &(dyn CommandInvoke + Send + Sync), - args: HashMap, -) { +async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { let pool = ctx.data.read().await.get::().cloned().unwrap(); let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; let flags = LookFlags { - show_disabled: args.get("disabled").map(|b| b == "true").unwrap_or(true), - channel_id: args.get("channel").map(|c| ChannelId(c.parse::().unwrap())), + show_disabled: args.get("disabled").map(|i| i.as_bool()).flatten().unwrap_or(true), + channel_id: args.get("channel").map(|i| i.as_channel_id()).flatten(), time_display: args.get("relative").map_or(TimeDisplayType::Relative, |b| { - if b == "true" { + if b.as_bool() == Some(true) { TimeDisplayType::Relative } else { TimeDisplayType::Absolute @@ -473,11 +457,7 @@ INSERT INTO events (event_name, bulk_count, guild_id, user_id) VALUES ('delete', #[description("Delete a timer")] #[arg(name = "name", description = "Name of the timer to delete", kind = "String", required = true)] #[required_permissions(Managed)] -async fn timer( - ctx: &Context, - invoke: &(dyn CommandInvoke + Send + Sync), - args: HashMap, -) { +async fn timer(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) { fn time_difference(start_time: NaiveDateTime) -> String { let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let now = NaiveDateTime::from_timestamp(unix_time, 0); @@ -495,8 +475,8 @@ async fn timer( let owner = invoke.guild_id().map(|g| g.0).unwrap_or_else(|| invoke.author_id().0); - match args.get("").map(|s| s.as_str()) { - Some("start") => { + match args.subcommand.clone().unwrap().as_str() { + "start" => { let count = Timer::count_from_owner(owner, &pool).await; if count >= 25 { @@ -508,7 +488,7 @@ async fn timer( ) .await; } else { - let name = args.get("name").unwrap(); + let name = args.get("name").unwrap().to_string(); if name.len() <= 32 { Timer::create(&name, owner, &pool).await; @@ -530,8 +510,8 @@ async fn timer( } } } - Some("delete") => { - let name = args.get("name").unwrap(); + "delete" => { + let name = args.get("name").unwrap().to_string(); let exists = sqlx::query!( " @@ -570,7 +550,7 @@ DELETE FROM timers WHERE owner = ? AND name = ? .await; } } - Some("list") => { + "list" => { let timers = Timer::from_owner(owner, &pool).await; if timers.len() > 0 { diff --git a/src/framework.rs b/src/framework.rs index c0fc710..938b183 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -18,13 +18,15 @@ use serenity::{ model::{ channel::{Channel, GuildChannel, Message}, guild::{Guild, Member}, - id::{ChannelId, GuildId, MessageId, UserId}, + id::{ChannelId, GuildId, MessageId, RoleId, UserId}, interactions::{ application_command::{ - ApplicationCommand, ApplicationCommandInteraction, ApplicationCommandOptionType, + ApplicationCommand, ApplicationCommandInteraction, ApplicationCommandOption, + ApplicationCommandOptionType, }, InteractionResponseType, }, + prelude::application_command::ApplicationCommandInteractionDataOption, }, prelude::TypeMapKey, FutureExt, Result as SerenityResult, @@ -281,10 +283,166 @@ pub struct Arg { pub options: &'static [&'static Self], } +pub enum OptionValue { + String(String), + Integer(i64), + Boolean(bool), + User(UserId), + Channel(ChannelId), + Role(RoleId), + Mentionable(u64), + Number(f64), +} + +impl OptionValue { + pub fn as_i64(&self) -> Option { + match self { + OptionValue::Integer(i) => Some(*i), + _ => None, + } + } + + pub fn as_bool(&self) -> Option { + match self { + OptionValue::Boolean(b) => Some(*b), + _ => None, + } + } + + pub fn as_channel_id(&self) -> Option { + match self { + OptionValue::Channel(c) => Some(*c), + _ => None, + } + } + + pub fn to_string(&self) -> String { + match self { + OptionValue::String(s) => s.to_string(), + OptionValue::Integer(i) => i.to_string(), + OptionValue::Boolean(b) => b.to_string(), + OptionValue::User(u) => u.to_string(), + OptionValue::Channel(c) => c.to_string(), + OptionValue::Role(r) => r.to_string(), + OptionValue::Mentionable(m) => m.to_string(), + OptionValue::Number(n) => n.to_string(), + } + } +} + +pub struct CommandOptions { + pub command: String, + pub subcommand: Option, + pub subcommand_group: Option, + pub options: HashMap, +} + +impl CommandOptions { + pub fn get(&self, key: &str) -> Option<&OptionValue> { + self.options.get(key) + } +} + +impl From for CommandOptions { + fn from(interaction: ApplicationCommandInteraction) -> Self { + fn match_option( + option: ApplicationCommandInteractionDataOption, + cmd_opts: &mut CommandOptions, + ) { + match option.kind { + ApplicationCommandOptionType::SubCommand => { + cmd_opts.subcommand = Some(option.name); + + for opt in option.options { + match_option(opt, cmd_opts); + } + } + ApplicationCommandOptionType::SubCommandGroup => { + cmd_opts.subcommand_group = Some(option.name); + + for opt in option.options { + match_option(opt, cmd_opts); + } + } + ApplicationCommandOptionType::String => { + cmd_opts.options.insert( + option.name, + OptionValue::String(option.value.unwrap().to_string()), + ); + } + ApplicationCommandOptionType::Integer => { + cmd_opts.options.insert( + option.name, + OptionValue::Integer(option.value.map(|m| m.as_i64()).flatten().unwrap()), + ); + } + ApplicationCommandOptionType::Boolean => { + cmd_opts.options.insert( + option.name, + OptionValue::Boolean(option.value.map(|m| m.as_bool()).flatten().unwrap()), + ); + } + ApplicationCommandOptionType::User => { + cmd_opts.options.insert( + option.name, + OptionValue::User(UserId( + option.value.map(|m| m.as_u64()).flatten().unwrap(), + )), + ); + } + ApplicationCommandOptionType::Channel => { + cmd_opts.options.insert( + option.name, + OptionValue::Channel(ChannelId( + option.value.map(|m| m.as_u64()).flatten().unwrap(), + )), + ); + } + ApplicationCommandOptionType::Role => { + cmd_opts.options.insert( + option.name, + OptionValue::Role(RoleId( + option.value.map(|m| m.as_u64()).flatten().unwrap(), + )), + ); + } + ApplicationCommandOptionType::Mentionable => { + cmd_opts.options.insert( + option.name, + OptionValue::Mentionable( + option.value.map(|m| m.as_u64()).flatten().unwrap(), + ), + ); + } + ApplicationCommandOptionType::Number => { + cmd_opts.options.insert( + option.name, + OptionValue::Number(option.value.map(|m| m.as_f64()).flatten().unwrap()), + ); + } + _ => {} + } + } + + let mut cmd_opts = Self { + command: interaction.data.name, + subcommand: None, + subcommand_group: None, + options: Default::default(), + }; + + for option in interaction.data.options { + match_option(option, &mut cmd_opts) + } + + cmd_opts + } +} + type SlashCommandFn = for<'fut> fn( &'fut Context, &'fut (dyn CommandInvoke + Sync + Send), - HashMap, + CommandOptions, ) -> BoxFuture<'fut, ()>; type TextCommandFn = for<'fut> fn( @@ -631,34 +789,7 @@ impl RegexFramework { let member = interaction.clone().member.unwrap(); if command.check_permissions(&ctx, &guild, &member).await { - let mut args = HashMap::new(); - - for arg in interaction.data.options.iter() { - if let Some(value) = &arg.value { - args.insert( - arg.name.clone(), - match value { - Value::Bool(b) => b.to_string(), - Value::Number(n) => n.to_string(), - Value::String(s) => s.to_owned(), - _ => String::new(), - }, - ); - } else { - args.insert("".to_string(), arg.name.clone()); - for sub_arg in arg.options.iter().filter(|o| o.value.is_some()) { - args.insert( - sub_arg.name.clone(), - match sub_arg.value.as_ref().unwrap() { - Value::Bool(b) => b.to_string(), - Value::Number(n) => n.to_string(), - Value::String(s) => s.to_owned(), - _ => String::new(), - }, - ); - } - } - } + let args = CommandOptions::from(interaction.clone()); if !ctx.check_executing(interaction.author_id()).await { ctx.set_executing(interaction.author_id()).await;