From eb5ea3167d54300a521c086732c234bcb82d0e87 Mon Sep 17 00:00:00 2001 From: jellywx Date: Sun, 30 Jan 2022 15:55:53 +0000 Subject: [PATCH] generics for Database pool --- src/cmds/manage.rs | 17 +++++++---------- src/cmds/settings.rs | 18 +++--------------- src/event_handlers.rs | 6 ++---- src/main.rs | 9 ++++++--- src/models/guild_data.rs | 20 +++++++++----------- src/models/join_sound.rs | 4 +--- src/models/sound.rs | 32 ++++++++++++++++---------------- src/utils.rs | 12 +++++------- 8 files changed, 49 insertions(+), 69 deletions(-) diff --git a/src/cmds/manage.rs b/src/cmds/manage.rs index a50cbe4..e20ce52 100644 --- a/src/cmds/manage.rs +++ b/src/cmds/manage.rs @@ -29,11 +29,10 @@ pub async fn upload_new_sound( if !name.is_empty() && name.len() <= 20 { if !is_numeric(&name) { - let pool = ctx.data().database.clone(); - // need to check the name is not currently in use by the user let count_name = - Sound::count_named_user_sounds(ctx.author().id, &name, pool.clone()).await?; + Sound::count_named_user_sounds(ctx.author().id, &name, &ctx.data().database) + .await?; if count_name > 0 { ctx.say( "You are already using that name. Please choose a unique name for your upload.", @@ -41,7 +40,7 @@ pub async fn upload_new_sound( .await?; } else { // need to check how many sounds user currently has - let count = Sound::count_user_sounds(ctx.author().id, pool.clone()).await?; + let count = Sound::count_user_sounds(ctx.author().id, &ctx.data().database).await?; let mut permit_upload = true; // need to check if user is patreon or nah @@ -93,7 +92,7 @@ pub async fn upload_new_sound( url.as_str(), ctx.guild_id().unwrap(), ctx.author().id, - pool, + &ctx.data().database, ) .await { @@ -160,7 +159,7 @@ pub async fn delete_sound( }; if sound.uploader_id == Some(uid) || has_perms { - sound.delete(pool).await?; + sound.delete(&pool).await?; ctx.say("Sound has been deleted").await?; } else { @@ -209,7 +208,7 @@ pub async fn change_public( ctx.say("Sound has been set to public 🔓").await?; } - sound.commit(pool).await? + sound.commit(&pool).await? } } @@ -238,9 +237,7 @@ pub async fn download_file( match sound.first() { Some(sound) => { - let source = sound - .store_sound_source(ctx.data().database.clone()) - .await?; + let source = sound.store_sound_source(&ctx.data().database).await?; let file = File::open(&source).await?; let name = format!("{}-{}.opus", sound.id, sound.name); diff --git a/src/cmds/settings.rs b/src/cmds/settings.rs index c8a8676..04da24b 100644 --- a/src/cmds/settings.rs +++ b/src/cmds/settings.rs @@ -15,11 +15,7 @@ pub async fn change_volume( if let Some(volume) = volume { guild_data.write().await.volume = volume as u8; - guild_data - .read() - .await - .commit(ctx.data().database.clone()) - .await?; + guild_data.read().await.commit(&ctx.data().database).await?; ctx.say(format!("Volume changed to {}%", volume)).await?; } else { @@ -91,11 +87,7 @@ pub async fn disable_greet_sound(ctx: Context<'_>) -> Result<(), Error> { if let Ok(guild_data) = guild_data_opt { guild_data.write().await.allow_greets = false; - guild_data - .read() - .await - .commit(ctx.data().database.clone()) - .await?; + guild_data.read().await.commit(&ctx.data().database).await?; } ctx.say("Greet sounds have been disabled in this server") @@ -112,11 +104,7 @@ pub async fn enable_greet_sound(ctx: Context<'_>) -> Result<(), Error> { if let Ok(guild_data) = guild_data_opt { guild_data.write().await.allow_greets = true; - guild_data - .read() - .await - .commit(ctx.data().database.clone()) - .await?; + guild_data.read().await.commit(&ctx.data().database).await?; } ctx.say("Greet sounds have been enable in this server") diff --git a/src/event_handlers.rs b/src/event_handlers.rs index 8ffa91d..d19bed3 100644 --- a/src/event_handlers.rs +++ b/src/event_handlers.rs @@ -72,8 +72,6 @@ pub async fn listener(ctx: &Context, event: &poise::Event<'_>, data: &Data) -> R } } else if let (Some(guild_id), Some(user_channel)) = (new.guild_id, new.channel_id) { if let Some(guild) = ctx.cache.guild(guild_id) { - let pool = data.database.clone(); - let guild_data_opt = data.guild_data(guild.id).await; if let Ok(guild_data) = guild_data_opt { @@ -98,7 +96,7 @@ SELECT name, id, public, server_id, uploader_id ", join_id ) - .fetch_one(&pool) + .fetch_one(&data.database) .await .unwrap(); @@ -108,7 +106,7 @@ SELECT name, id, public, server_id, uploader_id &mut sound, volume, &mut handler.lock().await, - pool, + &data.database, false, ) .await; diff --git a/src/main.rs b/src/main.rs index 09e9826..b6d1c41 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,13 +20,16 @@ use poise::serenity::{ }, }; use songbird::SerenityInit; -use sqlx::mysql::MySqlPool; +use sqlx::{MySql, Pool}; use tokio::sync::RwLock; use crate::{event_handlers::listener, models::guild_data::GuildData}; +// Which database driver are we using? +type Database = MySql; + pub struct Data { - database: MySqlPool, + database: Pool, http: reqwest::Client, guild_data_cache: DashMap>>, join_sound_cache: DashMap>, @@ -112,7 +115,7 @@ async fn main() -> Result<(), Box> { ..Default::default() }; - let database = MySqlPool::connect(&env::var("DATABASE_URL").expect("No database URL provided")) + let database = Pool::connect(&env::var("DATABASE_URL").expect("No database URL provided")) .await .unwrap(); diff --git a/src/models/guild_data.rs b/src/models/guild_data.rs index 2ab5efd..883bd7a 100644 --- a/src/models/guild_data.rs +++ b/src/models/guild_data.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use poise::serenity::{async_trait, model::id::GuildId}; -use sqlx::mysql::MySqlPool; +use sqlx::Executor; use tokio::sync::RwLock; -use crate::{Context, Data}; +use crate::{Context, Data, Database}; #[derive(Clone)] pub struct GuildData { @@ -44,9 +44,7 @@ impl CtxGuildData for Data { let x = if let Some(guild_data) = self.guild_data_cache.get(&guild_id) { Ok(guild_data.clone()) } else { - let pool = self.database.clone(); - - match GuildData::from_id(guild_id, pool).await { + match GuildData::from_id(guild_id, &self.database).await { Ok(d) => { let lock = Arc::new(RwLock::new(d)); @@ -66,7 +64,7 @@ impl CtxGuildData for Data { impl GuildData { pub async fn from_id>( guild_id: G, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database> + Copy, ) -> Result { let guild_id = guild_id.into(); @@ -79,7 +77,7 @@ SELECT id, prefix, volume, allow_greets, allowed_role ", guild_id.as_u64() ) - .fetch_one(&db_pool) + .fetch_one(db_pool) .await; match guild_data { @@ -93,7 +91,7 @@ SELECT id, prefix, volume, allow_greets, allowed_role async fn create_from_guild>( guild_id: G, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result { let guild_id = guild_id.into(); @@ -104,7 +102,7 @@ INSERT INTO servers (id) ", guild_id.as_u64() ) - .execute(&db_pool) + .execute(db_pool) .await?; Ok(GuildData { @@ -118,7 +116,7 @@ INSERT INTO servers (id) pub async fn commit( &self, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result<(), Box> { sqlx::query!( " @@ -137,7 +135,7 @@ WHERE self.allowed_role, self.id ) - .execute(&db_pool) + .execute(db_pool) .await?; Ok(()) diff --git a/src/models/join_sound.rs b/src/models/join_sound.rs index b6c8cd4..8be1278 100644 --- a/src/models/join_sound.rs +++ b/src/models/join_sound.rs @@ -21,8 +21,6 @@ impl JoinSoundCtx for Data { join_sound_id.value().clone() } else { let join_sound_id = { - let pool = self.database.clone(); - let join_id_res = sqlx::query!( " SELECT join_sound_id @@ -31,7 +29,7 @@ SELECT join_sound_id ", user_id.as_u64() ) - .fetch_one(&pool) + .fetch_one(&self.database) .await; if let Ok(row) = join_id_res { diff --git a/src/models/sound.rs b/src/models/sound.rs index 45e2bce..7f57c56 100644 --- a/src/models/sound.rs +++ b/src/models/sound.rs @@ -2,10 +2,10 @@ use std::{env, path::Path}; use poise::serenity::async_trait; use songbird::input::restartable::Restartable; -use sqlx::{mysql::MySqlPool, Error}; +use sqlx::{Error, Executor}; use tokio::{fs::File, io::AsyncWriteExt, process::Command}; -use crate::{consts::UPLOAD_MAX_SIZE, error::ErrorTypes, Data}; +use crate::{consts::UPLOAD_MAX_SIZE, error::ErrorTypes, Data, Database}; #[derive(Clone)] pub struct Sound { @@ -208,7 +208,7 @@ SELECT name, id, public, server_id, uploader_id } impl Sound { - async fn src(&self, db_pool: MySqlPool) -> Vec { + async fn src(&self, db_pool: impl Executor<'_, Database = Database>) -> Vec { struct Src { src: Vec, } @@ -223,7 +223,7 @@ SELECT src ", self.id ) - .fetch_one(&db_pool) + .fetch_one(db_pool) .await .unwrap(); @@ -232,7 +232,7 @@ SELECT src pub async fn store_sound_source( &self, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result> { let caching_location = env::var("CACHING_LOCATION").unwrap_or(String::from("/tmp")); @@ -250,7 +250,7 @@ SELECT src pub async fn playable( &self, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result> { let path_name = self.store_sound_source(db_pool).await?; @@ -261,7 +261,7 @@ SELECT src pub async fn count_user_sounds>( user_id: U, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result { let user_id = user_id.into(); @@ -273,7 +273,7 @@ SELECT COUNT(1) as count ", user_id ) - .fetch_one(&db_pool) + .fetch_one(db_pool) .await? .count; @@ -283,7 +283,7 @@ SELECT COUNT(1) as count pub async fn count_named_user_sounds>( user_id: U, name: &String, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result { let user_id = user_id.into(); @@ -298,7 +298,7 @@ SELECT COUNT(1) as count user_id, name ) - .fetch_one(&db_pool) + .fetch_one(db_pool) .await? .count; @@ -307,7 +307,7 @@ SELECT COUNT(1) as count pub async fn commit( &self, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result<(), Box> { sqlx::query!( " @@ -320,7 +320,7 @@ WHERE self.public, self.id ) - .execute(&db_pool) + .execute(db_pool) .await?; Ok(()) @@ -328,7 +328,7 @@ WHERE pub async fn delete( &self, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result<(), Box> { sqlx::query!( " @@ -338,7 +338,7 @@ DELETE ", self.id ) - .execute(&db_pool) + .execute(db_pool) .await?; Ok(()) @@ -349,7 +349,7 @@ DELETE src_url: &str, server_id: G, user_id: U, - db_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, ) -> Result<(), Box> { let server_id = server_id.into(); let user_id = user_id.into(); @@ -396,7 +396,7 @@ INSERT INTO sounds (name, server_id, uploader_id, public, src) user_id, data ) - .execute(&db_pool) + .execute(db_pool) .await { Ok(_) => Ok(()), diff --git a/src/utils.rs b/src/utils.rs index ed5ac71..2fbcaf9 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -6,7 +6,7 @@ use poise::serenity::model::{ id::{ChannelId, UserId}, }; use songbird::{create_player, error::JoinResult, tracks::TrackHandle, Call}; -use sqlx::MySqlPool; +use sqlx::Executor; use tokio::sync::{Mutex, MutexGuard}; use crate::{ @@ -14,17 +14,17 @@ use crate::{ guild_data::CtxGuildData, sound::{Sound, SoundCtx}, }, - Data, + Data, Database, }; pub async fn play_audio( sound: &mut Sound, volume: u8, call_handler: &mut MutexGuard<'_, Call>, - mysql_pool: MySqlPool, + db_pool: impl Executor<'_, Database = Database>, loop_: bool, ) -> Result> { - let (track, track_handler) = create_player(sound.playable(mysql_pool.clone()).await?.into()); + let (track, track_handler) = create_player(sound.playable(db_pool).await?.into()); let _ = track_handler.set_volume(volume as f32 / 100.0); @@ -99,8 +99,6 @@ pub async fn play_from_query( match channel_to_join { Some(user_channel) => { - let pool = data.database.clone(); - let mut sound_vec = data .search_for_sound(query, guild_id, user_id, true) .await @@ -122,7 +120,7 @@ pub async fn play_from_query( sound, guild_data.read().await.volume, &mut lock, - pool, + &data.database, loop_, ) .await