aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa

This commit is contained in:
jellywx 2021-06-14 21:35:38 +01:00
parent b8bbfbfade
commit 60ead9a1ef
3 changed files with 231 additions and 165 deletions

View File

@ -171,7 +171,7 @@ pub fn create_declaration_validations(fun: &mut CommandFun) -> SynResult<()> {
let context: Type = parse_quote!(&serenity::client::Context); let context: Type = parse_quote!(&serenity::client::Context);
let message: Type = parse_quote!(&(dyn crate::framework::CommandInvoke + Sync + Send)); let message: Type = parse_quote!(&(dyn crate::framework::CommandInvoke + Sync + Send));
let args: Type = parse_quote!(serenity::framework::standard::Args); let args: Type = parse_quote!(crate::framework::Args);
let mut index = 0; let mut index = 0;

View File

@ -3,17 +3,14 @@ use serenity::{
builder::CreateEmbed, builder::CreateEmbed,
cache::Cache, cache::Cache,
client::Context, client::Context,
framework::{ framework::{standard::CommandResult, Framework},
standard::{Args, CommandResult, Delimiter},
Framework,
},
futures::prelude::future::BoxFuture, futures::prelude::future::BoxFuture,
http::Http, http::Http,
model::{ model::{
channel::{Channel, GuildChannel, Message}, channel::{Channel, GuildChannel, Message},
guild::{Guild, Member}, guild::{Guild, Member},
id::{ChannelId, GuildId, UserId}, id::{ChannelId, GuildId, UserId},
interactions::Interaction, interactions::{ApplicationCommand, Interaction, InteractionType},
prelude::{ApplicationCommandOptionType, InteractionResponseType}, prelude::{ApplicationCommandOptionType, InteractionResponseType},
}, },
prelude::TypeMapKey, prelude::TypeMapKey,
@ -24,11 +21,9 @@ use log::{error, info, warn};
use regex::{Match, Regex, RegexBuilder}; use regex::{Match, Regex, RegexBuilder};
use std::{collections::HashMap, env, fmt}; use std::{collections::HashMap, env, fmt, sync::Arc};
use crate::{guild_data::CtxGuildData, MySQL}; use crate::{guild_data::CtxGuildData, MySQL};
use serenity::model::prelude::InteractionType;
use std::sync::Arc;
type CommandFn = for<'fut> fn( type CommandFn = for<'fut> fn(
&'fut Context, &'fut Context,
@ -36,6 +31,54 @@ type CommandFn = for<'fut> fn(
Args, Args,
) -> BoxFuture<'fut, CommandResult>; ) -> BoxFuture<'fut, CommandResult>;
pub struct Args {
args: HashMap<String, String>,
}
impl Args {
pub fn from(message: &str, arg_schema: &'static [&'static Arg]) -> Self {
// construct regex from arg schema
let mut re = arg_schema
.iter()
.map(|a| a.to_regex())
.collect::<Vec<String>>()
.join(r#"\s*"#);
re.push_str("$");
let regex = Regex::new(&re).unwrap();
let capture_names = regex.capture_names();
let captures = regex.captures(message);
let mut args = HashMap::new();
if let Some(captures) = captures {
for name in capture_names.filter(|n| n.is_some()).map(|n| n.unwrap()) {
args.insert(
name.to_string(),
captures.name(name).unwrap().as_str().to_string(),
);
}
}
Self { args }
}
pub fn len(&self) -> usize {
self.args.len()
}
pub fn is_empty(&self) -> bool {
self.args.is_empty()
}
pub fn named<D: ToString>(&self, name: D) -> Option<&String> {
let name = name.to_string();
self.args.get(&name)
}
}
pub struct CreateGenericResponse { pub struct CreateGenericResponse {
content: String, content: String,
embed: Option<CreateEmbed>, embed: Option<CreateEmbed>,
@ -203,6 +246,23 @@ pub struct Arg {
pub required: bool, pub required: bool,
} }
impl Arg {
pub fn to_regex(&self) -> String {
match self.kind {
ApplicationCommandOptionType::String => format!(r#"(?P<{}>.*?)"#, self.name),
ApplicationCommandOptionType::Integer => format!(r#"(?P<{}>\d+)"#, self.name),
ApplicationCommandOptionType::Boolean => format!(r#"(?P<{0}>{0})?"#, self.name),
ApplicationCommandOptionType::User => format!(r#"<(@|@!)(?P<{}>\d+)>"#, self.name),
ApplicationCommandOptionType::Channel => format!(r#"<#(?P<{}>\d+)>"#, self.name),
ApplicationCommandOptionType::Role => format!(r#"<@&(?P<{}>\d+)>"#, self.name),
ApplicationCommandOptionType::Mentionable => {
format!(r#"<(?P<{0}_pref>@|@!|@&|#)(?P<{0}>\d+)>"#, self.name)
}
_ => String::new(),
}
}
}
pub struct Command { pub struct Command {
pub fun: CommandFn, pub fun: CommandFn,
pub names: &'static [&'static str], pub names: &'static [&'static str],
@ -403,7 +463,29 @@ impl RegexFramework {
count += 1; count += 1;
} }
} else { } else {
// register application commands globally for (handle, command) in self.commands.iter().filter(|(_, c)| c.allow_slash) {
ApplicationCommand::create_global_application_command(&http, |a| {
a.name(handle).description(command.desc);
for arg in command.args {
a.create_option(|o| {
o.name(arg.name)
.description(arg.description)
.kind(arg.kind)
.required(arg.required)
});
}
a
})
.await
.expect(&format!(
"Failed to create application command for {}",
handle
));
count += 1;
}
} }
info!("{} slash commands built! Ready to go", count); info!("{} slash commands built! Ready to go", count);
@ -411,11 +493,12 @@ impl RegexFramework {
pub async fn execute(&self, ctx: Context, interaction: Interaction) { pub async fn execute(&self, ctx: Context, interaction: Interaction) {
if interaction.kind == InteractionType::ApplicationCommand { if interaction.kind == InteractionType::ApplicationCommand {
if let Some(data) = interaction.data.clone() {
let command = { let command = {
let name = &interaction.data.as_ref().unwrap().name; let name = data.name;
self.commands self.commands
.get(name) .get(&name)
.expect(&format!("Received invalid command: {}", name)) .expect(&format!("Received invalid command: {}", name))
}; };
@ -427,7 +510,13 @@ impl RegexFramework {
) )
.await .await
{ {
(command.fun)(&ctx, &interaction, Args::new("", &[Delimiter::Single(' ')])) let mut args = HashMap::new();
for arg in data.options.iter().filter(|o| o.value.is_some()) {
args.insert(arg.name.clone(), arg.value.clone().unwrap().to_string());
}
(command.fun)(&ctx, &interaction, Args { args })
.await .await
.unwrap(); .unwrap();
} else if command.required_permissions == PermissionLevel::Managed { } else if command.required_permissions == PermissionLevel::Managed {
@ -449,6 +538,7 @@ impl RegexFramework {
} }
} }
} }
}
enum PermissionCheck { enum PermissionCheck {
None, // No permissions None, // No permissions
@ -513,11 +603,7 @@ impl Framework for RegexFramework {
let member = guild.member(&ctx, &msg.author).await.unwrap(); let member = guild.member(&ctx, &msg.author).await.unwrap();
if command.check_permissions(&ctx, &guild, &member).await { if command.check_permissions(&ctx, &guild, &member).await {
(command.fun)( (command.fun)(&ctx, &msg, Args::from(&args, command.args))
&ctx,
&msg,
Args::new(&args, &[Delimiter::Single(' ')]),
)
.await .await
.unwrap(); .unwrap();
} else if command.required_permissions == PermissionLevel::Managed { } else if command.required_permissions == PermissionLevel::Managed {

View File

@ -9,7 +9,7 @@ mod sound;
use crate::{ use crate::{
event_handlers::{Handler, RestartTrack}, event_handlers::{Handler, RestartTrack},
framework::{CommandInvoke, CreateGenericResponse, RegexFramework}, framework::{Args, CommandInvoke, CreateGenericResponse, RegexFramework},
guild_data::{CtxGuildData, GuildData}, guild_data::{CtxGuildData, GuildData},
sound::{JoinSoundCtx, Sound}, sound::{JoinSoundCtx, Sound},
}; };
@ -20,7 +20,7 @@ use regex_command_attr::command;
use serenity::{ use serenity::{
client::{bridge::gateway::GatewayIntents, Client, Context}, client::{bridge::gateway::GatewayIntents, Client, Context},
framework::standard::{Args, CommandResult}, framework::standard::CommandResult,
http::Http, http::Http,
model::{ model::{
guild::Guild, guild::Guild,
@ -297,41 +297,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
#[command] #[command]
#[description("Get information on the commands of the bot")] #[description("Get information on the commands of the bot")]
#[arg(
name = "category",
description = "Get help for a specific category",
kind = "String",
required = false
)]
async fn help( async fn help(
ctx: &Context, ctx: &Context,
invoke: &(dyn CommandInvoke + Sync + Send), invoke: &(dyn CommandInvoke + Sync + Send),
args: Args, args: Args,
) -> CommandResult { ) -> CommandResult {
if args.is_empty() { if let Some(category) = args.named("category") {
let description = { let body = match category.to_lowercase().as_str() {
let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap();
let read_lock = guild_data.read().await;
format!(
"Type `{}help category` to view help for a command category below:",
read_lock.prefix
)
};
invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Help")
.color(THEME_COLOR)
.description(description)
.field("Info", "`help` `info` `invite` `donate`", false)
.field("Play", "`play` `p` `stop` `dc` `loop`", false)
.field("Manage", "`upload` `delete` `list` `public`", false)
.field("Settings", "`prefix` `roles` `volume` `allow_greet`", false)
.field("Search", "`search` `random` `popular`", false)
.field("Other", "`greet` `ambience`", false)
}),
)
.await?;
} else {
let body = match args.rest().to_lowercase().as_str() {
"info" => { "info" => {
"__Info Commands__ "__Info Commands__
`help` - view all commands `help` - view all commands
@ -421,6 +399,34 @@ Please select a category from the following:
.embed(|e| e.title("Help").color(THEME_COLOR).description(body)), .embed(|e| e.title("Help").color(THEME_COLOR).description(body)),
) )
.await?; .await?;
} else {
let description = {
let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap();
let read_lock = guild_data.read().await;
format!(
"Type `{}help category` to view help for a command category below:",
read_lock.prefix
)
};
invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Help")
.color(THEME_COLOR)
.description(description)
.field("Info", "`help` `info` `invite` `donate`", false)
.field("Play", "`play` `p` `stop` `dc` `loop`", false)
.field("Manage", "`upload` `delete` `list` `public`", false)
.field("Settings", "`prefix` `roles` `volume` `allow_greet`", false)
.field("Search", "`search` `random` `popular`", false)
.field("Other", "`greet` `ambience`", false)
}),
)
.await?;
} }
Ok(()) Ok(())
@ -436,12 +442,6 @@ Please select a category from the following:
kind = "String", kind = "String",
required = true required = true
)] )]
#[arg(
name = "loop",
description = "Whether to loop the sound or not (default: no)",
kind = "Boolean",
required = false
)]
async fn play( async fn play(
ctx: &Context, ctx: &Context,
invoke: &(dyn CommandInvoke + Sync + Send), invoke: &(dyn CommandInvoke + Sync + Send),
@ -497,7 +497,7 @@ async fn play_cmd(ctx: &Context, guild: Guild, user_id: UserId, args: Args, loop
match channel_to_join { match channel_to_join {
Some(user_channel) => { Some(user_channel) => {
let search_term = args.rest(); let search_term = args.named("query").unwrap();
let pool = ctx let pool = ctx
.data .data
@ -569,7 +569,7 @@ async fn play_ambience(
match channel_to_join { match channel_to_join {
Some(user_channel) => { Some(user_channel) => {
let search_name = args.rest().to_lowercase(); let search_name = args.named("query").unwrap().to_lowercase();
let audio_index = ctx.data.read().await.get::<AudioIndex>().cloned().unwrap(); let audio_index = ctx.data.read().await.get::<AudioIndex>().cloned().unwrap();
if let Some(filename) = audio_index.get(&search_name) { if let Some(filename) = audio_index.get(&search_name) {
@ -724,10 +724,16 @@ There is a maximum sound limit per user. This can be removed by subscribing at *
#[aliases("vol")] #[aliases("vol")]
#[required_permissions(Managed)] #[required_permissions(Managed)]
#[description("Change the bot's volume in this server")] #[description("Change the bot's volume in this server")]
#[arg(
name = "volume",
description = "New volume for the bot to use",
kind = "Integer",
required = false
)]
async fn change_volume( async fn change_volume(
ctx: &Context, ctx: &Context,
invoke: &(dyn CommandInvoke + Sync + Send), invoke: &(dyn CommandInvoke + Sync + Send),
mut args: Args, args: Args,
) -> CommandResult { ) -> CommandResult {
let pool = ctx let pool = ctx
.data .data
@ -740,9 +746,7 @@ async fn change_volume(
let guild_data_opt = ctx.guild_data(invoke.guild_id().unwrap()).await; let guild_data_opt = ctx.guild_data(invoke.guild_id().unwrap()).await;
let guild_data = guild_data_opt.unwrap(); let guild_data = guild_data_opt.unwrap();
if args.len() == 1 { if let Some(volume) = args.named("volume").map(|i| i.parse::<u8>().ok()).flatten() {
match args.single::<u8>() {
Ok(volume) => {
guild_data.write().await.volume = volume; guild_data.write().await.volume = volume;
guild_data.read().await.commit(pool).await?; guild_data.read().await.commit(pool).await?;
@ -750,26 +754,9 @@ async fn change_volume(
invoke invoke
.respond( .respond(
ctx.http.clone(), ctx.http.clone(),
CreateGenericResponse::new() CreateGenericResponse::new().content(format!("Volume changed to {}%", volume)),
.content(format!("Volume changed to {}%", volume)),
) )
.await?; .await?;
}
Err(_) => {
let read = guild_data.read().await;
invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content(format!(
"Current server volume: {vol}%. Change the volume with `/volume <new volume>`",
vol = read.volume
)),
)
.await?;
}
}
} else { } else {
let read = guild_data.read().await; let read = guild_data.read().await;
@ -793,7 +780,7 @@ async fn change_volume(
async fn change_prefix( async fn change_prefix(
ctx: &Context, ctx: &Context,
invoke: &(dyn CommandInvoke + Sync + Send), invoke: &(dyn CommandInvoke + Sync + Send),
mut args: Args, args: Args,
) -> CommandResult { ) -> CommandResult {
let pool = ctx let pool = ctx
.data .data
@ -811,14 +798,12 @@ async fn change_prefix(
guild_data = guild_data_opt.unwrap(); guild_data = guild_data_opt.unwrap();
} }
if args.len() == 1 { if let Some(prefix) = args.named("prefix") {
match args.single::<String>() {
Ok(prefix) => {
if prefix.len() <= 5 { if prefix.len() <= 5 {
let reply = format!("Prefix changed to `{}`", prefix); let reply = format!("Prefix changed to `{}`", prefix);
{ {
guild_data.write().await.prefix = prefix; guild_data.write().await.prefix = prefix.to_string();
} }
{ {
@ -842,20 +827,6 @@ async fn change_prefix(
) )
.await?; .await?;
} }
}
Err(_) => {
invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content(format!(
"Usage: `{prefix}prefix <new prefix>`",
prefix = guild_data.read().await.prefix
)),
)
.await?;
}
}
} else { } else {
invoke invoke
.respond( .respond(
@ -873,6 +844,12 @@ async fn change_prefix(
#[command("upload")] #[command("upload")]
#[allow_slash(false)] #[allow_slash(false)]
#[arg(
name = "name",
description = "Name to upload sound to",
kind = "String",
required = true
)]
async fn upload_new_sound( async fn upload_new_sound(
ctx: &Context, ctx: &Context,
invoke: &(dyn CommandInvoke + Sync + Send), invoke: &(dyn CommandInvoke + Sync + Send),
@ -891,7 +868,10 @@ async fn upload_new_sound(
true true
} }
let new_name = args.rest().to_string(); let new_name = args
.named("name")
.map(|n| n.to_string())
.unwrap_or(String::new());
if !new_name.is_empty() && new_name.len() <= 20 { if !new_name.is_empty() && new_name.len() <= 20 {
if !is_numeric(&new_name) { if !is_numeric(&new_name) {
@ -1023,7 +1003,7 @@ async fn set_allowed_roles(
.cloned() .cloned()
.expect("Could not get SQLPool from data"); .expect("Could not get SQLPool from data");
if args.len() == 0 { if args.is_empty() {
let roles = sqlx::query!( let roles = sqlx::query!(
" "
SELECT role SELECT role
@ -1117,7 +1097,7 @@ async fn list_sounds(
let sounds; let sounds;
let mut message_buffer; let mut message_buffer;
if args.rest() == "me" { if args.named("me").is_some() {
sounds = Sound::get_user_sounds(invoke.author_id(), pool).await?; sounds = Sound::get_user_sounds(invoke.author_id(), pool).await?;
message_buffer = "All your sounds: ".to_string(); message_buffer = "All your sounds: ".to_string();
@ -1178,7 +1158,7 @@ async fn change_public(
let uid = invoke.author_id().as_u64().to_owned(); let uid = invoke.author_id().as_u64().to_owned();
let name = args.rest(); let name = args.named("query").unwrap();
let gid = *invoke.guild_id().unwrap().as_u64(); let gid = *invoke.guild_id().unwrap().as_u64();
let mut sound_vec = Sound::search_for_sound(name, gid, uid, pool.clone(), true).await?; let mut sound_vec = Sound::search_for_sound(name, gid, uid, pool.clone(), true).await?;
@ -1245,7 +1225,7 @@ async fn delete_sound(
let uid = invoke.author_id().0; let uid = invoke.author_id().0;
let gid = invoke.guild_id().unwrap().0; let gid = invoke.guild_id().unwrap().0;
let name = args.rest(); let name = args.named("query").unwrap();
let sound_vec = Sound::search_for_sound(name, gid, uid, pool.clone(), true).await?; let sound_vec = Sound::search_for_sound(name, gid, uid, pool.clone(), true).await?;
let sound_result = sound_vec.first(); let sound_result = sound_vec.first();
@ -1347,7 +1327,7 @@ async fn search_sounds(
.cloned() .cloned()
.expect("Could not get SQLPool from data"); .expect("Could not get SQLPool from data");
let query = args.rest(); let query = args.named("query").unwrap();
let search_results = Sound::search_for_sound( let search_results = Sound::search_for_sound(
query, query,
@ -1451,7 +1431,7 @@ async fn set_greet_sound(
.cloned() .cloned()
.expect("Could not get SQLPool from data"); .expect("Could not get SQLPool from data");
let query = args.rest(); let query = args.named("query").unwrap();
let user_id = invoke.author_id(); let user_id = invoke.author_id();
if query.len() == 0 { if query.len() == 0 {