generics for Database pool
This commit is contained in:
parent
c364343fe9
commit
eb5ea3167d
@ -29,11 +29,10 @@ pub async fn upload_new_sound(
|
||||
|
||||
if !name.is_empty() && name.len() <= 20 {
|
||||
if !is_numeric(&name) {
|
||||
let pool = ctx.data().database.clone();
|
||||
|
||||
// need to check the name is not currently in use by the user
|
||||
let count_name =
|
||||
Sound::count_named_user_sounds(ctx.author().id, &name, pool.clone()).await?;
|
||||
Sound::count_named_user_sounds(ctx.author().id, &name, &ctx.data().database)
|
||||
.await?;
|
||||
if count_name > 0 {
|
||||
ctx.say(
|
||||
"You are already using that name. Please choose a unique name for your upload.",
|
||||
@ -41,7 +40,7 @@ pub async fn upload_new_sound(
|
||||
.await?;
|
||||
} else {
|
||||
// need to check how many sounds user currently has
|
||||
let count = Sound::count_user_sounds(ctx.author().id, pool.clone()).await?;
|
||||
let count = Sound::count_user_sounds(ctx.author().id, &ctx.data().database).await?;
|
||||
let mut permit_upload = true;
|
||||
|
||||
// need to check if user is patreon or nah
|
||||
@ -93,7 +92,7 @@ pub async fn upload_new_sound(
|
||||
url.as_str(),
|
||||
ctx.guild_id().unwrap(),
|
||||
ctx.author().id,
|
||||
pool,
|
||||
&ctx.data().database,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@ -160,7 +159,7 @@ pub async fn delete_sound(
|
||||
};
|
||||
|
||||
if sound.uploader_id == Some(uid) || has_perms {
|
||||
sound.delete(pool).await?;
|
||||
sound.delete(&pool).await?;
|
||||
|
||||
ctx.say("Sound has been deleted").await?;
|
||||
} else {
|
||||
@ -209,7 +208,7 @@ pub async fn change_public(
|
||||
ctx.say("Sound has been set to public 🔓").await?;
|
||||
}
|
||||
|
||||
sound.commit(pool).await?
|
||||
sound.commit(&pool).await?
|
||||
}
|
||||
}
|
||||
|
||||
@ -238,9 +237,7 @@ pub async fn download_file(
|
||||
|
||||
match sound.first() {
|
||||
Some(sound) => {
|
||||
let source = sound
|
||||
.store_sound_source(ctx.data().database.clone())
|
||||
.await?;
|
||||
let source = sound.store_sound_source(&ctx.data().database).await?;
|
||||
|
||||
let file = File::open(&source).await?;
|
||||
let name = format!("{}-{}.opus", sound.id, sound.name);
|
||||
|
@ -15,11 +15,7 @@ pub async fn change_volume(
|
||||
if let Some(volume) = volume {
|
||||
guild_data.write().await.volume = volume as u8;
|
||||
|
||||
guild_data
|
||||
.read()
|
||||
.await
|
||||
.commit(ctx.data().database.clone())
|
||||
.await?;
|
||||
guild_data.read().await.commit(&ctx.data().database).await?;
|
||||
|
||||
ctx.say(format!("Volume changed to {}%", volume)).await?;
|
||||
} else {
|
||||
@ -91,11 +87,7 @@ pub async fn disable_greet_sound(ctx: Context<'_>) -> Result<(), Error> {
|
||||
if let Ok(guild_data) = guild_data_opt {
|
||||
guild_data.write().await.allow_greets = false;
|
||||
|
||||
guild_data
|
||||
.read()
|
||||
.await
|
||||
.commit(ctx.data().database.clone())
|
||||
.await?;
|
||||
guild_data.read().await.commit(&ctx.data().database).await?;
|
||||
}
|
||||
|
||||
ctx.say("Greet sounds have been disabled in this server")
|
||||
@ -112,11 +104,7 @@ pub async fn enable_greet_sound(ctx: Context<'_>) -> Result<(), Error> {
|
||||
if let Ok(guild_data) = guild_data_opt {
|
||||
guild_data.write().await.allow_greets = true;
|
||||
|
||||
guild_data
|
||||
.read()
|
||||
.await
|
||||
.commit(ctx.data().database.clone())
|
||||
.await?;
|
||||
guild_data.read().await.commit(&ctx.data().database).await?;
|
||||
}
|
||||
|
||||
ctx.say("Greet sounds have been enable in this server")
|
||||
|
@ -72,8 +72,6 @@ pub async fn listener(ctx: &Context, event: &poise::Event<'_>, data: &Data) -> R
|
||||
}
|
||||
} else if let (Some(guild_id), Some(user_channel)) = (new.guild_id, new.channel_id) {
|
||||
if let Some(guild) = ctx.cache.guild(guild_id) {
|
||||
let pool = data.database.clone();
|
||||
|
||||
let guild_data_opt = data.guild_data(guild.id).await;
|
||||
|
||||
if let Ok(guild_data) = guild_data_opt {
|
||||
@ -98,7 +96,7 @@ SELECT name, id, public, server_id, uploader_id
|
||||
",
|
||||
join_id
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.fetch_one(&data.database)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -108,7 +106,7 @@ SELECT name, id, public, server_id, uploader_id
|
||||
&mut sound,
|
||||
volume,
|
||||
&mut handler.lock().await,
|
||||
pool,
|
||||
&data.database,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
@ -20,13 +20,16 @@ use poise::serenity::{
|
||||
},
|
||||
};
|
||||
use songbird::SerenityInit;
|
||||
use sqlx::mysql::MySqlPool;
|
||||
use sqlx::{MySql, Pool};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::{event_handlers::listener, models::guild_data::GuildData};
|
||||
|
||||
// Which database driver are we using?
|
||||
type Database = MySql;
|
||||
|
||||
pub struct Data {
|
||||
database: MySqlPool,
|
||||
database: Pool<Database>,
|
||||
http: reqwest::Client,
|
||||
guild_data_cache: DashMap<GuildId, Arc<RwLock<GuildData>>>,
|
||||
join_sound_cache: DashMap<UserId, Option<u32>>,
|
||||
@ -112,7 +115,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let database = MySqlPool::connect(&env::var("DATABASE_URL").expect("No database URL provided"))
|
||||
let database = Pool::connect(&env::var("DATABASE_URL").expect("No database URL provided"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use poise::serenity::{async_trait, model::id::GuildId};
|
||||
use sqlx::mysql::MySqlPool;
|
||||
use sqlx::Executor;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::{Context, Data};
|
||||
use crate::{Context, Data, Database};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GuildData {
|
||||
@ -44,9 +44,7 @@ impl CtxGuildData for Data {
|
||||
let x = if let Some(guild_data) = self.guild_data_cache.get(&guild_id) {
|
||||
Ok(guild_data.clone())
|
||||
} else {
|
||||
let pool = self.database.clone();
|
||||
|
||||
match GuildData::from_id(guild_id, pool).await {
|
||||
match GuildData::from_id(guild_id, &self.database).await {
|
||||
Ok(d) => {
|
||||
let lock = Arc::new(RwLock::new(d));
|
||||
|
||||
@ -66,7 +64,7 @@ impl CtxGuildData for Data {
|
||||
impl GuildData {
|
||||
pub async fn from_id<G: Into<GuildId>>(
|
||||
guild_id: G,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database> + Copy,
|
||||
) -> Result<GuildData, sqlx::Error> {
|
||||
let guild_id = guild_id.into();
|
||||
|
||||
@ -79,7 +77,7 @@ SELECT id, prefix, volume, allow_greets, allowed_role
|
||||
",
|
||||
guild_id.as_u64()
|
||||
)
|
||||
.fetch_one(&db_pool)
|
||||
.fetch_one(db_pool)
|
||||
.await;
|
||||
|
||||
match guild_data {
|
||||
@ -93,7 +91,7 @@ SELECT id, prefix, volume, allow_greets, allowed_role
|
||||
|
||||
async fn create_from_guild<G: Into<GuildId>>(
|
||||
guild_id: G,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<GuildData, sqlx::Error> {
|
||||
let guild_id = guild_id.into();
|
||||
|
||||
@ -104,7 +102,7 @@ INSERT INTO servers (id)
|
||||
",
|
||||
guild_id.as_u64()
|
||||
)
|
||||
.execute(&db_pool)
|
||||
.execute(db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(GuildData {
|
||||
@ -118,7 +116,7 @@ INSERT INTO servers (id)
|
||||
|
||||
pub async fn commit(
|
||||
&self,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
sqlx::query!(
|
||||
"
|
||||
@ -137,7 +135,7 @@ WHERE
|
||||
self.allowed_role,
|
||||
self.id
|
||||
)
|
||||
.execute(&db_pool)
|
||||
.execute(db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
|
@ -21,8 +21,6 @@ impl JoinSoundCtx for Data {
|
||||
join_sound_id.value().clone()
|
||||
} else {
|
||||
let join_sound_id = {
|
||||
let pool = self.database.clone();
|
||||
|
||||
let join_id_res = sqlx::query!(
|
||||
"
|
||||
SELECT join_sound_id
|
||||
@ -31,7 +29,7 @@ SELECT join_sound_id
|
||||
",
|
||||
user_id.as_u64()
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.fetch_one(&self.database)
|
||||
.await;
|
||||
|
||||
if let Ok(row) = join_id_res {
|
||||
|
@ -2,10 +2,10 @@ use std::{env, path::Path};
|
||||
|
||||
use poise::serenity::async_trait;
|
||||
use songbird::input::restartable::Restartable;
|
||||
use sqlx::{mysql::MySqlPool, Error};
|
||||
use sqlx::{Error, Executor};
|
||||
use tokio::{fs::File, io::AsyncWriteExt, process::Command};
|
||||
|
||||
use crate::{consts::UPLOAD_MAX_SIZE, error::ErrorTypes, Data};
|
||||
use crate::{consts::UPLOAD_MAX_SIZE, error::ErrorTypes, Data, Database};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Sound {
|
||||
@ -208,7 +208,7 @@ SELECT name, id, public, server_id, uploader_id
|
||||
}
|
||||
|
||||
impl Sound {
|
||||
async fn src(&self, db_pool: MySqlPool) -> Vec<u8> {
|
||||
async fn src(&self, db_pool: impl Executor<'_, Database = Database>) -> Vec<u8> {
|
||||
struct Src {
|
||||
src: Vec<u8>,
|
||||
}
|
||||
@ -223,7 +223,7 @@ SELECT src
|
||||
",
|
||||
self.id
|
||||
)
|
||||
.fetch_one(&db_pool)
|
||||
.fetch_one(db_pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -232,7 +232,7 @@ SELECT src
|
||||
|
||||
pub async fn store_sound_source(
|
||||
&self,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let caching_location = env::var("CACHING_LOCATION").unwrap_or(String::from("/tmp"));
|
||||
|
||||
@ -250,7 +250,7 @@ SELECT src
|
||||
|
||||
pub async fn playable(
|
||||
&self,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<Restartable, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let path_name = self.store_sound_source(db_pool).await?;
|
||||
|
||||
@ -261,7 +261,7 @@ SELECT src
|
||||
|
||||
pub async fn count_user_sounds<U: Into<u64>>(
|
||||
user_id: U,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<u32, sqlx::error::Error> {
|
||||
let user_id = user_id.into();
|
||||
|
||||
@ -273,7 +273,7 @@ SELECT COUNT(1) as count
|
||||
",
|
||||
user_id
|
||||
)
|
||||
.fetch_one(&db_pool)
|
||||
.fetch_one(db_pool)
|
||||
.await?
|
||||
.count;
|
||||
|
||||
@ -283,7 +283,7 @@ SELECT COUNT(1) as count
|
||||
pub async fn count_named_user_sounds<U: Into<u64>>(
|
||||
user_id: U,
|
||||
name: &String,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<u32, sqlx::error::Error> {
|
||||
let user_id = user_id.into();
|
||||
|
||||
@ -298,7 +298,7 @@ SELECT COUNT(1) as count
|
||||
user_id,
|
||||
name
|
||||
)
|
||||
.fetch_one(&db_pool)
|
||||
.fetch_one(db_pool)
|
||||
.await?
|
||||
.count;
|
||||
|
||||
@ -307,7 +307,7 @@ SELECT COUNT(1) as count
|
||||
|
||||
pub async fn commit(
|
||||
&self,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
sqlx::query!(
|
||||
"
|
||||
@ -320,7 +320,7 @@ WHERE
|
||||
self.public,
|
||||
self.id
|
||||
)
|
||||
.execute(&db_pool)
|
||||
.execute(db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
@ -328,7 +328,7 @@ WHERE
|
||||
|
||||
pub async fn delete(
|
||||
&self,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
sqlx::query!(
|
||||
"
|
||||
@ -338,7 +338,7 @@ DELETE
|
||||
",
|
||||
self.id
|
||||
)
|
||||
.execute(&db_pool)
|
||||
.execute(db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
@ -349,7 +349,7 @@ DELETE
|
||||
src_url: &str,
|
||||
server_id: G,
|
||||
user_id: U,
|
||||
db_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync + Send>> {
|
||||
let server_id = server_id.into();
|
||||
let user_id = user_id.into();
|
||||
@ -396,7 +396,7 @@ INSERT INTO sounds (name, server_id, uploader_id, public, src)
|
||||
user_id,
|
||||
data
|
||||
)
|
||||
.execute(&db_pool)
|
||||
.execute(db_pool)
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(()),
|
||||
|
12
src/utils.rs
12
src/utils.rs
@ -6,7 +6,7 @@ use poise::serenity::model::{
|
||||
id::{ChannelId, UserId},
|
||||
};
|
||||
use songbird::{create_player, error::JoinResult, tracks::TrackHandle, Call};
|
||||
use sqlx::MySqlPool;
|
||||
use sqlx::Executor;
|
||||
use tokio::sync::{Mutex, MutexGuard};
|
||||
|
||||
use crate::{
|
||||
@ -14,17 +14,17 @@ use crate::{
|
||||
guild_data::CtxGuildData,
|
||||
sound::{Sound, SoundCtx},
|
||||
},
|
||||
Data,
|
||||
Data, Database,
|
||||
};
|
||||
|
||||
pub async fn play_audio(
|
||||
sound: &mut Sound,
|
||||
volume: u8,
|
||||
call_handler: &mut MutexGuard<'_, Call>,
|
||||
mysql_pool: MySqlPool,
|
||||
db_pool: impl Executor<'_, Database = Database>,
|
||||
loop_: bool,
|
||||
) -> Result<TrackHandle, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let (track, track_handler) = create_player(sound.playable(mysql_pool.clone()).await?.into());
|
||||
let (track, track_handler) = create_player(sound.playable(db_pool).await?.into());
|
||||
|
||||
let _ = track_handler.set_volume(volume as f32 / 100.0);
|
||||
|
||||
@ -99,8 +99,6 @@ pub async fn play_from_query(
|
||||
|
||||
match channel_to_join {
|
||||
Some(user_channel) => {
|
||||
let pool = data.database.clone();
|
||||
|
||||
let mut sound_vec = data
|
||||
.search_for_sound(query, guild_id, user_id, true)
|
||||
.await
|
||||
@ -122,7 +120,7 @@ pub async fn play_from_query(
|
||||
sound,
|
||||
guild_data.read().await.volume,
|
||||
&mut lock,
|
||||
pool,
|
||||
&data.database,
|
||||
loop_,
|
||||
)
|
||||
.await
|
||||
|
Loading…
Reference in New Issue
Block a user