diff --git a/src/main.rs b/src/main.rs index 070d831..17fbf2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,9 +10,9 @@ use serenity::{ Client, Context, }, framework::standard::{ - Args, CommandResult, StandardFramework, + Args, CommandResult, CheckResult, StandardFramework, Reason, macros::{ - command, group, + command, group, check, } }, model::{ @@ -87,8 +87,89 @@ lazy_static! { } #[group] -#[commands(play, info, help, change_volume, change_prefix, upload_new_sound)] -struct General; +#[commands(info, help)] +struct AllUsers; + +#[group] +#[commands(play, upload_new_sound, change_volume)] +#[checks(role_check)] +struct RoleManagedUsers; + +#[group] +#[commands(change_prefix)] +#[checks(permission_check)] +struct PermissionManagedUsers; + +#[check] +#[name("role_check")] +async fn role_check(ctx: &mut Context, msg: &Message, _args: &mut Args) -> CheckResult { + + async fn check_for_roles(ctx: &&mut Context, msg: &&Message) -> Result<(), Box> { + let pool = ctx.data.read().await + .get::().cloned().expect("Could not get SQLPool from data"); + + let user_member = msg.member(&ctx).await; + + match user_member { + Some(member) => { + let user_roles: String = member.roles + .iter() + .map(|r| (*r.as_u64()).to_string()) + .collect::>() + .join(", "); + + let guild_id = *msg.guild_id.unwrap().as_u64(); + + let res = sqlx::query!( + " +SELECT COUNT(1) as count + FROM roles + WHERE guild_id = ? AND role IN (?) + ", + guild_id, user_roles + ).fetch_one(&pool).await?; + + if res.count > 0 { + Ok(()) + } + else { + Err(Box::new(ErrorTypes::NotEnoughRoles)) + } + } + + None => { + Err(Box::new(ErrorTypes::NotEnoughRoles)) + } + } + } + + if check_for_roles(&ctx, &msg).await.is_ok() { + CheckResult::Success + } + else { + perform_permission_check(ctx, &msg).await + } +} + +#[check] +#[name("permission_check")] +async fn permission_check(ctx: &mut Context, msg: &Message, _args: &mut Args) -> CheckResult { + perform_permission_check(ctx, &msg).await +} + +async fn perform_permission_check(ctx: &Context, msg: &&Message) -> CheckResult { + if let Some(guild_id) = msg.guild_id { + if let Ok(member) = guild_id.member(ctx.clone(), msg.author.id).await { + if let Ok(perms) = member.permissions(ctx).await { + if perms.manage_guild() { + return CheckResult::Success + } + } + } + } + + CheckResult::Failure(Reason::User(String::from("User needs `Manage Guild` permission"))) +} struct Sound { name: String, @@ -103,6 +184,7 @@ struct Sound { #[derive(Debug)] enum ErrorTypes { InvalidFile, + NotEnoughRoles, } impl std::error::Error for ErrorTypes {} @@ -234,6 +316,8 @@ SELECT COUNT(1) as count .arg("28000") .arg("-f") .arg("opus") + .arg("-fs") + .arg("1048576") .arg("pipe:1") .output(); @@ -241,7 +325,7 @@ SELECT COUNT(1) as count match output { Ok(out) => { - if out.status.success() && out.stdout.len() < 1024 * 1024 { + if out.status.success() { Some(out.stdout) } else { @@ -364,7 +448,9 @@ async fn main() -> Result<(), Box> { None => Some(String::from("?")) } }))) - .group(&GENERAL_GROUP); + .group(&ALLUSERS_GROUP) + .group(&ROLEMANAGEDUSERS_GROUP) + .group(&PERMISSIONMANAGEDUSERS_GROUP); let mut client = Client::new_with_extras( &env::var("DISCORD_TOKEN").expect("Missing token from environment"), @@ -620,7 +706,7 @@ async fn upload_new_sound(ctx: &mut Context, msg: &Message, mut args: Args) -> C } if permit_upload { - msg.channel_id.say(&ctx, "Please now upload an audio file under 1MB in size:").await?; + msg.channel_id.say(&ctx, "Please now upload an audio file under 1MB in size (larger files will be automatically trimmed):").await?; let reply = msg.channel_id.await_reply(&ctx) .author_id(msg.author.id) @@ -641,7 +727,7 @@ async fn upload_new_sound(ctx: &mut Context, msg: &Message, mut args: Args) -> C } Err(_) => { - msg.channel_id.say(&ctx, "Sound failed to upload. Size may be too large").await?; + msg.channel_id.say(&ctx, "Sound failed to upload.").await?; } } } else { @@ -655,7 +741,12 @@ async fn upload_new_sound(ctx: &mut Context, msg: &Message, mut args: Args) -> C } } else { - msg.channel_id.say(&ctx, "You have reached the maximum number of sounds ({}). Either delete some with `{}delete` or join our Patreon for unlimited uploads at **https://patreon.com/jellywx**").await?; + msg.channel_id.say( + &ctx, + format!( + "You have reached the maximum number of sounds ({}). Either delete some with `?delete` or join our Patreon for unlimited uploads at **https://patreon.com/jellywx**", + *MAX_SOUNDS, + )).await?; } } else { @@ -665,3 +756,43 @@ async fn upload_new_sound(ctx: &mut Context, msg: &Message, mut args: Args) -> C Ok(()) } +#[command] +async fn set_allowed_roles(ctx: &mut Context, msg: &Message, args: Args) -> CommandResult { + if args.len() == 0 { + msg.channel_id.say(&ctx, "Usage: `?roles `. Current roles: ").await?; + } + else { + let pool = ctx.data.read().await + .get::().cloned().expect("Could not get SQLPool from data"); + + let guild_id = *msg.guild_id.unwrap().as_u64(); + + sqlx::query!( + " +DELETE FROM roles + WHERE guild_id = ? + ", + guild_id + ).execute(&pool).await?; + + if msg.mention_roles.len() > 0 { + for role in msg.mention_roles.iter().map(|r| *r.as_u64()) { + sqlx::query!( + " +INSERT INTO roles (guild_id, role) + VALUES + (?, ?) + ", + guild_id, role + ).execute(&pool).await?; + } + + msg.channel_id.say(&ctx, "Specified roles whitelisted").await?; + } + else { + msg.channel_id.say(&ctx, "Role whitelisting disabled").await?; + } + } + + Ok(()) +}