generics for Database pool

This commit is contained in:
jellywx 2022-01-30 15:55:53 +00:00
parent c364343fe9
commit eb5ea3167d
8 changed files with 49 additions and 69 deletions

View File

@ -29,11 +29,10 @@ pub async fn upload_new_sound(
if !name.is_empty() && name.len() <= 20 {
if !is_numeric(&name) {
let pool = ctx.data().database.clone();
// need to check the name is not currently in use by the user
let count_name =
Sound::count_named_user_sounds(ctx.author().id, &name, pool.clone()).await?;
Sound::count_named_user_sounds(ctx.author().id, &name, &ctx.data().database)
.await?;
if count_name > 0 {
ctx.say(
"You are already using that name. Please choose a unique name for your upload.",
@ -41,7 +40,7 @@ pub async fn upload_new_sound(
.await?;
} else {
// need to check how many sounds user currently has
let count = Sound::count_user_sounds(ctx.author().id, pool.clone()).await?;
let count = Sound::count_user_sounds(ctx.author().id, &ctx.data().database).await?;
let mut permit_upload = true;
// need to check if user is patreon or nah
@ -93,7 +92,7 @@ pub async fn upload_new_sound(
url.as_str(),
ctx.guild_id().unwrap(),
ctx.author().id,
pool,
&ctx.data().database,
)
.await
{
@ -160,7 +159,7 @@ pub async fn delete_sound(
};
if sound.uploader_id == Some(uid) || has_perms {
sound.delete(pool).await?;
sound.delete(&pool).await?;
ctx.say("Sound has been deleted").await?;
} else {
@ -209,7 +208,7 @@ pub async fn change_public(
ctx.say("Sound has been set to public 🔓").await?;
}
sound.commit(pool).await?
sound.commit(&pool).await?
}
}
@ -238,9 +237,7 @@ pub async fn download_file(
match sound.first() {
Some(sound) => {
let source = sound
.store_sound_source(ctx.data().database.clone())
.await?;
let source = sound.store_sound_source(&ctx.data().database).await?;
let file = File::open(&source).await?;
let name = format!("{}-{}.opus", sound.id, sound.name);

View File

@ -15,11 +15,7 @@ pub async fn change_volume(
if let Some(volume) = volume {
guild_data.write().await.volume = volume as u8;
guild_data
.read()
.await
.commit(ctx.data().database.clone())
.await?;
guild_data.read().await.commit(&ctx.data().database).await?;
ctx.say(format!("Volume changed to {}%", volume)).await?;
} else {
@ -91,11 +87,7 @@ pub async fn disable_greet_sound(ctx: Context<'_>) -> Result<(), Error> {
if let Ok(guild_data) = guild_data_opt {
guild_data.write().await.allow_greets = false;
guild_data
.read()
.await
.commit(ctx.data().database.clone())
.await?;
guild_data.read().await.commit(&ctx.data().database).await?;
}
ctx.say("Greet sounds have been disabled in this server")
@ -112,11 +104,7 @@ pub async fn enable_greet_sound(ctx: Context<'_>) -> Result<(), Error> {
if let Ok(guild_data) = guild_data_opt {
guild_data.write().await.allow_greets = true;
guild_data
.read()
.await
.commit(ctx.data().database.clone())
.await?;
guild_data.read().await.commit(&ctx.data().database).await?;
}
ctx.say("Greet sounds have been enable in this server")

View File

@ -72,8 +72,6 @@ pub async fn listener(ctx: &Context, event: &poise::Event<'_>, data: &Data) -> R
}
} else if let (Some(guild_id), Some(user_channel)) = (new.guild_id, new.channel_id) {
if let Some(guild) = ctx.cache.guild(guild_id) {
let pool = data.database.clone();
let guild_data_opt = data.guild_data(guild.id).await;
if let Ok(guild_data) = guild_data_opt {
@ -98,7 +96,7 @@ SELECT name, id, public, server_id, uploader_id
",
join_id
)
.fetch_one(&pool)
.fetch_one(&data.database)
.await
.unwrap();
@ -108,7 +106,7 @@ SELECT name, id, public, server_id, uploader_id
&mut sound,
volume,
&mut handler.lock().await,
pool,
&data.database,
false,
)
.await;

View File

@ -20,13 +20,16 @@ use poise::serenity::{
},
};
use songbird::SerenityInit;
use sqlx::mysql::MySqlPool;
use sqlx::{MySql, Pool};
use tokio::sync::RwLock;
use crate::{event_handlers::listener, models::guild_data::GuildData};
// Which database driver are we using?
type Database = MySql;
pub struct Data {
database: MySqlPool,
database: Pool<Database>,
http: reqwest::Client,
guild_data_cache: DashMap<GuildId, Arc<RwLock<GuildData>>>,
join_sound_cache: DashMap<UserId, Option<u32>>,
@ -112,7 +115,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
..Default::default()
};
let database = MySqlPool::connect(&env::var("DATABASE_URL").expect("No database URL provided"))
let database = Pool::connect(&env::var("DATABASE_URL").expect("No database URL provided"))
.await
.unwrap();

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use poise::serenity::{async_trait, model::id::GuildId};
use sqlx::mysql::MySqlPool;
use sqlx::Executor;
use tokio::sync::RwLock;
use crate::{Context, Data};
use crate::{Context, Data, Database};
#[derive(Clone)]
pub struct GuildData {
@ -44,9 +44,7 @@ impl CtxGuildData for Data {
let x = if let Some(guild_data) = self.guild_data_cache.get(&guild_id) {
Ok(guild_data.clone())
} else {
let pool = self.database.clone();
match GuildData::from_id(guild_id, pool).await {
match GuildData::from_id(guild_id, &self.database).await {
Ok(d) => {
let lock = Arc::new(RwLock::new(d));
@ -66,7 +64,7 @@ impl CtxGuildData for Data {
impl GuildData {
pub async fn from_id<G: Into<GuildId>>(
guild_id: G,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database> + Copy,
) -> Result<GuildData, sqlx::Error> {
let guild_id = guild_id.into();
@ -79,7 +77,7 @@ SELECT id, prefix, volume, allow_greets, allowed_role
",
guild_id.as_u64()
)
.fetch_one(&db_pool)
.fetch_one(db_pool)
.await;
match guild_data {
@ -93,7 +91,7 @@ SELECT id, prefix, volume, allow_greets, allowed_role
async fn create_from_guild<G: Into<GuildId>>(
guild_id: G,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<GuildData, sqlx::Error> {
let guild_id = guild_id.into();
@ -104,7 +102,7 @@ INSERT INTO servers (id)
",
guild_id.as_u64()
)
.execute(&db_pool)
.execute(db_pool)
.await?;
Ok(GuildData {
@ -118,7 +116,7 @@ INSERT INTO servers (id)
pub async fn commit(
&self,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
sqlx::query!(
"
@ -137,7 +135,7 @@ WHERE
self.allowed_role,
self.id
)
.execute(&db_pool)
.execute(db_pool)
.await?;
Ok(())

View File

@ -21,8 +21,6 @@ impl JoinSoundCtx for Data {
join_sound_id.value().clone()
} else {
let join_sound_id = {
let pool = self.database.clone();
let join_id_res = sqlx::query!(
"
SELECT join_sound_id
@ -31,7 +29,7 @@ SELECT join_sound_id
",
user_id.as_u64()
)
.fetch_one(&pool)
.fetch_one(&self.database)
.await;
if let Ok(row) = join_id_res {

View File

@ -2,10 +2,10 @@ use std::{env, path::Path};
use poise::serenity::async_trait;
use songbird::input::restartable::Restartable;
use sqlx::{mysql::MySqlPool, Error};
use sqlx::{Error, Executor};
use tokio::{fs::File, io::AsyncWriteExt, process::Command};
use crate::{consts::UPLOAD_MAX_SIZE, error::ErrorTypes, Data};
use crate::{consts::UPLOAD_MAX_SIZE, error::ErrorTypes, Data, Database};
#[derive(Clone)]
pub struct Sound {
@ -208,7 +208,7 @@ SELECT name, id, public, server_id, uploader_id
}
impl Sound {
async fn src(&self, db_pool: MySqlPool) -> Vec<u8> {
async fn src(&self, db_pool: impl Executor<'_, Database = Database>) -> Vec<u8> {
struct Src {
src: Vec<u8>,
}
@ -223,7 +223,7 @@ SELECT src
",
self.id
)
.fetch_one(&db_pool)
.fetch_one(db_pool)
.await
.unwrap();
@ -232,7 +232,7 @@ SELECT src
pub async fn store_sound_source(
&self,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let caching_location = env::var("CACHING_LOCATION").unwrap_or(String::from("/tmp"));
@ -250,7 +250,7 @@ SELECT src
pub async fn playable(
&self,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<Restartable, Box<dyn std::error::Error + Send + Sync>> {
let path_name = self.store_sound_source(db_pool).await?;
@ -261,7 +261,7 @@ SELECT src
pub async fn count_user_sounds<U: Into<u64>>(
user_id: U,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<u32, sqlx::error::Error> {
let user_id = user_id.into();
@ -273,7 +273,7 @@ SELECT COUNT(1) as count
",
user_id
)
.fetch_one(&db_pool)
.fetch_one(db_pool)
.await?
.count;
@ -283,7 +283,7 @@ SELECT COUNT(1) as count
pub async fn count_named_user_sounds<U: Into<u64>>(
user_id: U,
name: &String,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<u32, sqlx::error::Error> {
let user_id = user_id.into();
@ -298,7 +298,7 @@ SELECT COUNT(1) as count
user_id,
name
)
.fetch_one(&db_pool)
.fetch_one(db_pool)
.await?
.count;
@ -307,7 +307,7 @@ SELECT COUNT(1) as count
pub async fn commit(
&self,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
sqlx::query!(
"
@ -320,7 +320,7 @@ WHERE
self.public,
self.id
)
.execute(&db_pool)
.execute(db_pool)
.await?;
Ok(())
@ -328,7 +328,7 @@ WHERE
pub async fn delete(
&self,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
sqlx::query!(
"
@ -338,7 +338,7 @@ DELETE
",
self.id
)
.execute(&db_pool)
.execute(db_pool)
.await?;
Ok(())
@ -349,7 +349,7 @@ DELETE
src_url: &str,
server_id: G,
user_id: U,
db_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + Send>> {
let server_id = server_id.into();
let user_id = user_id.into();
@ -396,7 +396,7 @@ INSERT INTO sounds (name, server_id, uploader_id, public, src)
user_id,
data
)
.execute(&db_pool)
.execute(db_pool)
.await
{
Ok(_) => Ok(()),

View File

@ -6,7 +6,7 @@ use poise::serenity::model::{
id::{ChannelId, UserId},
};
use songbird::{create_player, error::JoinResult, tracks::TrackHandle, Call};
use sqlx::MySqlPool;
use sqlx::Executor;
use tokio::sync::{Mutex, MutexGuard};
use crate::{
@ -14,17 +14,17 @@ use crate::{
guild_data::CtxGuildData,
sound::{Sound, SoundCtx},
},
Data,
Data, Database,
};
pub async fn play_audio(
sound: &mut Sound,
volume: u8,
call_handler: &mut MutexGuard<'_, Call>,
mysql_pool: MySqlPool,
db_pool: impl Executor<'_, Database = Database>,
loop_: bool,
) -> Result<TrackHandle, Box<dyn std::error::Error + Send + Sync>> {
let (track, track_handler) = create_player(sound.playable(mysql_pool.clone()).await?.into());
let (track, track_handler) = create_player(sound.playable(db_pool).await?.into());
let _ = track_handler.set_volume(volume as f32 / 100.0);
@ -99,8 +99,6 @@ pub async fn play_from_query(
match channel_to_join {
Some(user_channel) => {
let pool = data.database.clone();
let mut sound_vec = data
.search_for_sound(query, guild_id, user_id, true)
.await
@ -122,7 +120,7 @@ pub async fn play_from_query(
sound,
guild_data.read().await.volume,
&mut lock,
pool,
&data.database,
loop_,
)
.await