From 0abe696f589e3a1e3cd478b4929fae45788f92cd Mon Sep 17 00:00:00 2001 From: jellywx Date: Tue, 15 Jun 2021 12:09:48 +0100 Subject: [PATCH] added an env var to enable slash command building. added responses to the stop commands. fixed something with boolean arguments. blah blah blah --- Cargo.lock | 4 +-- Cargo.toml | 2 +- src/framework.rs | 69 +++++++++++++++++++++++++++++++++++++++--------- src/main.rs | 22 ++++++++++----- 4 files changed, 75 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b2b7917..15668d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1691,8 +1691,8 @@ dependencies = [ [[package]] name = "songbird" -version = "0.2.0-beta.2" -source = "git+https://github.com/FelixMcFelix/songbird?branch=ws-fix#a9ae0528405ed783153175f8d4d52dd51532aef3" +version = "0.2.0-beta.3" +source = "git+https://github.com/serenity-rs/songbird?branch=next#e0ea2f5fe2fe14cb2de0774227f27cb1175d8295" dependencies = [ "async-trait", "async-tungstenite", diff --git a/Cargo.toml b/Cargo.toml index ac0f01e..7dbafe0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ authors = ["jellywx "] edition = "2018" [dependencies] -songbird = { git = "https://github.com/FelixMcFelix/songbird", branch = "ws-fix" } +songbird = { git = "https://github.com/serenity-rs/songbird", branch = "next" } serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["voice", "collector", "unstable_discord_api"] } sqlx = { version = "0.5", default-features = false, features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal"] } dotenv = "0.15" diff --git a/src/framework.rs b/src/framework.rs index e5aa010..69d13cd 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -21,9 +21,15 @@ use log::{error, info, warn}; use regex::{Match, Regex, RegexBuilder}; -use std::{collections::HashMap, env, fmt, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + env, fmt, + hash::{Hash, Hasher}, + sync::Arc, +}; use crate::{guild_data::CtxGuildData, MySQL}; +use serde_json::Value; type CommandFn = for<'fut> fn( &'fut Context, @@ -54,10 +60,9 @@ impl Args { if let Some(captures) = captures { for name in capture_names.filter(|n| n.is_some()).map(|n| n.unwrap()) { - args.insert( - name.to_string(), - captures.name(name).unwrap().as_str().to_string(), - ); + if let Some(cap) = captures.name(name) { + args.insert(name.to_string(), cap.as_str().to_string()); + } } } @@ -274,6 +279,20 @@ pub struct Command { pub args: &'static [&'static Arg], } +impl Hash for Command { + fn hash(&self, state: &mut H) { + self.names[0].hash(state) + } +} + +impl PartialEq for Command { + fn eq(&self, other: &Self) -> bool { + self.names[0] == other.names[0] + } +} + +impl Eq for Command {} + impl Command { async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool { if self.required_permissions == PermissionLevel::Unrestricted { @@ -348,6 +367,7 @@ impl fmt::Debug for Command { pub struct RegexFramework { commands: HashMap, + commands_: HashSet<&'static Command>, command_matcher: Regex, default_prefix: String, client_id: u64, @@ -363,6 +383,7 @@ impl RegexFramework { pub fn new>(client_id: T) -> Self { Self { commands: HashMap::new(), + commands_: HashSet::new(), command_matcher: Regex::new(r#"^$"#).unwrap(), default_prefix: "".to_string(), client_id: client_id.into(), @@ -392,6 +413,8 @@ impl RegexFramework { pub fn add_command(mut self, command: &'static Command) -> Self { info!("{:?}", command); + self.commands_.insert(command); + for name in command.names { self.commands.insert(name.to_string(), command); } @@ -428,6 +451,11 @@ impl RegexFramework { } pub async fn build_slash(&self, http: impl AsRef) { + if env::var("REBUILD_COMMANDS").is_err() { + info!("No rebuild"); + return; + } + info!("Building slash commands..."); let mut count = 0; @@ -438,10 +466,10 @@ impl RegexFramework { .flatten() .map(|v| GuildId(v)) { - for (handle, command) in self.commands.iter().filter(|(_, c)| c.allow_slash) { + for command in self.commands_.iter().filter(|c| c.allow_slash) { guild_id .create_application_command(&http, |a| { - a.name(handle).description(command.desc); + a.name(command.names[0]).description(command.desc); for arg in command.args { a.create_option(|o| { @@ -457,15 +485,15 @@ impl RegexFramework { .await .expect(&format!( "Failed to create application command for {}", - handle + command.names[0] )); count += 1; } } else { - for (handle, command) in self.commands.iter().filter(|(_, c)| c.allow_slash) { + for command in self.commands_.iter().filter(|c| c.allow_slash) { ApplicationCommand::create_global_application_command(&http, |a| { - a.name(handle).description(command.desc); + a.name(command.names[0]).description(command.desc); for arg in command.args { a.create_option(|o| { @@ -481,7 +509,7 @@ impl RegexFramework { .await .expect(&format!( "Failed to create application command for {}", - handle + command.names[0] )); count += 1; @@ -492,7 +520,8 @@ impl RegexFramework { } pub async fn execute(&self, ctx: Context, interaction: Interaction) { - if interaction.kind == InteractionType::ApplicationCommand { + if interaction.kind == InteractionType::ApplicationCommand && interaction.guild_id.is_some() + { if let Some(data) = interaction.data.clone() { let command = { let name = data.name; @@ -513,7 +542,21 @@ impl RegexFramework { let mut args = HashMap::new(); for arg in data.options.iter().filter(|o| o.value.is_some()) { - args.insert(arg.name.clone(), arg.value.clone().unwrap().to_string()); + args.insert( + arg.name.clone(), + match arg.value.clone().unwrap() { + Value::Bool(b) => { + if b { + arg.name.clone() + } else { + String::new() + } + } + Value::Number(n) => n.to_string(), + Value::String(s) => s, + _ => String::new(), + }, + ); } (command.fun)(&ctx, &interaction, Args { args }) diff --git a/src/main.rs b/src/main.rs index 9ca7eae..9da9c40 100644 --- a/src/main.rs +++ b/src/main.rs @@ -450,10 +450,10 @@ async fn play( let guild = invoke.guild(ctx.cache.clone()).await.unwrap(); invoke - .channel_id() - .say( - &ctx, - play_cmd(ctx, guild, invoke.author_id(), args, false).await, + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(play_cmd(ctx, guild, invoke.author_id(), args, false).await), ) .await?; @@ -499,6 +499,8 @@ async fn play_cmd(ctx: &Context, guild: Guild, user_id: UserId, args: Args, loop Some(user_channel) => { let search_term = args.named("query").unwrap(); + println!("{}", search_term); + let pool = ctx .data .read() @@ -569,7 +571,7 @@ async fn play_ambience( match channel_to_join { Some(user_channel) => { - let search_name = args.named("query").unwrap().to_lowercase(); + let search_name = args.named("name").unwrap().to_lowercase(); let audio_index = ctx.data.read().await.get::().cloned().unwrap(); if let Some(filename) = audio_index.get(&search_name) { @@ -661,6 +663,10 @@ async fn stop_playing( lock.stop(); } + invoke + .respond(ctx.http.clone(), CreateGenericResponse::new().content("👍")) + .await?; + Ok(()) } @@ -678,6 +684,10 @@ async fn disconnect( let songbird = songbird::get(ctx).await.unwrap(); let _ = songbird.leave(guild_id).await; + invoke + .respond(ctx.http.clone(), CreateGenericResponse::new().content("👍")) + .await?; + Ok(()) } @@ -1097,7 +1107,7 @@ async fn list_sounds( let sounds; let mut message_buffer; - if args.named("me").is_some() { + if args.named("me").map(|i| i.to_owned()) == Some("me".to_string()) { sounds = Sound::get_user_sounds(invoke.author_id(), pool).await?; message_buffer = "All your sounds: ".to_string();