From 60ead9a1ef549842894113448cb1d167fb7b77bf Mon Sep 17 00:00:00 2001 From: jellywx Date: Mon, 14 Jun 2021 21:35:38 +0100 Subject: [PATCH] aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa --- regex_command_attr/src/util.rs | 2 +- src/framework.rs | 180 +++++++++++++++++++-------- src/main.rs | 214 +++++++++++++++------------------ 3 files changed, 231 insertions(+), 165 deletions(-) diff --git a/regex_command_attr/src/util.rs b/regex_command_attr/src/util.rs index 7c44f9a..12197d8 100644 --- a/regex_command_attr/src/util.rs +++ b/regex_command_attr/src/util.rs @@ -171,7 +171,7 @@ pub fn create_declaration_validations(fun: &mut CommandFun) -> SynResult<()> { let context: Type = parse_quote!(&serenity::client::Context); let message: Type = parse_quote!(&(dyn crate::framework::CommandInvoke + Sync + Send)); - let args: Type = parse_quote!(serenity::framework::standard::Args); + let args: Type = parse_quote!(crate::framework::Args); let mut index = 0; diff --git a/src/framework.rs b/src/framework.rs index 98d9550..e5aa010 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -3,17 +3,14 @@ use serenity::{ builder::CreateEmbed, cache::Cache, client::Context, - framework::{ - standard::{Args, CommandResult, Delimiter}, - Framework, - }, + framework::{standard::CommandResult, Framework}, futures::prelude::future::BoxFuture, http::Http, model::{ channel::{Channel, GuildChannel, Message}, guild::{Guild, Member}, id::{ChannelId, GuildId, UserId}, - interactions::Interaction, + interactions::{ApplicationCommand, Interaction, InteractionType}, prelude::{ApplicationCommandOptionType, InteractionResponseType}, }, prelude::TypeMapKey, @@ -24,11 +21,9 @@ use log::{error, info, warn}; use regex::{Match, Regex, RegexBuilder}; -use std::{collections::HashMap, env, fmt}; +use std::{collections::HashMap, env, fmt, sync::Arc}; use crate::{guild_data::CtxGuildData, MySQL}; -use serenity::model::prelude::InteractionType; -use std::sync::Arc; type CommandFn = for<'fut> fn( &'fut Context, @@ -36,6 +31,54 @@ type CommandFn = for<'fut> fn( Args, ) -> BoxFuture<'fut, CommandResult>; +pub struct Args { + args: HashMap, +} + +impl Args { + pub fn from(message: &str, arg_schema: &'static [&'static Arg]) -> Self { + // construct regex from arg schema + let mut re = arg_schema + .iter() + .map(|a| a.to_regex()) + .collect::>() + .join(r#"\s*"#); + + re.push_str("$"); + + let regex = Regex::new(&re).unwrap(); + let capture_names = regex.capture_names(); + let captures = regex.captures(message); + + let mut args = HashMap::new(); + + if let Some(captures) = captures { + for name in capture_names.filter(|n| n.is_some()).map(|n| n.unwrap()) { + args.insert( + name.to_string(), + captures.name(name).unwrap().as_str().to_string(), + ); + } + } + + Self { args } + } + + pub fn len(&self) -> usize { + self.args.len() + } + + pub fn is_empty(&self) -> bool { + self.args.is_empty() + } + + pub fn named(&self, name: D) -> Option<&String> { + let name = name.to_string(); + + self.args.get(&name) + } +} + pub struct CreateGenericResponse { content: String, embed: Option, @@ -203,6 +246,23 @@ pub struct Arg { pub required: bool, } +impl Arg { + pub fn to_regex(&self) -> String { + match self.kind { + ApplicationCommandOptionType::String => format!(r#"(?P<{}>.*?)"#, self.name), + ApplicationCommandOptionType::Integer => format!(r#"(?P<{}>\d+)"#, self.name), + ApplicationCommandOptionType::Boolean => format!(r#"(?P<{0}>{0})?"#, self.name), + ApplicationCommandOptionType::User => format!(r#"<(@|@!)(?P<{}>\d+)>"#, self.name), + ApplicationCommandOptionType::Channel => format!(r#"<#(?P<{}>\d+)>"#, self.name), + ApplicationCommandOptionType::Role => format!(r#"<@&(?P<{}>\d+)>"#, self.name), + ApplicationCommandOptionType::Mentionable => { + format!(r#"<(?P<{0}_pref>@|@!|@&|#)(?P<{0}>\d+)>"#, self.name) + } + _ => String::new(), + } + } +} + pub struct Command { pub fun: CommandFn, pub names: &'static [&'static str], @@ -403,7 +463,29 @@ impl RegexFramework { count += 1; } } else { - // register application commands globally + for (handle, command) in self.commands.iter().filter(|(_, c)| c.allow_slash) { + ApplicationCommand::create_global_application_command(&http, |a| { + a.name(handle).description(command.desc); + + for arg in command.args { + a.create_option(|o| { + o.name(arg.name) + .description(arg.description) + .kind(arg.kind) + .required(arg.required) + }); + } + + a + }) + .await + .expect(&format!( + "Failed to create application command for {}", + handle + )); + + count += 1; + } } info!("{} slash commands built! Ready to go", count); @@ -411,40 +493,48 @@ impl RegexFramework { pub async fn execute(&self, ctx: Context, interaction: Interaction) { if interaction.kind == InteractionType::ApplicationCommand { - let command = { - let name = &interaction.data.as_ref().unwrap().name; + if let Some(data) = interaction.data.clone() { + let command = { + let name = data.name; - self.commands - .get(name) - .expect(&format!("Received invalid command: {}", name)) - }; + self.commands + .get(&name) + .expect(&format!("Received invalid command: {}", name)) + }; - if command - .check_permissions( - &ctx, - &interaction.guild(ctx.cache.clone()).await.unwrap(), - &interaction.member(&ctx).await.unwrap(), - ) - .await - { - (command.fun)(&ctx, &interaction, Args::new("", &[Delimiter::Single(' ')])) + if command + .check_permissions( + &ctx, + &interaction.guild(ctx.cache.clone()).await.unwrap(), + &interaction.member(&ctx).await.unwrap(), + ) .await - .unwrap(); - } else if command.required_permissions == PermissionLevel::Managed { - let _ = interaction - .respond( - ctx.http.clone(), - CreateGenericResponse::new().content("You must either be an Admin or have a role specified in `?roles` to do this command") - ) - .await; - } else if command.required_permissions == PermissionLevel::Restricted { - let _ = interaction - .respond( - ctx.http.clone(), - CreateGenericResponse::new() - .content("You must be an Admin to do this command"), - ) - .await; + { + let mut args = HashMap::new(); + + for arg in data.options.iter().filter(|o| o.value.is_some()) { + args.insert(arg.name.clone(), arg.value.clone().unwrap().to_string()); + } + + (command.fun)(&ctx, &interaction, Args { args }) + .await + .unwrap(); + } else if command.required_permissions == PermissionLevel::Managed { + let _ = interaction + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("You must either be an Admin or have a role specified in `?roles` to do this command") + ) + .await; + } else if command.required_permissions == PermissionLevel::Restricted { + let _ = interaction + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("You must be an Admin to do this command"), + ) + .await; + } } } } @@ -513,13 +603,9 @@ impl Framework for RegexFramework { let member = guild.member(&ctx, &msg.author).await.unwrap(); if command.check_permissions(&ctx, &guild, &member).await { - (command.fun)( - &ctx, - &msg, - Args::new(&args, &[Delimiter::Single(' ')]), - ) - .await - .unwrap(); + (command.fun)(&ctx, &msg, Args::from(&args, command.args)) + .await + .unwrap(); } else if command.required_permissions == PermissionLevel::Managed { let _ = msg.channel_id.say(&ctx, "You must either be an Admin or have a role specified in `?roles` to do this command").await; } else if command.required_permissions diff --git a/src/main.rs b/src/main.rs index 195ae1d..9ca7eae 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ mod sound; use crate::{ event_handlers::{Handler, RestartTrack}, - framework::{CommandInvoke, CreateGenericResponse, RegexFramework}, + framework::{Args, CommandInvoke, CreateGenericResponse, RegexFramework}, guild_data::{CtxGuildData, GuildData}, sound::{JoinSoundCtx, Sound}, }; @@ -20,7 +20,7 @@ use regex_command_attr::command; use serenity::{ client::{bridge::gateway::GatewayIntents, Client, Context}, - framework::standard::{Args, CommandResult}, + framework::standard::CommandResult, http::Http, model::{ guild::Guild, @@ -297,41 +297,19 @@ async fn main() -> Result<(), Box> { #[command] #[description("Get information on the commands of the bot")] +#[arg( + name = "category", + description = "Get help for a specific category", + kind = "String", + required = false +)] async fn help( ctx: &Context, invoke: &(dyn CommandInvoke + Sync + Send), args: Args, ) -> CommandResult { - if args.is_empty() { - let description = { - let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap(); - - let read_lock = guild_data.read().await; - - format!( - "Type `{}help category` to view help for a command category below:", - read_lock.prefix - ) - }; - - invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new().embed(|e| { - e.title("Help") - .color(THEME_COLOR) - .description(description) - .field("Info", "`help` `info` `invite` `donate`", false) - .field("Play", "`play` `p` `stop` `dc` `loop`", false) - .field("Manage", "`upload` `delete` `list` `public`", false) - .field("Settings", "`prefix` `roles` `volume` `allow_greet`", false) - .field("Search", "`search` `random` `popular`", false) - .field("Other", "`greet` `ambience`", false) - }), - ) - .await?; - } else { - let body = match args.rest().to_lowercase().as_str() { + if let Some(category) = args.named("category") { + let body = match category.to_lowercase().as_str() { "info" => { "__Info Commands__ `help` - view all commands @@ -421,6 +399,34 @@ Please select a category from the following: .embed(|e| e.title("Help").color(THEME_COLOR).description(body)), ) .await?; + } else { + let description = { + let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap(); + + let read_lock = guild_data.read().await; + + format!( + "Type `{}help category` to view help for a command category below:", + read_lock.prefix + ) + }; + + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().embed(|e| { + e.title("Help") + .color(THEME_COLOR) + .description(description) + .field("Info", "`help` `info` `invite` `donate`", false) + .field("Play", "`play` `p` `stop` `dc` `loop`", false) + .field("Manage", "`upload` `delete` `list` `public`", false) + .field("Settings", "`prefix` `roles` `volume` `allow_greet`", false) + .field("Search", "`search` `random` `popular`", false) + .field("Other", "`greet` `ambience`", false) + }), + ) + .await?; } Ok(()) @@ -436,12 +442,6 @@ Please select a category from the following: kind = "String", required = true )] -#[arg( - name = "loop", - description = "Whether to loop the sound or not (default: no)", - kind = "Boolean", - required = false -)] async fn play( ctx: &Context, invoke: &(dyn CommandInvoke + Sync + Send), @@ -497,7 +497,7 @@ async fn play_cmd(ctx: &Context, guild: Guild, user_id: UserId, args: Args, loop match channel_to_join { Some(user_channel) => { - let search_term = args.rest(); + let search_term = args.named("query").unwrap(); let pool = ctx .data @@ -569,7 +569,7 @@ async fn play_ambience( match channel_to_join { Some(user_channel) => { - let search_name = args.rest().to_lowercase(); + let search_name = args.named("query").unwrap().to_lowercase(); let audio_index = ctx.data.read().await.get::().cloned().unwrap(); if let Some(filename) = audio_index.get(&search_name) { @@ -724,10 +724,16 @@ There is a maximum sound limit per user. This can be removed by subscribing at * #[aliases("vol")] #[required_permissions(Managed)] #[description("Change the bot's volume in this server")] +#[arg( + name = "volume", + description = "New volume for the bot to use", + kind = "Integer", + required = false +)] async fn change_volume( ctx: &Context, invoke: &(dyn CommandInvoke + Sync + Send), - mut args: Args, + args: Args, ) -> CommandResult { let pool = ctx .data @@ -740,36 +746,17 @@ async fn change_volume( let guild_data_opt = ctx.guild_data(invoke.guild_id().unwrap()).await; let guild_data = guild_data_opt.unwrap(); - if args.len() == 1 { - match args.single::() { - Ok(volume) => { - guild_data.write().await.volume = volume; + if let Some(volume) = args.named("volume").map(|i| i.parse::().ok()).flatten() { + guild_data.write().await.volume = volume; - guild_data.read().await.commit(pool).await?; + guild_data.read().await.commit(pool).await?; - invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new() - .content(format!("Volume changed to {}%", volume)), - ) - .await?; - } - - Err(_) => { - let read = guild_data.read().await; - - invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new().content(format!( - "Current server volume: {vol}%. Change the volume with `/volume `", - vol = read.volume - )), - ) - .await?; - } - } + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!("Volume changed to {}%", volume)), + ) + .await?; } else { let read = guild_data.read().await; @@ -793,7 +780,7 @@ async fn change_volume( async fn change_prefix( ctx: &Context, invoke: &(dyn CommandInvoke + Sync + Send), - mut args: Args, + args: Args, ) -> CommandResult { let pool = ctx .data @@ -811,50 +798,34 @@ async fn change_prefix( guild_data = guild_data_opt.unwrap(); } - if args.len() == 1 { - match args.single::() { - Ok(prefix) => { - if prefix.len() <= 5 { - let reply = format!("Prefix changed to `{}`", prefix); + if let Some(prefix) = args.named("prefix") { + if prefix.len() <= 5 { + let reply = format!("Prefix changed to `{}`", prefix); - { - guild_data.write().await.prefix = prefix; - } - - { - let read = guild_data.read().await; - - read.commit(pool).await?; - } - - invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new().content(reply), - ) - .await?; - } else { - invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new() - .content("Prefix must be less than 5 characters long"), - ) - .await?; - } + { + guild_data.write().await.prefix = prefix.to_string(); } - Err(_) => { - invoke - .respond( - ctx.http.clone(), - CreateGenericResponse::new().content(format!( - "Usage: `{prefix}prefix `", - prefix = guild_data.read().await.prefix - )), - ) - .await?; + { + let read = guild_data.read().await; + + read.commit(pool).await?; } + + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(reply), + ) + .await?; + } else { + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Prefix must be less than 5 characters long"), + ) + .await?; } } else { invoke @@ -873,6 +844,12 @@ async fn change_prefix( #[command("upload")] #[allow_slash(false)] +#[arg( + name = "name", + description = "Name to upload sound to", + kind = "String", + required = true +)] async fn upload_new_sound( ctx: &Context, invoke: &(dyn CommandInvoke + Sync + Send), @@ -891,7 +868,10 @@ async fn upload_new_sound( true } - let new_name = args.rest().to_string(); + let new_name = args + .named("name") + .map(|n| n.to_string()) + .unwrap_or(String::new()); if !new_name.is_empty() && new_name.len() <= 20 { if !is_numeric(&new_name) { @@ -1023,7 +1003,7 @@ async fn set_allowed_roles( .cloned() .expect("Could not get SQLPool from data"); - if args.len() == 0 { + if args.is_empty() { let roles = sqlx::query!( " SELECT role @@ -1117,7 +1097,7 @@ async fn list_sounds( let sounds; let mut message_buffer; - if args.rest() == "me" { + if args.named("me").is_some() { sounds = Sound::get_user_sounds(invoke.author_id(), pool).await?; message_buffer = "All your sounds: ".to_string(); @@ -1178,7 +1158,7 @@ async fn change_public( let uid = invoke.author_id().as_u64().to_owned(); - let name = args.rest(); + let name = args.named("query").unwrap(); let gid = *invoke.guild_id().unwrap().as_u64(); let mut sound_vec = Sound::search_for_sound(name, gid, uid, pool.clone(), true).await?; @@ -1245,7 +1225,7 @@ async fn delete_sound( let uid = invoke.author_id().0; let gid = invoke.guild_id().unwrap().0; - let name = args.rest(); + let name = args.named("query").unwrap(); let sound_vec = Sound::search_for_sound(name, gid, uid, pool.clone(), true).await?; let sound_result = sound_vec.first(); @@ -1347,7 +1327,7 @@ async fn search_sounds( .cloned() .expect("Could not get SQLPool from data"); - let query = args.rest(); + let query = args.named("query").unwrap(); let search_results = Sound::search_for_sound( query, @@ -1451,7 +1431,7 @@ async fn set_greet_sound( .cloned() .expect("Could not get SQLPool from data"); - let query = args.rest(); + let query = args.named("query").unwrap(); let user_id = invoke.author_id(); if query.len() == 0 {