generics for Database pool

This commit is contained in:
jellywx 2022-01-30 15:55:53 +00:00
parent c364343fe9
commit eb5ea3167d
8 changed files with 49 additions and 69 deletions

View File

@ -29,11 +29,10 @@ pub async fn upload_new_sound(
if !name.is_empty() && name.len() <= 20 { if !name.is_empty() && name.len() <= 20 {
if !is_numeric(&name) { if !is_numeric(&name) {
let pool = ctx.data().database.clone();
// need to check the name is not currently in use by the user // need to check the name is not currently in use by the user
let count_name = let count_name =
Sound::count_named_user_sounds(ctx.author().id, &name, pool.clone()).await?; Sound::count_named_user_sounds(ctx.author().id, &name, &ctx.data().database)
.await?;
if count_name > 0 { if count_name > 0 {
ctx.say( ctx.say(
"You are already using that name. Please choose a unique name for your upload.", "You are already using that name. Please choose a unique name for your upload.",
@ -41,7 +40,7 @@ pub async fn upload_new_sound(
.await?; .await?;
} else { } else {
// need to check how many sounds user currently has // need to check how many sounds user currently has
let count = Sound::count_user_sounds(ctx.author().id, pool.clone()).await?; let count = Sound::count_user_sounds(ctx.author().id, &ctx.data().database).await?;
let mut permit_upload = true; let mut permit_upload = true;
// need to check if user is patreon or nah // need to check if user is patreon or nah
@ -93,7 +92,7 @@ pub async fn upload_new_sound(
url.as_str(), url.as_str(),
ctx.guild_id().unwrap(), ctx.guild_id().unwrap(),
ctx.author().id, ctx.author().id,
pool, &ctx.data().database,
) )
.await .await
{ {
@ -160,7 +159,7 @@ pub async fn delete_sound(
}; };
if sound.uploader_id == Some(uid) || has_perms { if sound.uploader_id == Some(uid) || has_perms {
sound.delete(pool).await?; sound.delete(&pool).await?;
ctx.say("Sound has been deleted").await?; ctx.say("Sound has been deleted").await?;
} else { } else {
@ -209,7 +208,7 @@ pub async fn change_public(
ctx.say("Sound has been set to public 🔓").await?; ctx.say("Sound has been set to public 🔓").await?;
} }
sound.commit(pool).await? sound.commit(&pool).await?
} }
} }
@ -238,9 +237,7 @@ pub async fn download_file(
match sound.first() { match sound.first() {
Some(sound) => { Some(sound) => {
let source = sound let source = sound.store_sound_source(&ctx.data().database).await?;
.store_sound_source(ctx.data().database.clone())
.await?;
let file = File::open(&source).await?; let file = File::open(&source).await?;
let name = format!("{}-{}.opus", sound.id, sound.name); let name = format!("{}-{}.opus", sound.id, sound.name);

View File

@ -15,11 +15,7 @@ pub async fn change_volume(
if let Some(volume) = volume { if let Some(volume) = volume {
guild_data.write().await.volume = volume as u8; guild_data.write().await.volume = volume as u8;
guild_data guild_data.read().await.commit(&ctx.data().database).await?;
.read()
.await
.commit(ctx.data().database.clone())
.await?;
ctx.say(format!("Volume changed to {}%", volume)).await?; ctx.say(format!("Volume changed to {}%", volume)).await?;
} else { } else {
@ -91,11 +87,7 @@ pub async fn disable_greet_sound(ctx: Context<'_>) -> Result<(), Error> {
if let Ok(guild_data) = guild_data_opt { if let Ok(guild_data) = guild_data_opt {
guild_data.write().await.allow_greets = false; guild_data.write().await.allow_greets = false;
guild_data guild_data.read().await.commit(&ctx.data().database).await?;
.read()
.await
.commit(ctx.data().database.clone())
.await?;
} }
ctx.say("Greet sounds have been disabled in this server") ctx.say("Greet sounds have been disabled in this server")
@ -112,11 +104,7 @@ pub async fn enable_greet_sound(ctx: Context<'_>) -> Result<(), Error> {
if let Ok(guild_data) = guild_data_opt { if let Ok(guild_data) = guild_data_opt {
guild_data.write().await.allow_greets = true; guild_data.write().await.allow_greets = true;
guild_data guild_data.read().await.commit(&ctx.data().database).await?;
.read()
.await
.commit(ctx.data().database.clone())
.await?;
} }
ctx.say("Greet sounds have been enable in this server") ctx.say("Greet sounds have been enable in this server")

View File

@ -72,8 +72,6 @@ pub async fn listener(ctx: &Context, event: &poise::Event<'_>, data: &Data) -> R
} }
} else if let (Some(guild_id), Some(user_channel)) = (new.guild_id, new.channel_id) { } else if let (Some(guild_id), Some(user_channel)) = (new.guild_id, new.channel_id) {
if let Some(guild) = ctx.cache.guild(guild_id) { if let Some(guild) = ctx.cache.guild(guild_id) {
let pool = data.database.clone();
let guild_data_opt = data.guild_data(guild.id).await; let guild_data_opt = data.guild_data(guild.id).await;
if let Ok(guild_data) = guild_data_opt { if let Ok(guild_data) = guild_data_opt {
@ -98,7 +96,7 @@ SELECT name, id, public, server_id, uploader_id
", ",
join_id join_id
) )
.fetch_one(&pool) .fetch_one(&data.database)
.await .await
.unwrap(); .unwrap();
@ -108,7 +106,7 @@ SELECT name, id, public, server_id, uploader_id
&mut sound, &mut sound,
volume, volume,
&mut handler.lock().await, &mut handler.lock().await,
pool, &data.database,
false, false,
) )
.await; .await;

View File

@ -20,13 +20,16 @@ use poise::serenity::{
}, },
}; };
use songbird::SerenityInit; use songbird::SerenityInit;
use sqlx::mysql::MySqlPool; use sqlx::{MySql, Pool};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::{event_handlers::listener, models::guild_data::GuildData}; use crate::{event_handlers::listener, models::guild_data::GuildData};
// Which database driver are we using?
type Database = MySql;
pub struct Data { pub struct Data {
database: MySqlPool, database: Pool<Database>,
http: reqwest::Client, http: reqwest::Client,
guild_data_cache: DashMap<GuildId, Arc<RwLock<GuildData>>>, guild_data_cache: DashMap<GuildId, Arc<RwLock<GuildData>>>,
join_sound_cache: DashMap<UserId, Option<u32>>, join_sound_cache: DashMap<UserId, Option<u32>>,
@ -112,7 +115,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
..Default::default() ..Default::default()
}; };
let database = MySqlPool::connect(&env::var("DATABASE_URL").expect("No database URL provided")) let database = Pool::connect(&env::var("DATABASE_URL").expect("No database URL provided"))
.await .await
.unwrap(); .unwrap();

View File

@ -1,10 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use poise::serenity::{async_trait, model::id::GuildId}; use poise::serenity::{async_trait, model::id::GuildId};
use sqlx::mysql::MySqlPool; use sqlx::Executor;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::{Context, Data}; use crate::{Context, Data, Database};
#[derive(Clone)] #[derive(Clone)]
pub struct GuildData { pub struct GuildData {
@ -44,9 +44,7 @@ impl CtxGuildData for Data {
let x = if let Some(guild_data) = self.guild_data_cache.get(&guild_id) { let x = if let Some(guild_data) = self.guild_data_cache.get(&guild_id) {
Ok(guild_data.clone()) Ok(guild_data.clone())
} else { } else {
let pool = self.database.clone(); match GuildData::from_id(guild_id, &self.database).await {
match GuildData::from_id(guild_id, pool).await {
Ok(d) => { Ok(d) => {
let lock = Arc::new(RwLock::new(d)); let lock = Arc::new(RwLock::new(d));
@ -66,7 +64,7 @@ impl CtxGuildData for Data {
impl GuildData { impl GuildData {
pub async fn from_id<G: Into<GuildId>>( pub async fn from_id<G: Into<GuildId>>(
guild_id: G, guild_id: G,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database> + Copy,
) -> Result<GuildData, sqlx::Error> { ) -> Result<GuildData, sqlx::Error> {
let guild_id = guild_id.into(); let guild_id = guild_id.into();
@ -79,7 +77,7 @@ SELECT id, prefix, volume, allow_greets, allowed_role
", ",
guild_id.as_u64() guild_id.as_u64()
) )
.fetch_one(&db_pool) .fetch_one(db_pool)
.await; .await;
match guild_data { match guild_data {
@ -93,7 +91,7 @@ SELECT id, prefix, volume, allow_greets, allowed_role
async fn create_from_guild<G: Into<GuildId>>( async fn create_from_guild<G: Into<GuildId>>(
guild_id: G, guild_id: G,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<GuildData, sqlx::Error> { ) -> Result<GuildData, sqlx::Error> {
let guild_id = guild_id.into(); let guild_id = guild_id.into();
@ -104,7 +102,7 @@ INSERT INTO servers (id)
", ",
guild_id.as_u64() guild_id.as_u64()
) )
.execute(&db_pool) .execute(db_pool)
.await?; .await?;
Ok(GuildData { Ok(GuildData {
@ -118,7 +116,7 @@ INSERT INTO servers (id)
pub async fn commit( pub async fn commit(
&self, &self,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
sqlx::query!( sqlx::query!(
" "
@ -137,7 +135,7 @@ WHERE
self.allowed_role, self.allowed_role,
self.id self.id
) )
.execute(&db_pool) .execute(db_pool)
.await?; .await?;
Ok(()) Ok(())

View File

@ -21,8 +21,6 @@ impl JoinSoundCtx for Data {
join_sound_id.value().clone() join_sound_id.value().clone()
} else { } else {
let join_sound_id = { let join_sound_id = {
let pool = self.database.clone();
let join_id_res = sqlx::query!( let join_id_res = sqlx::query!(
" "
SELECT join_sound_id SELECT join_sound_id
@ -31,7 +29,7 @@ SELECT join_sound_id
", ",
user_id.as_u64() user_id.as_u64()
) )
.fetch_one(&pool) .fetch_one(&self.database)
.await; .await;
if let Ok(row) = join_id_res { if let Ok(row) = join_id_res {

View File

@ -2,10 +2,10 @@ use std::{env, path::Path};
use poise::serenity::async_trait; use poise::serenity::async_trait;
use songbird::input::restartable::Restartable; use songbird::input::restartable::Restartable;
use sqlx::{mysql::MySqlPool, Error}; use sqlx::{Error, Executor};
use tokio::{fs::File, io::AsyncWriteExt, process::Command}; use tokio::{fs::File, io::AsyncWriteExt, process::Command};
use crate::{consts::UPLOAD_MAX_SIZE, error::ErrorTypes, Data}; use crate::{consts::UPLOAD_MAX_SIZE, error::ErrorTypes, Data, Database};
#[derive(Clone)] #[derive(Clone)]
pub struct Sound { pub struct Sound {
@ -208,7 +208,7 @@ SELECT name, id, public, server_id, uploader_id
} }
impl Sound { impl Sound {
async fn src(&self, db_pool: MySqlPool) -> Vec<u8> { async fn src(&self, db_pool: impl Executor<'_, Database = Database>) -> Vec<u8> {
struct Src { struct Src {
src: Vec<u8>, src: Vec<u8>,
} }
@ -223,7 +223,7 @@ SELECT src
", ",
self.id self.id
) )
.fetch_one(&db_pool) .fetch_one(db_pool)
.await .await
.unwrap(); .unwrap();
@ -232,7 +232,7 @@ SELECT src
pub async fn store_sound_source( pub async fn store_sound_source(
&self, &self,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let caching_location = env::var("CACHING_LOCATION").unwrap_or(String::from("/tmp")); let caching_location = env::var("CACHING_LOCATION").unwrap_or(String::from("/tmp"));
@ -250,7 +250,7 @@ SELECT src
pub async fn playable( pub async fn playable(
&self, &self,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<Restartable, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Restartable, Box<dyn std::error::Error + Send + Sync>> {
let path_name = self.store_sound_source(db_pool).await?; let path_name = self.store_sound_source(db_pool).await?;
@ -261,7 +261,7 @@ SELECT src
pub async fn count_user_sounds<U: Into<u64>>( pub async fn count_user_sounds<U: Into<u64>>(
user_id: U, user_id: U,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<u32, sqlx::error::Error> { ) -> Result<u32, sqlx::error::Error> {
let user_id = user_id.into(); let user_id = user_id.into();
@ -273,7 +273,7 @@ SELECT COUNT(1) as count
", ",
user_id user_id
) )
.fetch_one(&db_pool) .fetch_one(db_pool)
.await? .await?
.count; .count;
@ -283,7 +283,7 @@ SELECT COUNT(1) as count
pub async fn count_named_user_sounds<U: Into<u64>>( pub async fn count_named_user_sounds<U: Into<u64>>(
user_id: U, user_id: U,
name: &String, name: &String,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<u32, sqlx::error::Error> { ) -> Result<u32, sqlx::error::Error> {
let user_id = user_id.into(); let user_id = user_id.into();
@ -298,7 +298,7 @@ SELECT COUNT(1) as count
user_id, user_id,
name name
) )
.fetch_one(&db_pool) .fetch_one(db_pool)
.await? .await?
.count; .count;
@ -307,7 +307,7 @@ SELECT COUNT(1) as count
pub async fn commit( pub async fn commit(
&self, &self,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
sqlx::query!( sqlx::query!(
" "
@ -320,7 +320,7 @@ WHERE
self.public, self.public,
self.id self.id
) )
.execute(&db_pool) .execute(db_pool)
.await?; .await?;
Ok(()) Ok(())
@ -328,7 +328,7 @@ WHERE
pub async fn delete( pub async fn delete(
&self, &self,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
sqlx::query!( sqlx::query!(
" "
@ -338,7 +338,7 @@ DELETE
", ",
self.id self.id
) )
.execute(&db_pool) .execute(db_pool)
.await?; .await?;
Ok(()) Ok(())
@ -349,7 +349,7 @@ DELETE
src_url: &str, src_url: &str,
server_id: G, server_id: G,
user_id: U, user_id: U,
db_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + Send>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync + Send>> {
let server_id = server_id.into(); let server_id = server_id.into();
let user_id = user_id.into(); let user_id = user_id.into();
@ -396,7 +396,7 @@ INSERT INTO sounds (name, server_id, uploader_id, public, src)
user_id, user_id,
data data
) )
.execute(&db_pool) .execute(db_pool)
.await .await
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),

View File

@ -6,7 +6,7 @@ use poise::serenity::model::{
id::{ChannelId, UserId}, id::{ChannelId, UserId},
}; };
use songbird::{create_player, error::JoinResult, tracks::TrackHandle, Call}; use songbird::{create_player, error::JoinResult, tracks::TrackHandle, Call};
use sqlx::MySqlPool; use sqlx::Executor;
use tokio::sync::{Mutex, MutexGuard}; use tokio::sync::{Mutex, MutexGuard};
use crate::{ use crate::{
@ -14,17 +14,17 @@ use crate::{
guild_data::CtxGuildData, guild_data::CtxGuildData,
sound::{Sound, SoundCtx}, sound::{Sound, SoundCtx},
}, },
Data, Data, Database,
}; };
pub async fn play_audio( pub async fn play_audio(
sound: &mut Sound, sound: &mut Sound,
volume: u8, volume: u8,
call_handler: &mut MutexGuard<'_, Call>, call_handler: &mut MutexGuard<'_, Call>,
mysql_pool: MySqlPool, db_pool: impl Executor<'_, Database = Database>,
loop_: bool, loop_: bool,
) -> Result<TrackHandle, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<TrackHandle, Box<dyn std::error::Error + Send + Sync>> {
let (track, track_handler) = create_player(sound.playable(mysql_pool.clone()).await?.into()); let (track, track_handler) = create_player(sound.playable(db_pool).await?.into());
let _ = track_handler.set_volume(volume as f32 / 100.0); let _ = track_handler.set_volume(volume as f32 / 100.0);
@ -99,8 +99,6 @@ pub async fn play_from_query(
match channel_to_join { match channel_to_join {
Some(user_channel) => { Some(user_channel) => {
let pool = data.database.clone();
let mut sound_vec = data let mut sound_vec = data
.search_for_sound(query, guild_id, user_id, true) .search_for_sound(query, guild_id, user_id, true)
.await .await
@ -122,7 +120,7 @@ pub async fn play_from_query(
sound, sound,
guild_data.read().await.volume, guild_data.read().await.volume,
&mut lock, &mut lock,
pool, &data.database,
loop_, loop_,
) )
.await .await