diff --git a/.idea/dataSources.local.xml b/.idea/dataSources.local.xml index c44ff77..55767b9 100644 --- a/.idea/dataSources.local.xml +++ b/.idea/dataSources.local.xml @@ -1,6 +1,6 @@ - + master_key diff --git a/Cargo.toml b/Cargo.toml index 4645352..ac0f01e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" [dependencies] songbird = { git = "https://github.com/FelixMcFelix/songbird", branch = "ws-fix" } -serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["voice", "collector"] } +serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["voice", "collector", "unstable_discord_api"] } sqlx = { version = "0.5", default-features = false, features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal"] } dotenv = "0.15" tokio = { version = "1", features = ["fs", "process", "io-util"] } diff --git a/regex_command_attr/src/lib.rs b/regex_command_attr/src/lib.rs index e72451d..635ac5a 100644 --- a/regex_command_attr/src/lib.rs +++ b/regex_command_attr/src/lib.rs @@ -57,11 +57,15 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { let name = &name[..]; match_options!(name, values, options, span => [ - permission_level + permission_level; + allow_slash ]); } - let Options { permission_level } = options; + let Options { + permission_level, + allow_slash, + } = options; propagate_err!(create_declaration_validations(&mut fun, DeclarFor::Command)); @@ -88,6 +92,7 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { func: #name, name: #lit_name, required_perms: #permission_level, + allow_slash: #allow_slash, }; #visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> { diff --git a/regex_command_attr/src/structures.rs b/regex_command_attr/src/structures.rs index ed7952b..3b0043e 100644 --- a/regex_command_attr/src/structures.rs +++ b/regex_command_attr/src/structures.rs @@ -226,11 +226,15 @@ impl ToTokens for PermissionLevel { #[derive(Debug, Default)] pub struct Options { pub permission_level: PermissionLevel, + pub allow_slash: bool, } impl Options { #[inline] pub fn new() -> Self { - Self::default() + Self { + permission_level: PermissionLevel::default(), + allow_slash: false, + } } } diff --git a/regex_command_attr/src/util.rs b/regex_command_attr/src/util.rs index 0298135..ed514bf 100644 --- a/regex_command_attr/src/util.rs +++ b/regex_command_attr/src/util.rs @@ -182,7 +182,7 @@ pub fn create_declaration_validations(fun: &mut CommandFun, dec_for: DeclarFor) } let context: Type = parse_quote!(&serenity::client::Context); - let message: Type = parse_quote!(&serenity::model::channel::Message); + let message: Type = parse_quote!(&(dyn crate::framework::CommandInvoke + Sync + Send)); let args: Type = parse_quote!(serenity::framework::standard::Args); let args2: Type = parse_quote!(&mut serenity::framework::standard::Args); let options: Type = parse_quote!(&serenity::framework::standard::CommandOptions); diff --git a/src/event_handlers.rs b/src/event_handlers.rs index 2abba01..35101a2 100644 --- a/src/event_handlers.rs +++ b/src/event_handlers.rs @@ -1,7 +1,20 @@ -use serenity::async_trait; -use songbird::Event; -use songbird::EventContext; -use songbird::EventHandler as SongbirdEventHandler; +use crate::{ + guild_data::CtxGuildData, + join_channel, play_audio, + sound::{JoinSoundCtx, Sound}, + MySQL, ReqwestClient, +}; + +use serenity::{ + async_trait, + client::{Context, EventHandler}, + model::{channel::Channel, guild::Guild, id::GuildId, voice::VoiceState}, + utils::shard_id, +}; + +use songbird::{Event, EventContext, EventHandler as SongbirdEventHandler}; + +use std::{collections::HashMap, env}; pub struct RestartTrack; @@ -15,3 +28,130 @@ impl SongbirdEventHandler for RestartTrack { None } } + +pub struct Handler; + +#[serenity::async_trait] +impl EventHandler for Handler { + async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) { + if is_new { + if let Ok(token) = env::var("DISCORDBOTS_TOKEN") { + let shard_count = ctx.cache.shard_count().await; + let current_shard_id = shard_id(guild.id.as_u64().to_owned(), shard_count); + + let guild_count = ctx + .cache + .guilds() + .await + .iter() + .filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id) + .count() as u64; + + let mut hm = HashMap::new(); + hm.insert("server_count", guild_count); + hm.insert("shard_id", current_shard_id); + hm.insert("shard_count", shard_count); + + let client = ctx + .data + .read() + .await + .get::() + .cloned() + .expect("Could not get ReqwestClient from data"); + + let response = client + .post( + format!( + "https://top.gg/api/bots/{}/stats", + ctx.cache.current_user_id().await.as_u64() + ) + .as_str(), + ) + .header("Authorization", token) + .json(&hm) + .send() + .await; + + if let Err(res) = response { + println!("DiscordBots Response: {:?}", res); + } + } + } + } + + async fn voice_state_update( + &self, + ctx: Context, + guild_id_opt: Option, + old: Option, + new: VoiceState, + ) { + if let Some(past_state) = old { + if let (Some(guild_id), None) = (guild_id_opt, new.channel_id) { + if let Some(channel_id) = past_state.channel_id { + if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx).await + { + if channel.members(&ctx).await.map(|m| m.len()).unwrap_or(0) <= 1 { + let songbird = songbird::get(&ctx).await.unwrap(); + + let _ = songbird.remove(guild_id).await; + } + } + } + } + } else if let (Some(guild_id), Some(user_channel)) = (guild_id_opt, new.channel_id) { + if let Some(guild) = ctx.cache.guild(guild_id).await { + let pool = ctx + .data + .read() + .await + .get::() + .cloned() + .expect("Could not get SQLPool from data"); + + let guild_data_opt = ctx.guild_data(guild.id).await; + + if let Ok(guild_data) = guild_data_opt { + let volume; + let allowed_greets; + + { + let read = guild_data.read().await; + + volume = read.volume; + allowed_greets = read.allow_greets; + } + + if allowed_greets { + if let Some(join_id) = ctx.join_sound(new.user_id).await { + let mut sound = sqlx::query_as_unchecked!( + Sound, + " +SELECT name, id, plays, public, server_id, uploader_id + FROM sounds + WHERE id = ? + ", + join_id + ) + .fetch_one(&pool) + .await + .unwrap(); + + let (handler, _) = join_channel(&ctx, guild, user_channel).await; + + let _ = play_audio( + &mut sound, + volume, + &mut handler.lock().await, + pool, + false, + ) + .await; + } + } + } + } + } + } +} diff --git a/src/framework.rs b/src/framework.rs index ebf6834..88b57b8 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -2,13 +2,17 @@ use serenity::{ async_trait, client::Context, constants::MESSAGE_CODE_LIMIT, - framework::{standard::Args, Framework}, + framework::{ + standard::{Args, CommandResult, Delimiter}, + Framework, + }, futures::prelude::future::BoxFuture, http::Http, model::{ channel::{Channel, GuildChannel, Message}, guild::{Guild, Member}, - id::ChannelId, + id::{ChannelId, GuildId, UserId}, + interactions::Interaction, }, Result as SerenityResult, }; @@ -20,9 +24,164 @@ use regex::{Match, Regex, RegexBuilder}; use std::{collections::HashMap, fmt}; use crate::{guild_data::CtxGuildData, MySQL}; -use serenity::framework::standard::{CommandResult, Delimiter}; +use serenity::builder::CreateEmbed; +use serenity::cache::Cache; +use serenity::model::prelude::InteractionResponseType; +use std::sync::Arc; -type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, Args) -> BoxFuture<'fut, CommandResult>; +type CommandFn = for<'fut> fn( + &'fut Context, + &'fut (dyn CommandInvoke + Sync + Send), + Args, +) -> BoxFuture<'fut, CommandResult>; + +pub struct CreateGenericResponse { + content: String, + embed: Option, +} + +impl CreateGenericResponse { + pub fn new() -> Self { + Self { + content: "".to_string(), + embed: None, + } + } + + pub fn content(mut self, content: D) -> Self { + self.content = content.to_string(); + + self + } + + pub fn embed &mut CreateEmbed>(mut self, f: F) -> Self { + let mut embed = CreateEmbed::default(); + + f(&mut embed); + + self.embed = Some(embed); + + self + } +} + +#[async_trait] +pub trait CommandInvoke { + fn channel_id(&self) -> ChannelId; + fn guild_id(&self) -> Option; + async fn guild(&self, cache: Arc) -> Option; + fn author_id(&self) -> UserId; + async fn member(&self, context: &Context) -> SerenityResult; + fn msg(&self) -> Option; + fn interaction(&self) -> Option; + async fn respond( + &self, + http: Arc, + generic_response: CreateGenericResponse, + ) -> SerenityResult<()>; +} + +#[async_trait] +impl CommandInvoke for Message { + fn channel_id(&self) -> ChannelId { + self.channel_id + } + + fn guild_id(&self) -> Option { + self.guild_id + } + + async fn guild(&self, cache: Arc) -> Option { + self.guild(cache).await + } + + fn author_id(&self) -> UserId { + self.author.id + } + + async fn member(&self, context: &Context) -> SerenityResult { + self.member(context).await + } + + fn msg(&self) -> Option { + Some(self.clone()) + } + + fn interaction(&self) -> Option { + None + } + + async fn respond( + &self, + http: Arc, + generic_response: CreateGenericResponse, + ) -> SerenityResult<()> { + self.channel_id + .send_message(http, |m| { + m.content(generic_response.content); + + if let Some(embed) = generic_response.embed { + m.set_embed(embed.clone()); + } + + m + }) + .await + .map(|_| ()) + } +} + +#[async_trait] +impl CommandInvoke for Interaction { + fn channel_id(&self) -> ChannelId { + self.channel_id.unwrap() + } + + fn guild_id(&self) -> Option { + self.guild_id + } + + async fn guild(&self, cache: Arc) -> Option { + self.guild(cache).await + } + + fn author_id(&self) -> UserId { + self.member.as_ref().unwrap().user.id + } + + async fn member(&self, _: &Context) -> SerenityResult { + Ok(self.member.clone().unwrap()) + } + + fn msg(&self) -> Option { + None + } + + fn interaction(&self) -> Option { + Some(self.clone()) + } + + async fn respond( + &self, + http: Arc, + generic_response: CreateGenericResponse, + ) -> SerenityResult<()> { + self.create_interaction_response(http, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| { + d.content(generic_response.content); + + if let Some(embed) = generic_response.embed { + d.set_embed(embed.clone()); + } + + d + }) + }) + .await + .map(|_| ()) + } +} #[derive(Debug, PartialEq)] pub enum PermissionLevel { @@ -35,6 +194,7 @@ pub struct Command { pub name: &'static str, pub required_perms: PermissionLevel, pub func: CommandFn, + pub allow_slash: bool, } impl Command { diff --git a/src/main.rs b/src/main.rs index 4a657d9..6399575 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,10 +8,10 @@ mod guild_data; mod sound; use crate::{ - event_handlers::RestartTrack, - framework::RegexFramework, + event_handlers::{Handler, RestartTrack}, + framework::{CommandInvoke, CreateGenericResponse, RegexFramework}, guild_data::{CtxGuildData, GuildData}, - sound::Sound, + sound::{JoinSoundCtx, Sound}, }; use log::info; @@ -23,14 +23,10 @@ use serenity::{ framework::standard::{Args, CommandResult}, http::Http, model::{ - channel::Channel, - channel::Message, guild::Guild, - id::{ChannelId, GuildId, RoleId}, - voice::VoiceState, + id::{ChannelId, GuildId, RoleId, UserId}, }, prelude::*, - utils::shard_id, }; use songbird::{ @@ -46,10 +42,10 @@ use sqlx::mysql::MySqlPool; use dotenv::dotenv; -use crate::sound::JoinSoundCtx; use dashmap::DashMap; -use serenity::model::id::UserId; + use std::{collections::HashMap, convert::TryFrom, env, sync::Arc, time::Duration}; + use tokio::sync::{MutexGuard, RwLock}; struct MySQL; @@ -90,134 +86,6 @@ lazy_static! { static ref PATREON_ROLE: u64 = env::var("PATREON_ROLE").unwrap().parse::().unwrap(); } -// create event handler for bot -struct Handler; - -#[serenity::async_trait] -impl EventHandler for Handler { - async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) { - if is_new { - if let Ok(token) = env::var("DISCORDBOTS_TOKEN") { - let shard_count = ctx.cache.shard_count().await; - let current_shard_id = shard_id(guild.id.as_u64().to_owned(), shard_count); - - let guild_count = ctx - .cache - .guilds() - .await - .iter() - .filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id) - .count() as u64; - - let mut hm = HashMap::new(); - hm.insert("server_count", guild_count); - hm.insert("shard_id", current_shard_id); - hm.insert("shard_count", shard_count); - - let client = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("Could not get ReqwestClient from data"); - - let response = client - .post( - format!( - "https://top.gg/api/bots/{}/stats", - ctx.cache.current_user_id().await.as_u64() - ) - .as_str(), - ) - .header("Authorization", token) - .json(&hm) - .send() - .await; - - if let Err(res) = response { - println!("DiscordBots Response: {:?}", res); - } - } - } - } - - async fn voice_state_update( - &self, - ctx: Context, - guild_id_opt: Option, - old: Option, - new: VoiceState, - ) { - if let Some(past_state) = old { - if let (Some(guild_id), None) = (guild_id_opt, new.channel_id) { - if let Some(channel_id) = past_state.channel_id { - if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx).await - { - if channel.members(&ctx).await.map(|m| m.len()).unwrap_or(0) <= 1 { - let songbird = songbird::get(&ctx).await.unwrap(); - - let _ = songbird.remove(guild_id).await; - } - } - } - } - } else if let (Some(guild_id), Some(user_channel)) = (guild_id_opt, new.channel_id) { - if let Some(guild) = ctx.cache.guild(guild_id).await { - let pool = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("Could not get SQLPool from data"); - - let guild_data_opt = ctx.guild_data(guild.id).await; - - if let Ok(guild_data) = guild_data_opt { - let volume; - let allowed_greets; - - { - let read = guild_data.read().await; - - volume = read.volume; - allowed_greets = read.allow_greets; - } - - if allowed_greets { - if let Some(join_id) = ctx.join_sound(new.user_id).await { - let mut sound = sqlx::query_as_unchecked!( - Sound, - " -SELECT name, id, plays, public, server_id, uploader_id - FROM sounds - WHERE id = ? - ", - join_id - ) - .fetch_one(&pool) - .await - .unwrap(); - - let (handler, _) = join_channel(&ctx, guild, user_channel).await; - - let _ = play_audio( - &mut sound, - volume, - &mut handler.lock().await, - pool, - false, - ) - .await; - } - } - } - } - } - } -} - async fn play_audio( sound: &mut Sound, volume: u8, @@ -293,6 +161,7 @@ async fn main() -> Result<(), Box> { let http = Http::new_with_token(&token); let logged_in_id = http.get_current_user().await?.id; + let application_id = http.get_current_application_info().await?.id; let audio_index = if let Ok(static_audio) = std::fs::read_to_string("audio/audio.json") { if let Ok(json) = serde_json::from_str::>(&static_audio) { @@ -356,6 +225,7 @@ async fn main() -> Result<(), Box> { | GatewayIntents::GUILDS, ) .framework(framework) + .application_id(application_id.0) .event_handler(Handler) .register_songbird() .await @@ -426,10 +296,14 @@ async fn main() -> Result<(), Box> { } #[command] -async fn help(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +async fn help( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { if args.is_empty() { let description = { - let guild_data = ctx.guild_data(msg.guild_id.unwrap()).await.unwrap(); + let guild_data = ctx.guild_data(invoke.guild_id().unwrap()).await.unwrap(); let read_lock = guild_data.read().await; @@ -439,7 +313,8 @@ async fn help(ctx: &Context, msg: &Message, args: Args) -> CommandResult { ) }; - msg.channel_id + invoke + .channel_id() .send_message(&ctx, |m| { m.embed(|e| { e.title("Help") @@ -538,7 +413,8 @@ Please select a category from the following: } }; - msg.channel_id + invoke + .channel_id() .send_message(&ctx, |m| { m.embed(|e| e.title("Help").color(THEME_COLOR).description(body)) }) @@ -550,30 +426,55 @@ Please select a category from the following: #[command] #[permission_level(Managed)] -async fn play(ctx: &Context, msg: &Message, args: Args) -> CommandResult { - play_cmd(ctx, msg, args, false).await +async fn play( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { + let guild = invoke + .guild_id() + .unwrap() + .to_guild_cached(&ctx) + .await + .unwrap(); + + invoke + .channel_id() + .say( + &ctx, + play_cmd(ctx, guild, invoke.author_id(), args, false).await, + ) + .await?; + + Ok(()) } #[command] #[permission_level(Managed)] -async fn loop_play(ctx: &Context, msg: &Message, args: Args) -> CommandResult { - play_cmd(ctx, msg, args, true).await +async fn loop_play( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { + let guild = invoke.guild(ctx.cache.clone()).await.unwrap(); + + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(play_cmd(ctx, guild, invoke.author_id(), args, true).await), + ) + .await?; + + Ok(()) } -async fn play_cmd(ctx: &Context, msg: &Message, args: Args, loop_: bool) -> CommandResult { - let guild = match msg.guild(&ctx.cache).await { - Some(guild) => guild, - - None => { - return Ok(()); - } - }; - +async fn play_cmd(ctx: &Context, guild: Guild, user_id: UserId, args: Args, loop_: bool) -> String { let guild_id = guild.id; let channel_to_join = guild .voice_states - .get(&msg.author.id) + .get(&user_id) .and_then(|voice_state| voice_state.channel_id); match channel_to_join { @@ -588,14 +489,10 @@ async fn play_cmd(ctx: &Context, msg: &Message, args: Args, loop_: bool) -> Comm .cloned() .expect("Could not get SQLPool from data"); - let mut sound_vec = Sound::search_for_sound( - search_term, - *guild_id.as_u64(), - *msg.author.id.as_u64(), - pool.clone(), - true, - ) - .await?; + let mut sound_vec = + Sound::search_for_sound(search_term, guild_id, user_id, pool.clone(), true) + .await + .unwrap(); let sound_res = sound_vec.first_mut(); @@ -616,49 +513,33 @@ async fn play_cmd(ctx: &Context, msg: &Message, args: Args, loop_: bool) -> Comm pool, loop_, ) - .await?; + .await + .unwrap(); } - msg.channel_id - .say( - &ctx, - format!("Playing sound {} with ID {}", sound.name, sound.id), - ) - .await?; + format!("Playing sound {} with ID {}", sound.name, sound.id) } - None => { - msg.channel_id - .say(&ctx, "Couldn't find sound by term provided") - .await?; - } + None => "Couldn't find sound by term provided".to_string(), } } - None => { - msg.channel_id - .say(&ctx, "You are not in a voice chat!") - .await?; - } + None => "You are not in a voice chat!".to_string(), } - - Ok(()) } #[command] #[permission_level(Managed)] -async fn play_ambience(ctx: &Context, msg: &Message, args: Args) -> CommandResult { - let guild = match msg.guild(&ctx.cache).await { - Some(guild) => guild, - - None => { - return Ok(()); - } - }; +async fn play_ambience( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { + let guild = invoke.guild(ctx.cache.clone()).await.unwrap(); let channel_to_join = guild .voice_states - .get(&msg.author.id) + .get(&invoke.author_id()) .and_then(|voice_state| voice_state.channel_id); match channel_to_join { @@ -692,13 +573,18 @@ async fn play_ambience(ctx: &Context, msg: &Message, args: Args) -> CommandResul RestartTrack {}, ); - msg.channel_id - .say(&ctx, format!("Playing ambience **{}**", search_name)) + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(format!("Playing ambience **{}**", search_name)), + ) .await?; } else { - msg.channel_id - .send_message(&ctx, |m| { - m.embed(|e| { + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().embed(|e| { e.title("Not Found").description(format!( "Could not find ambience sound by name **{}** @@ -712,15 +598,18 @@ __Available ambience sounds:__ .collect::>() .join("\n") )) - }) - }) + }), + ) .await?; } } None => { - msg.channel_id - .say(&ctx, "You are not in a voice chat!") + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("You are not in a voice chat!"), + ) .await?; } } @@ -730,8 +619,12 @@ __Available ambience sounds:__ #[command] #[permission_level(Managed)] -async fn stop_playing(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { - let guild_id = msg.guild_id.unwrap(); +async fn stop_playing( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + _args: Args, +) -> CommandResult { + let guild_id = invoke.guild_id().unwrap(); let songbird = songbird::get(ctx).await.unwrap(); let call_opt = songbird.get(guild_id); @@ -747,8 +640,12 @@ async fn stop_playing(ctx: &Context, msg: &Message, _args: Args) -> CommandResul #[command] #[permission_level(Managed)] -async fn disconnect(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { - let guild_id = msg.guild_id.unwrap(); +async fn disconnect( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + _args: Args, +) -> CommandResult { + let guild_id = invoke.guild_id().unwrap(); let songbird = songbird::get(ctx).await.unwrap(); let _ = songbird.leave(guild_id).await; @@ -757,10 +654,14 @@ async fn disconnect(ctx: &Context, msg: &Message, _args: Args) -> CommandResult } #[command] -async fn info(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { +async fn info( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + _args: Args, +) -> CommandResult { let current_user = ctx.cache.current_user().await; - msg.channel_id.send_message(&ctx, |m| m + invoke.channel_id().send_message(&ctx, |m| m .embed(|e| e .title("Info") .color(THEME_COLOR) @@ -791,7 +692,11 @@ There is a maximum sound limit per user. This can be removed by subscribing at * #[command] #[permission_level(Managed)] -async fn change_volume(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { +async fn change_volume( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + mut args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -800,7 +705,7 @@ async fn change_volume(ctx: &Context, msg: &Message, mut args: Args) -> CommandR .cloned() .expect("Could not get SQLPool from data"); - let guild_data_opt = ctx.guild_data(msg.guild_id.unwrap()).await; + let guild_data_opt = ctx.guild_data(invoke.guild_id().unwrap()).await; let guild_data = guild_data_opt.unwrap(); if args.len() == 1 { @@ -810,25 +715,41 @@ async fn change_volume(ctx: &Context, msg: &Message, mut args: Args) -> CommandR guild_data.read().await.commit(pool).await?; - msg.channel_id - .say(&ctx, format!("Volume changed to {}%", volume)) + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(format!("Volume changed to {}%", volume)), + ) .await?; } Err(_) => { let read = guild_data.read().await; - msg.channel_id.say(&ctx, - format!("Current server volume: {vol}%. Change the volume with ```{prefix}volume ```", - vol = read.volume, prefix = read.prefix)).await?; + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!( + "Current server volume: {vol}%. Change the volume with `/volume `", + vol = read.volume + )), + ) + .await?; } } } else { let read = guild_data.read().await; - msg.channel_id.say(&ctx, - format!("Current server volume: {vol}%. Change the volume with ```{prefix}volume ```", - vol = read.volume, prefix = read.prefix)).await?; + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!( + "Current server volume: {vol}%. Change the volume with `/volume `", + vol = read.volume + )), + ) + .await?; } Ok(()) @@ -836,7 +757,11 @@ async fn change_volume(ctx: &Context, msg: &Message, mut args: Args) -> CommandR #[command] #[permission_level(Restricted)] -async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { +async fn change_prefix( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + mut args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -848,7 +773,7 @@ async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandR let guild_data; { - let guild_data_opt = ctx.guild_data(msg.guild_id.unwrap()).await; + let guild_data_opt = ctx.guild_data(invoke.guild_id().unwrap()).await; guild_data = guild_data_opt.unwrap(); } @@ -869,34 +794,43 @@ async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandR read.commit(pool).await?; } - msg.channel_id.say(&ctx, reply).await?; + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(reply), + ) + .await?; } else { - msg.channel_id - .say(&ctx, "Prefix must be less than 5 characters long") + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Prefix must be less than 5 characters long"), + ) .await?; } } Err(_) => { - msg.channel_id - .say( - &ctx, - format!( + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!( "Usage: `{prefix}prefix `", prefix = guild_data.read().await.prefix - ), + )), ) .await?; } } } else { - msg.channel_id - .say( - &ctx, - format!( + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!( "Usage: `{prefix}prefix `", prefix = guild_data.read().await.prefix - ), + )), ) .await?; } @@ -905,7 +839,14 @@ async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandR } #[command] -async fn upload_new_sound(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +#[allow_slash(false)] +async fn upload_new_sound( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { + let msg = invoke.msg().unwrap(); + fn is_numeric(s: &String) -> bool { for char in s.chars() { if char.is_digit(10) { @@ -1032,7 +973,13 @@ async fn upload_new_sound(ctx: &Context, msg: &Message, args: Args) -> CommandRe #[command] #[permission_level(Restricted)] -async fn set_allowed_roles(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +#[allow_slash(false)] +async fn set_allowed_roles( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { + let msg = invoke.msg().unwrap(); let guild_id = *msg.guild_id.unwrap().as_u64(); let pool = ctx @@ -1114,7 +1061,11 @@ INSERT INTO roles (guild_id, role) } #[command] -async fn list_sounds(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +async fn list_sounds( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -1127,11 +1078,11 @@ async fn list_sounds(ctx: &Context, msg: &Message, args: Args) -> CommandResult let mut message_buffer; if args.rest() == "me" { - sounds = Sound::get_user_sounds(*msg.author.id.as_u64(), pool).await?; + sounds = Sound::get_user_sounds(invoke.author_id(), pool).await?; message_buffer = "All your sounds: ".to_string(); } else { - sounds = Sound::get_guild_sounds(*msg.guild_id.unwrap().as_u64(), pool).await?; + sounds = Sound::get_guild_sounds(invoke.guild_id().unwrap(), pool).await?; message_buffer = "All sounds on this server: ".to_string(); } @@ -1147,21 +1098,35 @@ async fn list_sounds(ctx: &Context, msg: &Message, args: Args) -> CommandResult ); if message_buffer.len() > 2000 { - msg.channel_id.say(&ctx, message_buffer).await?; + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(message_buffer), + ) + .await?; message_buffer = "".to_string(); } } if message_buffer.len() > 0 { - msg.channel_id.say(&ctx, message_buffer).await?; + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(message_buffer), + ) + .await?; } Ok(()) } #[command] -async fn change_public(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +async fn change_public( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -1169,30 +1134,38 @@ async fn change_public(ctx: &Context, msg: &Message, args: Args) -> CommandResul .get::() .cloned() .expect("Could not get SQLPool from data"); - let uid = msg.author.id.as_u64(); + + let uid = invoke.author_id().as_u64().to_owned(); let name = args.rest(); - let gid = *msg.guild_id.unwrap().as_u64(); + let gid = *invoke.guild_id().unwrap().as_u64(); - let mut sound_vec = Sound::search_for_sound(name, gid, *uid, pool.clone(), true).await?; + let mut sound_vec = Sound::search_for_sound(name, gid, uid, pool.clone(), true).await?; let sound_result = sound_vec.first_mut(); match sound_result { Some(sound) => { - if sound.uploader_id != Some(*uid) { - msg.channel_id.say(&ctx, "You can only change the availability of sounds you have uploaded. Use `?list me` to view your sounds").await?; + if sound.uploader_id != Some(uid) { + invoke.respond(ctx.http.clone(), CreateGenericResponse::new().content("You can only change the visibility of sounds you have uploaded. Use `?list me` to view your sounds")).await?; } else { if sound.public { sound.public = false; - msg.channel_id - .say(&ctx, "Sound has been set to private 🔒") + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Sound has been set to private 🔒"), + ) .await?; } else { sound.public = true; - msg.channel_id - .say(&ctx, "Sound has been set to public 🔓") + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("Sound has been set to public 🔓"), + ) .await?; } @@ -1201,8 +1174,11 @@ async fn change_public(ctx: &Context, msg: &Message, args: Args) -> CommandResul } None => { - msg.channel_id - .say(&ctx, "Sound could not be found by that name.") + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("Sound could not be found by that name."), + ) .await?; } } @@ -1211,7 +1187,11 @@ async fn change_public(ctx: &Context, msg: &Message, args: Args) -> CommandResul } #[command] -async fn delete_sound(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +async fn delete_sound( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -1220,8 +1200,8 @@ async fn delete_sound(ctx: &Context, msg: &Message, args: Args) -> CommandResult .cloned() .expect("Could not get SQLPool from data"); - let uid = *msg.author.id.as_u64(); - let gid = *msg.guild_id.unwrap().as_u64(); + let uid = invoke.author_id().0; + let gid = invoke.guild_id().unwrap().0; let name = args.rest(); @@ -1231,15 +1211,17 @@ async fn delete_sound(ctx: &Context, msg: &Message, args: Args) -> CommandResult match sound_result { Some(sound) => { if sound.uploader_id != Some(uid) && sound.server_id != gid { - msg.channel_id - .say( - &ctx, - "You can only delete sounds from this guild or that you have uploaded.", + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content( + "You can only delete sounds from this guild or that you have uploaded.", + ), ) .await?; } else { let has_perms = { - if let Ok(member) = msg.member(&ctx).await { + if let Ok(member) = invoke.member(&ctx).await { if let Ok(perms) = member.permissions(&ctx).await { perms.manage_guild() } else { @@ -1253,12 +1235,19 @@ async fn delete_sound(ctx: &Context, msg: &Message, args: Args) -> CommandResult if sound.uploader_id == Some(uid) || has_perms { sound.delete(pool).await?; - msg.channel_id.say(&ctx, "Sound has been deleted").await?; + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("Sound has been deleted"), + ) + .await?; } else { - msg.channel_id - .say( - &ctx, - "Only server admins can delete sounds uploaded by other users.", + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content( + "Only server admins can delete sounds uploaded by other users.", + ), ) .await?; } @@ -1266,8 +1255,11 @@ async fn delete_sound(ctx: &Context, msg: &Message, args: Args) -> CommandResult } None => { - msg.channel_id - .say(&ctx, "Sound could not be found by that name.") + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("Sound could not be found by that name."), + ) .await?; } } @@ -1275,11 +1267,7 @@ async fn delete_sound(ctx: &Context, msg: &Message, args: Args) -> CommandResult Ok(()) } -async fn format_search_results( - search_results: Vec, - msg: &Message, - ctx: &Context, -) -> Result<(), Box> { +fn format_search_results(search_results: Vec) -> CreateGenericResponse { let mut current_character_count = 0; let title = "Public sounds matching filter:"; @@ -1299,15 +1287,15 @@ async fn format_search_results( current_character_count <= 2048 - title.len() }); - msg.channel_id - .send_message(&ctx, |m| m.embed(|e| e.title(title).fields(field_iter))) - .await?; - - Ok(()) + CreateGenericResponse::new().embed(|e| e.title(title).fields(field_iter)) } #[command] -async fn search_sounds(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +async fn search_sounds( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -1320,20 +1308,26 @@ async fn search_sounds(ctx: &Context, msg: &Message, args: Args) -> CommandResul let search_results = Sound::search_for_sound( query, - *msg.guild_id.unwrap().as_u64(), - *msg.author.id.as_u64(), + invoke.guild_id().unwrap(), + invoke.author_id(), pool, false, ) .await?; - format_search_results(search_results, msg, ctx).await?; + invoke + .respond(ctx.http.clone(), format_search_results(search_results)) + .await?; Ok(()) } #[command] -async fn show_popular_sounds(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { +async fn show_popular_sounds( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + _args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -1347,6 +1341,7 @@ async fn show_popular_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comma " SELECT name, id, plays, public, server_id, uploader_id FROM sounds + WHERE public = 1 ORDER BY plays DESC LIMIT 25 " @@ -1354,13 +1349,19 @@ SELECT name, id, plays, public, server_id, uploader_id .fetch_all(&pool) .await?; - format_search_results(search_results, msg, ctx).await?; + invoke + .respond(ctx.http.clone(), format_search_results(search_results)) + .await?; Ok(()) } #[command] -async fn show_random_sounds(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { +async fn show_random_sounds( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + _args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -1374,6 +1375,7 @@ async fn show_random_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comman " SELECT name, id, plays, public, server_id, uploader_id FROM sounds + WHERE public = 1 ORDER BY rand() LIMIT 25 " @@ -1382,15 +1384,19 @@ SELECT name, id, plays, public, server_id, uploader_id .await .unwrap(); - format_search_results(search_results, msg, ctx) - .await - .unwrap(); + invoke + .respond(ctx.http.clone(), format_search_results(search_results)) + .await?; Ok(()) } #[command] -async fn set_greet_sound(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +async fn set_greet_sound( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -1400,18 +1406,21 @@ async fn set_greet_sound(ctx: &Context, msg: &Message, args: Args) -> CommandRes .expect("Could not get SQLPool from data"); let query = args.rest(); - let user_id = *msg.author.id.as_u64(); + let user_id = invoke.author_id(); if query.len() == 0 { ctx.update_join_sound(user_id, None).await; - msg.channel_id - .say(&ctx, "Your greet sound has been unset.") + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("Your greet sound has been unset."), + ) .await?; } else { let sound_vec = Sound::search_for_sound( query, - *msg.guild_id.unwrap().as_u64(), + invoke.guild_id().unwrap(), user_id, pool.clone(), true, @@ -1422,20 +1431,24 @@ async fn set_greet_sound(ctx: &Context, msg: &Message, args: Args) -> CommandRes Some(sound) => { ctx.update_join_sound(user_id, Some(sound.id)).await; - msg.channel_id - .say( - &ctx, - format!( + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!( "Greet sound has been set to {} (ID {})", sound.name, sound.id - ), + )), ) .await?; } None => { - msg.channel_id - .say(&ctx, "Could not find a sound by that name.") + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Could not find a sound by that name."), + ) .await?; } } @@ -1446,7 +1459,11 @@ async fn set_greet_sound(ctx: &Context, msg: &Message, args: Args) -> CommandRes #[command] #[permission_level(Managed)] -async fn allow_greet_sounds(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { +async fn allow_greet_sounds( + ctx: &Context, + invoke: &(dyn CommandInvoke + Sync + Send), + _args: Args, +) -> CommandResult { let pool = ctx .data .read() @@ -1455,7 +1472,7 @@ async fn allow_greet_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comman .cloned() .expect("Could not acquire SQL pool from data"); - let guild_data_opt = ctx.guild_data(msg.guild_id.unwrap()).await; + let guild_data_opt = ctx.guild_data(invoke.guild_id().unwrap()).await; if let Ok(guild_data) = guild_data_opt { let current = guild_data.read().await.allow_greets; @@ -1466,13 +1483,13 @@ async fn allow_greet_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comman guild_data.read().await.commit(pool).await?; - msg.channel_id - .say( - &ctx, - format!( + invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!( "Greet sounds have been {}abled in this server", if !current { "en" } else { "dis" } - ), + )), ) .await?; } diff --git a/src/sound.rs b/src/sound.rs index 167e5a0..d85a857 100644 --- a/src/sound.rs +++ b/src/sound.rs @@ -121,13 +121,16 @@ pub struct Sound { } impl Sound { - pub async fn search_for_sound( + pub async fn search_for_sound, U: Into>( query: &str, - guild_id: u64, - user_id: u64, + guild_id: G, + user_id: U, db_pool: MySqlPool, strict: bool, ) -> Result, sqlx::Error> { + let guild_id = guild_id.into(); + let user_id = user_id.into(); + fn extract_id(s: &str) -> Option { if s.len() > 3 && s.to_lowercase().starts_with("id:") { match s[3..].parse::() { @@ -403,8 +406,8 @@ INSERT INTO sounds (name, server_id, uploader_id, public, src) } } - pub async fn get_user_sounds( - user_id: u64, + pub async fn get_user_sounds>( + user_id: U, db_pool: MySqlPool, ) -> Result, Box> { let sounds = sqlx::query_as_unchecked!( @@ -414,7 +417,7 @@ SELECT name, id, plays, public, server_id, uploader_id FROM sounds WHERE uploader_id = ? ", - user_id + user_id.into() ) .fetch_all(&db_pool) .await?; @@ -422,8 +425,8 @@ SELECT name, id, plays, public, server_id, uploader_id Ok(sounds) } - pub async fn get_guild_sounds( - guild_id: u64, + pub async fn get_guild_sounds>( + guild_id: G, db_pool: MySqlPool, ) -> Result, Box> { let sounds = sqlx::query_as_unchecked!( @@ -433,7 +436,7 @@ SELECT name, id, plays, public, server_id, uploader_id FROM sounds WHERE server_id = ? ", - guild_id + guild_id.into() ) .fetch_all(&db_pool) .await?;