handle failed checks and pass messages back to user

This commit is contained in:
jude-lafitteIII
2020-05-20 17:54:46 +01:00
parent 03e3c4f301
commit 7265976948
5 changed files with 165 additions and 41 deletions

View File

@ -7,7 +7,6 @@ mod error;
use sound::Sound;
use guilddata::GuildData;
use error::ErrorTypes;
use serenity::{
client::{
@ -18,9 +17,9 @@ use serenity::{
Client, Context,
},
framework::standard::{
Args, CommandResult, CheckResult, StandardFramework, Reason,
Args, CommandResult, CheckResult, DispatchError, StandardFramework, Reason,
macros::{
command, group, check,
command, group, check, hook,
}
},
model::{
@ -66,6 +65,7 @@ use std::{
sync::Arc,
time::Duration,
};
use serenity::framework::standard::CommandError;
struct SQLPool;
@ -128,7 +128,7 @@ struct PermissionManagedUsers;
#[name("role_check")]
async fn role_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResult {
async fn check_for_roles(ctx: &&Context, msg: &&Message) -> Result<(), Box<dyn std::error::Error>> {
async fn check_for_roles(ctx: &&Context, msg: &&Message) -> CheckResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
@ -144,34 +144,43 @@ async fn role_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResu
let guild_id = *msg.guild_id.unwrap().as_u64();
let res = sqlx::query!(
"
let role_res = sqlx::query!(
"
SELECT COUNT(1) as count
FROM roles
WHERE guild_id = ? AND role IN (?)
",
guild_id, user_roles
).fetch_one(&pool).await?;
",
guild_id, user_roles
)
.fetch_one(&pool).await;
if res.count > 0 {
Ok(())
}
else {
Err(Box::new(ErrorTypes::NotEnoughRoles))
match role_res {
Ok(role_count) => {
if role_count.count > 0 {
CheckResult::Success
}
else {
CheckResult::Failure(Reason::User("User has not got a sufficient role".to_string()))
}
}
Err(_) => {
CheckResult::Failure(Reason::User("User has not got a sufficient role".to_string()))
}
}
}
None => {
Err(Box::new(ErrorTypes::NotEnoughRoles))
CheckResult::Failure(Reason::User("User has not got a sufficient role".to_string()))
}
}
}
if check_for_roles(&ctx, &msg).await.is_ok() {
if perform_permission_check(ctx, &msg).await.is_success() {
CheckResult::Success
}
else {
perform_permission_check(ctx, &msg).await
check_for_roles(&ctx, &msg).await
}
}
@ -223,7 +232,7 @@ SELECT join_sound_id
let mut sound = sqlx::query_as_unchecked!(
Sound,
"
SELECT *
SELECT name, id, plays, public, server_id, uploader_id
FROM sounds
WHERE id = ?
",
@ -259,7 +268,7 @@ SELECT *
async fn play_audio(sound: &mut Sound, guild: GuildData, handler: &mut VoiceHandler, mut voice_guilds: MutexGuard<'_, HashMap<GuildId, u8>>, pool: MySqlPool)
-> Result<(), Box<dyn std::error::Error>> {
let audio = handler.play_only(sound.store_sound_source().await?);
let audio = handler.play_only(sound.store_sound_source(pool.clone()).await?);
{
let mut locked = audio.lock().await;
@ -275,9 +284,30 @@ async fn play_audio(sound: &mut Sound, guild: GuildData, handler: &mut VoiceHand
Ok(())
}
#[hook]
async fn log_errors(_: &Context, m: &Message, cmd_name: &str, error: Result<(), CommandError>) {
if let Err(e) = error {
println!("Error in command {} ({}): {:?}", cmd_name, m.content, e);
}
}
#[hook]
async fn dispatch_error_hook(ctx: &Context, msg: &Message, error: DispatchError) {
match error {
DispatchError::CheckFailed(_f, reason) => {
if let Reason::User(description) = reason {
let _ = msg.reply(ctx, format!("You cannot do this command: {}", description)).await;
}
}
_ => {}
}
}
// entry point
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
dotenv()?;
let voice_guilds = Arc::new(Mutex::new(HashMap::new()));
@ -302,7 +332,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.group(&ALLUSERS_GROUP)
.group(&ROLEMANAGEDUSERS_GROUP)
.group(&PERMISSIONMANAGEDUSERS_GROUP);
.group(&PERMISSIONMANAGEDUSERS_GROUP)
.after(log_errors)
.on_dispatch_error(dispatch_error_hook);
let mut client = Client::new(&env::var("DISCORD_TOKEN").expect("Missing token from environment"))
.intents(GatewayIntents::GUILD_VOICE_STATES | GatewayIntents::GUILD_MESSAGES | GatewayIntents::GUILDS)
@ -815,7 +847,7 @@ async fn format_search_results(search_results: Vec<Sound>, msg: &Message, ctx: &
let field_iter = search_results.iter().take(25).map(|item| {
(&item.name, format!("ID: {}\nPlays: {}", item.id, item.plays), false)
(&item.name, format!("ID: {}\nPlays: {}", item.id, item.plays), true)
}).filter(|item| {
@ -857,7 +889,8 @@ async fn show_popular_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comma
let search_results = sqlx::query_as_unchecked!(
Sound,
"
SELECT * FROM sounds
SELECT name, id, plays, public, server_id, uploader_id
FROM sounds
ORDER BY plays DESC
LIMIT 25
"
@ -878,7 +911,8 @@ async fn show_random_sounds(ctx: &Context, msg: &Message, _args: Args) -> Comman
let search_results = sqlx::query_as_unchecked!(
Sound,
"
SELECT * FROM sounds
SELECT name, id, plays, public, server_id, uploader_id
FROM sounds
ORDER BY rand()
LIMIT 25
"