diff --git a/Cargo.lock b/Cargo.lock index f144bff..635351c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,6 +6,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d2e7343e7fc9de883d1b0341e0b13970f764c14101234857d2ddafa1cb1cac2" +[[package]] +name = "aho-corasick" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8716408b8bc624ed7f65d223ddb9ac2d044c0547b6fa4b0d554f3a9540496ada" +dependencies = [ + "memchr", +] + [[package]] name = "arc-swap" version = "0.4.6" @@ -507,6 +516,15 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "heck" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "hermit-abi" version = "0.1.10" @@ -701,6 +719,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "matches" version = "0.1.8" @@ -1037,6 +1061,24 @@ version = "0.1.56" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" +[[package]] +name = "regex" +version = "1.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6020f034922e3194c711b82a627453881bc4682166cabb07134a10c26ba7692" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", + "thread_local", +] + +[[package]] +name = "regex-syntax" +version = "0.6.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae" + [[package]] name = "remove_dir_all" version = "0.5.2" @@ -1351,10 +1393,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] -name = "sqlx" -version = "0.3.4" +name = "sqlformat" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2c268cf2f045f3d8b6b54e50653e66c59d6770373b2b59ba29dea459e8294cf" +checksum = "5ce64a4576e1720a2e511bf3ccdb8c0f6cfed0fc265bcbaa0bd369485e02c631" +dependencies = [ + "lazy_static", + "maplit", + "regex", +] + +[[package]] +name = "sqlx" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8974cacd80085fbe49e778708d660dec6fb351604dc34c3905b26efb2803b038" dependencies = [ "sqlx-core", "sqlx-macros", @@ -1362,9 +1415,9 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "673fb6f30bdc14b7812a5ae22d7cc1e8e3d1debd5384eedbdf745827e9721cf3" +checksum = "88ac5a436f941c42eac509471a730df5c3c58e1450e68cd39afedbd948206273" dependencies = [ "async-native-tls", "async-stream", @@ -1388,18 +1441,20 @@ dependencies = [ "rand", "sha-1", "sha2", + "sqlformat", "tokio", "url", ] [[package]] name = "sqlx-macros" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0de981547a9e8c15336b30079ca040ca252aa91c071e05298d699981d6dec041" +checksum = "de2ae78b783af5922d811b14665a5a3755e531c3087bb805cf24cf71f15e6780" dependencies = [ "dotenv", "futures", + "heck", "lazy_static", "proc-macro2", "quote", @@ -1478,6 +1533,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" +dependencies = [ + "lazy_static", +] + [[package]] name = "time" version = "0.1.42" @@ -1614,6 +1678,12 @@ dependencies = [ "smallvec", ] +[[package]] +name = "unicode-segmentation" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e83e153d1053cbb5a118eeff7fd5be06ed99153f00dbcd8ae310c5fb2b22edc0" + [[package]] name = "unicode-xid" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index 49e11c7..91c4d0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" [dependencies] serenity = {path = "/home/jude/serenity", features = ["voice", "collector"]} -sqlx = {version = "0.3", default-features = false, features = ["runtime-tokio", "macros", "mysql", "bigdecimal"]} +sqlx = {version = "0.3.5", default-features = false, features = ["runtime-tokio", "macros", "mysql", "bigdecimal"]} dotenv = "0.15" tokio = {version = "0.2.19", features = ["fs", "sync", "process", "io-util"]} lazy_static = "1.4.0" diff --git a/src/error.rs b/src/error.rs index 97ec01d..e6a75ca 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,12 +3,11 @@ use std::fmt::Formatter; #[derive(Debug)] pub enum ErrorTypes { InvalidFile, - NotEnoughRoles, } impl std::error::Error for ErrorTypes {} impl std::fmt::Display for ErrorTypes { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "ErrorTypes") + write!(f, "ErrorTypes: InvalidFile") } } diff --git a/src/main.rs b/src/main.rs index e4d0a56..f76078a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,6 @@ mod error; use sound::Sound; use guilddata::GuildData; -use error::ErrorTypes; use serenity::{ client::{ @@ -18,9 +17,9 @@ use serenity::{ Client, Context, }, framework::standard::{ - Args, CommandResult, CheckResult, StandardFramework, Reason, + Args, CommandResult, CheckResult, DispatchError, StandardFramework, Reason, macros::{ - command, group, check, + command, group, check, hook, } }, model::{ @@ -66,6 +65,7 @@ use std::{ sync::Arc, time::Duration, }; +use serenity::framework::standard::CommandError; struct SQLPool; @@ -128,7 +128,7 @@ struct PermissionManagedUsers; #[name("role_check")] async fn role_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResult { - async fn check_for_roles(ctx: &&Context, msg: &&Message) -> Result<(), Box> { + async fn check_for_roles(ctx: &&Context, msg: &&Message) -> CheckResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); @@ -144,34 +144,43 @@ async fn role_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResu let guild_id = *msg.guild_id.unwrap().as_u64(); - let res = sqlx::query!( - " + let role_res = sqlx::query!( + " SELECT COUNT(1) as count FROM roles WHERE guild_id = ? AND role IN (?) - ", - guild_id, user_roles - ).fetch_one(&pool).await?; + ", + guild_id, user_roles + ) + .fetch_one(&pool).await; - if res.count > 0 { - Ok(()) - } - else { - Err(Box::new(ErrorTypes::NotEnoughRoles)) + match role_res { + Ok(role_count) => { + if role_count.count > 0 { + CheckResult::Success + } + else { + CheckResult::Failure(Reason::User("User has not got a sufficient role".to_string())) + } + } + + Err(_) => { + CheckResult::Failure(Reason::User("User has not got a sufficient role".to_string())) + } } } None => { - Err(Box::new(ErrorTypes::NotEnoughRoles)) + CheckResult::Failure(Reason::User("User has not got a sufficient role".to_string())) } } } - if check_for_roles(&ctx, &msg).await.is_ok() { + if perform_permission_check(ctx, &msg).await.is_success() { CheckResult::Success } else { - perform_permission_check(ctx, &msg).await + check_for_roles(&ctx, &msg).await } } @@ -223,7 +232,7 @@ SELECT join_sound_id let mut sound = sqlx::query_as_unchecked!( Sound, " -SELECT * +SELECT name, id, plays, public, server_id, uploader_id FROM sounds WHERE id = ? ", @@ -259,7 +268,7 @@ SELECT * async fn play_audio(sound: &mut Sound, guild: GuildData, handler: &mut VoiceHandler, mut voice_guilds: MutexGuard<'_, HashMap>, pool: MySqlPool) -> Result<(), Box> { - let audio = handler.play_only(sound.store_sound_source().await?); + let audio = handler.play_only(sound.store_sound_source(pool.clone()).await?); { let mut locked = audio.lock().await; @@ -275,9 +284,30 @@ async fn play_audio(sound: &mut Sound, guild: GuildData, handler: &mut VoiceHand Ok(()) } +#[hook] +async fn log_errors(_: &Context, m: &Message, cmd_name: &str, error: Result<(), CommandError>) { + if let Err(e) = error { + println!("Error in command {} ({}): {:?}", cmd_name, m.content, e); + } +} + +#[hook] +async fn dispatch_error_hook(ctx: &Context, msg: &Message, error: DispatchError) { + match error { + DispatchError::CheckFailed(_f, reason) => { + if let Reason::User(description) = reason { + let _ = msg.reply(ctx, format!("You cannot do this command: {}", description)).await; + } + } + + _ => {} + } +} + // entry point #[tokio::main] async fn main() -> Result<(), Box> { + dotenv()?; let voice_guilds = Arc::new(Mutex::new(HashMap::new())); @@ -302,7 +332,9 @@ async fn main() -> Result<(), Box> { ) .group(&ALLUSERS_GROUP) .group(&ROLEMANAGEDUSERS_GROUP) - .group(&PERMISSIONMANAGEDUSERS_GROUP); + .group(&PERMISSIONMANAGEDUSERS_GROUP) + .after(log_errors) + .on_dispatch_error(dispatch_error_hook); let mut client = Client::new(&env::var("DISCORD_TOKEN").expect("Missing token from environment")) .intents(GatewayIntents::GUILD_VOICE_STATES | GatewayIntents::GUILD_MESSAGES | GatewayIntents::GUILDS) @@ -815,7 +847,7 @@ async fn format_search_results(search_results: Vec, msg: &Message, ctx: & let field_iter = search_results.iter().take(25).map(|item| { - (&item.name, format!("ID: {}\nPlays: {}", item.id, item.plays), false) + (&item.name, format!("ID: {}\nPlays: {}", item.id, item.plays), true) }).filter(|item| { @@ -857,7 +889,8 @@ async fn show_popular_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comma let search_results = sqlx::query_as_unchecked!( Sound, " -SELECT * FROM sounds +SELECT name, id, plays, public, server_id, uploader_id + FROM sounds ORDER BY plays DESC LIMIT 25 " @@ -878,7 +911,8 @@ async fn show_random_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comman let search_results = sqlx::query_as_unchecked!( Sound, " -SELECT * FROM sounds +SELECT name, id, plays, public, server_id, uploader_id + FROM sounds ORDER BY rand() LIMIT 25 " diff --git a/src/sound.rs b/src/sound.rs index 1058293..eac9e02 100644 --- a/src/sound.rs +++ b/src/sound.rs @@ -25,7 +25,6 @@ pub struct Sound { pub public: bool, pub server_id: u64, pub uploader_id: Option, - pub src: Vec, } impl Sound { @@ -48,7 +47,7 @@ impl Sound { let sound = sqlx::query_as_unchecked!( Self, " -SELECT * +SELECT name, id, plays, public, server_id, uploader_id FROM sounds WHERE id = ? AND ( public = 1 OR @@ -71,7 +70,7 @@ SELECT * sound = sqlx::query_as_unchecked!( Self, " -SELECT * +SELECT name, id, plays, public, server_id, uploader_id FROM sounds WHERE name = ? AND ( public = 1 OR @@ -90,7 +89,7 @@ SELECT * sound = sqlx::query_as_unchecked!( Self, " -SELECT * +SELECT name, id, plays, public, server_id, uploader_id FROM sounds WHERE name LIKE CONCAT('%', ?, '%') AND ( public = 1 OR @@ -109,7 +108,29 @@ SELECT * } } - pub async fn store_sound_source(&self) -> Result, Box> { + async fn get_self_src(&self, db_pool: MySqlPool) -> Vec { + struct Src { + src: Vec + } + + let record = sqlx::query_as_unchecked!( + Src, + " +SELECT src + FROM sounds + WHERE id = ? + LIMIT 1 + ", + self.id + ) + .fetch_one(&db_pool) + .await.unwrap(); + + return record.src + } + + pub async fn store_sound_source(&self, db_pool: MySqlPool) -> Result, Box> { + let caching_location = env::var("CACHING_LOCATION").unwrap_or(String::from("/tmp")); let path_name = format!("{}/sound-{}", caching_location, self.id); @@ -120,7 +141,7 @@ SELECT * let mut file = File::create(&path).await?; - file.write_all(self.src.as_ref()).await?; + file.write_all(&self.get_self_src(db_pool).await).await?; } Ok(ffmpeg(path_name).await?) @@ -267,7 +288,7 @@ INSERT INTO sounds (name, server_id, uploader_id, public, src) let sounds = sqlx::query_as_unchecked!( Sound, " -SELECT * +SELECT name, id, plays, public, server_id, uploader_id FROM sounds WHERE uploader_id = ? ", @@ -281,7 +302,7 @@ SELECT * let sounds = sqlx::query_as_unchecked!( Sound, " -SELECT * +SELECT name, id, plays, public, server_id, uploader_id FROM sounds WHERE server_id = ? ",