diff --git a/Cargo.lock b/Cargo.lock index dd7c375..b6c9b6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1840,6 +1840,7 @@ dependencies = [ name = "soundfx-rs" version = "1.2.0" dependencies = [ + "dashmap", "dotenv", "env_logger", "lazy_static", diff --git a/Cargo.toml b/Cargo.toml index 06e232d..50e4a66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ songbird = { git = "https://github.com/serenity-rs/songbird", branch = "current" regex = "1.4" log = "0.4" serde_json = "1.0" +dashmap = "4.0" [dependencies.regex_command_attr] path = "./regex_command_attr" diff --git a/src/framework.rs b/src/framework.rs index 1cf77a0..454efdc 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -19,7 +19,7 @@ use regex::{Match, Regex, RegexBuilder}; use std::{collections::HashMap, fmt}; -use crate::{guild_data::GuildData, MySQL}; +use crate::{guild_data::CtxGuildData, MySQL}; use serenity::framework::standard::{CommandResult, Delimiter}; type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, Args) -> BoxFuture<'fut, CommandResult>; @@ -252,24 +252,11 @@ impl Framework for RegexFramework { async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option>) -> bool { if let Some(prefix) = prefix_opt { - let pool = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("Could not get SQLPool from data"); + match ctx.get_from_id(guild.id).await { + Ok(guild_data) => prefix.as_str() == guild_data.read().await.prefix, - let guild_prefix = match GuildData::get_from_id(guild.clone(), pool.clone()).await { - Some(guild_data) => guild_data.prefix, - - None => { - GuildData::create_from_guild(guild, pool).await.unwrap(); - String::from("?") - } - }; - - guild_prefix.as_str() == prefix.as_str() + Err(_) => prefix.as_str() == "?", + } } else { true } diff --git a/src/guild_data.rs b/src/guild_data.rs index 90f432c..7ff8e02 100644 --- a/src/guild_data.rs +++ b/src/guild_data.rs @@ -1,6 +1,10 @@ -use serenity::model::guild::Guild; +use crate::{GuildDataCache, MySQL}; +use serenity::{async_trait, model::id::GuildId, prelude::Context}; use sqlx::mysql::MySqlPool; +use std::sync::Arc; +use tokio::sync::RwLock; +#[derive(Clone)] pub struct GuildData { pub id: u64, pub prefix: String, @@ -8,8 +12,59 @@ pub struct GuildData { pub allow_greets: bool, } +#[async_trait] +pub trait CtxGuildData { + async fn get_from_id + Send + Sync>( + &self, + guild_id: G, + ) -> Result>, sqlx::Error>; +} + +#[async_trait] +impl CtxGuildData for Context { + async fn get_from_id + Send + Sync>( + &self, + guild_id: G, + ) -> Result>, sqlx::Error> { + let guild_id = guild_id.into(); + + let guild_cache = self + .data + .read() + .await + .get::() + .cloned() + .unwrap(); + + let x = if let Some(guild_data) = guild_cache.get(&guild_id) { + Ok(guild_data.clone()) + } else { + let pool = self.data.read().await.get::().cloned().unwrap(); + + match GuildData::get_from_id(guild_id, pool).await { + Ok(d) => { + let lock = Arc::new(RwLock::new(d)); + + guild_cache.insert(guild_id, lock.clone()); + + Ok(lock) + } + + Err(e) => Err(e), + } + }; + + x + } +} + impl GuildData { - pub async fn get_from_id(guild: Guild, db_pool: MySqlPool) -> Option { + pub async fn get_from_id>( + guild_id: G, + db_pool: MySqlPool, + ) -> Result { + let guild_id = guild_id.into(); + let guild_data = sqlx::query_as_unchecked!( GuildData, " @@ -17,35 +72,32 @@ SELECT id, prefix, volume, allow_greets FROM servers WHERE id = ? ", - guild.id.as_u64() + guild_id.as_u64() ) .fetch_one(&db_pool) .await; match guild_data { - Ok(g) => Some(g), - - Err(sqlx::Error::RowNotFound) => Self::create_from_guild(&guild, db_pool).await.ok(), - - Err(e) => { - println!("{:?}", e); - - None + Err(sqlx::error::Error::RowNotFound) => { + Self::create_from_guild(guild_id, db_pool).await } + + d => d, } } - pub async fn create_from_guild( - guild: &Guild, + async fn create_from_guild>( + guild_id: G, db_pool: MySqlPool, - ) -> Result> { + ) -> Result { + let guild_id = guild_id.into(); + sqlx::query!( " -INSERT INTO servers (id, name) - VALUES (?, ?) +INSERT INTO servers (id) + VALUES (?) ", - guild.id.as_u64(), - guild.name + guild_id.as_u64() ) .execute(&db_pool) .await?; @@ -55,14 +107,14 @@ INSERT INTO servers (id, name) INSERT IGNORE INTO roles (guild_id, role) VALUES (?, ?) ", - guild.id.as_u64(), - guild.id.as_u64() + guild_id.as_u64(), + guild_id.as_u64() ) .execute(&db_pool) .await?; Ok(GuildData { - id: guild.id.as_u64().to_owned(), + id: guild_id.as_u64().to_owned(), prefix: String::from("?"), volume: 100, allow_greets: true, diff --git a/src/main.rs b/src/main.rs index 9e7dc2c..7936b06 100644 --- a/src/main.rs +++ b/src/main.rs @@ -41,9 +41,10 @@ use sqlx::mysql::MySqlPool; use dotenv::dotenv; -use crate::framework::RegexFramework; +use crate::{framework::RegexFramework, guild_data::CtxGuildData}; +use dashmap::DashMap; use std::{collections::HashMap, convert::TryFrom, env, sync::Arc, time::Duration}; -use tokio::sync::MutexGuard; +use tokio::sync::{MutexGuard, RwLock}; struct MySQL; @@ -63,6 +64,12 @@ impl TypeMapKey for AudioIndex { type Value = Arc>; } +struct GuildDataCache; + +impl TypeMapKey for GuildDataCache { + type Value = Arc>>>; +} + const THEME_COLOR: u32 = 0x00e0f3; lazy_static! { @@ -145,10 +152,20 @@ impl EventHandler for Handler { .cloned() .expect("Could not get SQLPool from data"); - let guild_data_opt = GuildData::get_from_id(guild.clone(), pool.clone()).await; + let guild_data_opt = ctx.get_from_id(guild.id).await; - if let Some(guild_data) = guild_data_opt { - if guild_data.allow_greets { + if let Ok(guild_data) = guild_data_opt { + let volume; + let allowed_greets; + + { + let read = guild_data.read().await; + + volume = read.volume; + allowed_greets = read.allow_greets; + } + + if allowed_greets { let join_id_res = sqlx::query!( " SELECT join_sound_id @@ -180,7 +197,7 @@ SELECT name, id, plays, public, server_id, uploader_id let _ = play_audio( &mut sound, - guild_data, + volume, &mut handler.lock().await, pool, false, @@ -197,7 +214,7 @@ SELECT name, id, plays, public, server_id, uploader_id async fn play_audio( sound: &mut Sound, - guild: GuildData, + volume: u8, call_handler: &mut MutexGuard<'_, Call>, mysql_pool: MySqlPool, loop_: bool, @@ -206,7 +223,7 @@ async fn play_audio( let (track, track_handler) = create_player(sound.store_sound_source(mysql_pool.clone()).await?.into()); - let _ = track_handler.set_volume(guild.volume as f32 / 100.0); + let _ = track_handler.set_volume(volume as f32 / 100.0); if loop_ { let _ = track_handler.enable_loop(); @@ -240,12 +257,31 @@ async fn join_channel( let call_opt = songbird.get(guild.id); if let Some(call) = call_opt { + { + // set call to deafen + let _ = call.lock().await.deafen(true).await; + } + (call, Ok(())) } else { - songbird.join(guild.id, channel_id).await + let (call, res) = songbird.join(guild.id, channel_id).await; + + { + // set call to deafen + let _ = call.lock().await.deafen(true).await; + } + + (call, res) } } else { - songbird.join(guild.id, channel_id).await + let (call, res) = songbird.join(guild.id, channel_id).await; + + { + // set call to deafen + let _ = call.lock().await.deafen(true).await; + } + + (call, res) }; (call, res) @@ -335,8 +371,11 @@ async fn main() -> Result<(), Box> { .await .unwrap(); + let guild_data_cache = Arc::new(DashMap::new()); + let mut data = client.data.write().await; + data.insert::(guild_data_cache); data.insert::(mysql_pool); data.insert::(Arc::new(reqwest::Client::new())); @@ -408,11 +447,18 @@ async fn play_cmd(ctx: &Context, msg: &Message, args: Args, loop_: bool) -> Comm let (call_handler, _) = join_channel(ctx, guild.clone(), user_channel).await; - let guild_data = GuildData::get_from_id(guild, pool.clone()).await.unwrap(); + let guild_data = ctx.get_from_id(guild).await.unwrap(); let mut lock = call_handler.lock().await; - play_audio(sound, guild_data, &mut lock, pool, loop_).await?; + play_audio( + sound, + guild_data.read().await.volume, + &mut lock, + pool, + loop_, + ) + .await?; } msg.channel_id @@ -474,13 +520,12 @@ async fn play_ambience(ctx: &Context, msg: &Message, args: Args) -> CommandResul Some(user_channel) => { let search_name = args.rest().to_lowercase(); let audio_index = ctx.data.read().await.get::().cloned().unwrap(); - let pool = ctx.data.read().await.get::().cloned().unwrap(); if let Some(filename) = audio_index.get(&search_name) { { let (call_handler, _) = join_channel(ctx, guild.clone(), user_channel).await; - let guild_data = GuildData::get_from_id(guild, pool.clone()).await.unwrap(); + let guild_data = ctx.get_from_id(guild).await.unwrap(); let mut lock = call_handler.lock().await; @@ -495,7 +540,7 @@ async fn play_ambience(ctx: &Context, msg: &Message, args: Args) -> CommandResul .unwrap(), ); - let _ = track_handler.set_volume(guild_data.volume as f32 / 100.0); + let _ = track_handler.set_volume(guild_data.read().await.volume as f32 / 100.0); let _ = track_handler.add_event( Event::Periodic( track_handler.metadata().duration.unwrap() - Duration::from_millis(500), @@ -604,14 +649,6 @@ There is a maximum sound limit per user. This can be removed by donating at http #[command] #[permission_level(Managed)] async fn change_volume(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let guild = match msg.guild(&ctx.cache).await { - Some(guild) => guild, - - None => { - return Ok(()); - } - }; - let pool = ctx .data .read() @@ -620,15 +657,15 @@ async fn change_volume(ctx: &Context, msg: &Message, mut args: Args) -> CommandR .cloned() .expect("Could not get SQLPool from data"); - let guild_data_opt = GuildData::get_from_id(guild, pool.clone()).await; - let mut guild_data = guild_data_opt.unwrap(); + let guild_data_opt = ctx.get_from_id(msg.guild_id.unwrap()).await; + let guild_data = guild_data_opt.unwrap(); if args.len() == 1 { match args.single::() { Ok(volume) => { - guild_data.volume = volume; + guild_data.write().await.volume = volume; - guild_data.commit(pool).await?; + guild_data.read().await.commit(pool).await?; msg.channel_id .say(&ctx, format!("Volume changed to {}%", volume)) @@ -636,15 +673,19 @@ async fn change_volume(ctx: &Context, msg: &Message, mut args: Args) -> CommandR } Err(_) => { + let read = guild_data.read().await; + msg.channel_id.say(&ctx, format!("Current server volume: {vol}%. Change the volume with ```{prefix}volume ```", - vol = guild_data.volume, prefix = guild_data.prefix)).await?; + vol = read.volume, prefix = read.prefix)).await?; } } } else { + let read = guild_data.read().await; + msg.channel_id.say(&ctx, format!("Current server volume: {vol}%. Change the volume with ```{prefix}volume ```", - vol = guild_data.volume, prefix = guild_data.prefix)).await?; + vol = read.volume, prefix = read.prefix)).await?; } Ok(()) @@ -653,14 +694,6 @@ async fn change_volume(ctx: &Context, msg: &Message, mut args: Args) -> CommandR #[command] #[permission_level(Restricted)] async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let guild = match msg.guild(&ctx.cache).await { - Some(guild) => guild, - - None => { - return Ok(()); - } - }; - let pool = ctx .data .read() @@ -669,10 +702,10 @@ async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandR .cloned() .expect("Could not get SQLPool from data"); - let mut guild_data; + let guild_data; { - let guild_data_opt = GuildData::get_from_id(guild, pool.clone()).await; + let guild_data_opt = ctx.get_from_id(msg.guild_id.unwrap()).await; guild_data = guild_data_opt.unwrap(); } @@ -681,13 +714,19 @@ async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandR match args.single::() { Ok(prefix) => { if prefix.len() <= 5 { - guild_data.prefix = prefix; + let reply = format!("Prefix changed to `{}`", prefix); - guild_data.commit(pool).await?; + { + guild_data.write().await.prefix = prefix; + } - msg.channel_id - .say(&ctx, format!("Prefix changed to `{}`", guild_data.prefix)) - .await?; + { + let read = guild_data.read().await; + + read.commit(pool).await?; + } + + msg.channel_id.say(&ctx, reply).await?; } else { msg.channel_id .say(&ctx, "Prefix must be less than 5 characters long") @@ -701,7 +740,7 @@ async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandR &ctx, format!( "Usage: `{prefix}prefix `", - prefix = guild_data.prefix + prefix = guild_data.read().await.prefix ), ) .await?; @@ -713,7 +752,7 @@ async fn change_prefix(ctx: &Context, msg: &Message, mut args: Args) -> CommandR &ctx, format!( "Usage: `{prefix}prefix `", - prefix = guild_data.prefix + prefix = guild_data.read().await.prefix ), ) .await?; @@ -1265,14 +1304,6 @@ WHERE #[command] #[permission_level(Managed)] async fn allow_greet_sounds(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { - let guild = match msg.guild(&ctx.cache).await { - Some(guild) => guild, - - None => { - return Ok(()); - } - }; - let pool = ctx .data .read() @@ -1281,19 +1312,23 @@ async fn allow_greet_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comman .cloned() .expect("Could not acquire SQL pool from data"); - let guild_data_opt = GuildData::get_from_id(guild, pool.clone()).await; + let guild_data_opt = ctx.get_from_id(msg.guild_id.unwrap()).await; - if let Some(mut guild_data) = guild_data_opt { - guild_data.allow_greets = !guild_data.allow_greets; + if let Ok(guild_data) = guild_data_opt { + let current = guild_data.read().await.allow_greets; - guild_data.commit(pool).await?; + { + guild_data.write().await.allow_greets = !current; + } + + guild_data.read().await.commit(pool).await?; msg.channel_id .say( &ctx, format!( "Greet sounds have been {}abled in this server", - if guild_data.allow_greets { "en" } else { "dis" } + if !current { "en" } else { "dis" } ), ) .await?;