From 7d43aa59187b7d7e97296937ce742639440ac149 Mon Sep 17 00:00:00 2001 From: jude Date: Fri, 13 May 2022 23:08:52 +0100 Subject: [PATCH] cleared up all unwraps from the reminder sender. cleared up clippy lints. added undo button --- postman/src/sender.rs | 118 ++++++++++++++++++++------------ src/commands/info_cmds.rs | 4 +- src/commands/moderation_cmds.rs | 28 +++----- src/commands/reminder_cmds.rs | 75 +++++++++++++------- src/commands/todo_cmds.rs | 2 +- src/component_models/mod.rs | 91 ++++++++++++++++++++++-- src/consts.rs | 10 +-- src/event_handlers.rs | 11 ++- src/hooks.rs | 23 +++---- src/interval_parser.rs | 8 +-- src/main.rs | 5 +- src/models/command_macro.rs | 18 ++--- src/models/reminder/builder.rs | 8 +-- src/models/reminder/mod.rs | 69 ++++++++++++++++--- src/time_parser.rs | 6 +- 15 files changed, 318 insertions(+), 158 deletions(-) diff --git a/postman/src/sender.rs b/postman/src/sender.rs index 2feb704..c7843f1 100644 --- a/postman/src/sender.rs +++ b/postman/src/sender.rs @@ -58,10 +58,10 @@ fn fmt_displacement(format: &str, seconds: u64) -> String { pub fn substitute(string: &str) -> String { let new = TIMEFROM_REGEX.replace(string, |caps: &Captures| { - let final_time = caps.name("time").unwrap().as_str(); - let format = caps.name("format").unwrap().as_str(); + let final_time = caps.name("time").map(|m| m.as_str().parse::().ok()).flatten(); + let format = caps.name("format").map(|m| m.as_str()); - if let Ok(final_time) = final_time.parse::() { + if let (Some(final_time), Some(format)) = (final_time, format) { let dt = NaiveDateTime::from_timestamp(final_time, 0); let now = Utc::now().naive_utc(); @@ -81,13 +81,11 @@ pub fn substitute(string: &str) -> String { TIMENOW_REGEX .replace(&new, |caps: &Captures| { - let timezone = caps.name("timezone").unwrap().as_str(); + let timezone = caps.name("timezone").map(|m| m.as_str().parse::().ok()).flatten(); + let format = caps.name("format").map(|m| m.as_str()); - println!("{}", timezone); - - if let Ok(tz) = timezone.parse::() { - let format = caps.name("format").unwrap().as_str(); - let now = Utc::now().with_timezone(&tz); + if let (Some(timezone), Some(format)) = (timezone, format) { + let now = Utc::now().with_timezone(&timezone); now.format(format).to_string() } else { @@ -122,7 +120,7 @@ impl Embed { pool: impl Executor<'_, Database = Database> + Copy, id: u32, ) -> Option { - let mut embed = sqlx::query_as!( + match sqlx::query_as!( Self, r#" SELECT @@ -142,21 +140,29 @@ impl Embed { ) .fetch_one(pool) .await - .unwrap(); + { + Ok(mut embed) => { + embed.title = substitute(&embed.title); + embed.description = substitute(&embed.description); + embed.footer = substitute(&embed.footer); - embed.title = substitute(&embed.title); - embed.description = substitute(&embed.description); - embed.footer = substitute(&embed.footer); + embed.fields.iter_mut().for_each(|mut field| { + field.title = substitute(&field.title); + field.value = substitute(&field.value); + }); - embed.fields.iter_mut().for_each(|mut field| { - field.title = substitute(&field.title); - field.value = substitute(&field.value); - }); + if embed.has_content() { + Some(embed) + } else { + None + } + } - if embed.has_content() { - Some(embed) - } else { - None + Err(e) => { + warn!("Error loading embed from reminder: {:?}", e); + + None + } } } @@ -251,9 +257,9 @@ pub struct Reminder { impl Reminder { pub async fn fetch_reminders(pool: impl Executor<'_, Database = Database> + Copy) -> Vec { - sqlx::query_as_unchecked!( + match sqlx::query_as!( Reminder, - " + r#" SELECT reminders.`id` AS id, @@ -261,20 +267,20 @@ SELECT channels.`webhook_id` AS webhook_id, channels.`webhook_token` AS webhook_token, - channels.`paused` AS channel_paused, - channels.`paused_until` AS channel_paused_until, - reminders.`enabled` AS enabled, + channels.`paused` AS "channel_paused:_", + channels.`paused_until` AS "channel_paused_until:_", + reminders.`enabled` AS "enabled:_", - reminders.`tts` AS tts, - reminders.`pin` AS pin, + reminders.`tts` AS "tts:_", + reminders.`pin` AS "pin:_", reminders.`content` AS content, reminders.`attachment` AS attachment, reminders.`attachment_name` AS attachment_name, - reminders.`utc_time` AS 'utc_time', + reminders.`utc_time` AS "utc_time:_", reminders.`timezone` AS timezone, - reminders.`restartable` AS restartable, - reminders.`expires` AS expires, + reminders.`restartable` AS "restartable:_", + reminders.`expires` AS "expires:_", reminders.`interval_seconds` AS 'interval_seconds', reminders.`interval_months` AS 'interval_months', @@ -288,18 +294,26 @@ ON reminders.channel_id = channels.id WHERE reminders.`utc_time` < NOW() - ", + "#, ) .fetch_all(pool) .await - .unwrap() - .into_iter() - .map(|mut rem| { - rem.content = substitute(&rem.content); + { + Ok(reminders) => reminders + .into_iter() + .map(|mut rem| { + rem.content = substitute(&rem.content); - rem - }) - .collect::>() + rem + }) + .collect::>(), + + Err(e) => { + warn!("Could not fetch reminders: {:?}", e); + + vec![] + } + } } async fn reset_webhook(&self, pool: impl Executor<'_, Database = Database> + Copy) { @@ -319,7 +333,7 @@ UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ? let mut updated_reminder_time = self.utc_time; if let Some(interval) = self.interval_months { - let row = sqlx::query!( + match sqlx::query!( // use the second date_add to force return value to datetime "SELECT DATE_ADD(DATE_ADD(?, INTERVAL ? MONTH), INTERVAL 0 SECOND) AS new_time", updated_reminder_time, @@ -327,9 +341,25 @@ UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ? ) .fetch_one(pool) .await - .unwrap(); + { + Ok(row) => match row.new_time { + Some(datetime) => { + updated_reminder_time = datetime; + } + None => { + warn!("Could not update interval by months: got NULL"); - updated_reminder_time = row.new_time.unwrap(); + updated_reminder_time += Duration::days(30); + } + }, + + Err(e) => { + warn!("Could not update interval by months: {:?}", e); + + // naively fallback to adding 30 days + updated_reminder_time += Duration::days(30); + } + } } if let Some(interval) = self.interval_seconds { @@ -538,7 +568,7 @@ UPDATE `channels` SET paused = 0, paused_until = NULL WHERE `channel` = ? error!("Error sending {:?}: {:?}", self, e); if let Error::Http(error) = e { - if error.status_code() == Some(StatusCode::from_u16(404).unwrap()) { + if error.status_code() == Some(StatusCode::NOT_FOUND) { error!("Seeing channel is deleted. Removing reminder"); self.force_delete(pool).await; } else { diff --git a/src/commands/info_cmds.rs b/src/commands/info_cmds.rs index fb8184c..15af5ce 100644 --- a/src/commands/info_cmds.rs +++ b/src/commands/info_cmds.rs @@ -71,7 +71,7 @@ pub async fn info(ctx: Context<'_>) -> Result<(), Error> { .send(|m| { m.ephemeral(true).embed(|e| { e.title("Info") - .description(format!( + .description( "Help: `/help` **Welcome to Reminder Bot!** @@ -81,7 +81,7 @@ Find me on https://discord.jellywx.com and on https://github.com/JellyWX :) Invite the bot: https://invite.reminder-bot.com/ Use our dashboard: https://reminder-bot.com/", - )) + ) .footer(footer) .color(*THEME_COLOR) }) diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index 1f479ea..4fc325f 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -1,3 +1,5 @@ +use std::collections::hash_map::Entry; + use chrono::offset::Utc; use chrono_tz::{Tz, TZ_VARIANTS}; use levenshtein::levenshtein; @@ -52,7 +54,7 @@ pub async fn timezone( .description(format!( "Timezone has been set to **{}**. Your current time should be `{}`", timezone, - now.format("%H:%M").to_string() + now.format("%H:%M") )) .color(*THEME_COLOR) }) @@ -75,10 +77,7 @@ pub async fn timezone( let fields = filtered_tz.iter().map(|tz| { ( tz.to_string(), - format!( - "🕗 `{}`", - Utc::now().with_timezone(tz).format("%H:%M").to_string() - ), + format!("🕗 `{}`", Utc::now().with_timezone(tz).format("%H:%M")), true, ) }); @@ -98,11 +97,7 @@ pub async fn timezone( } } else { let popular_timezones_iter = ctx.data().popular_timezones.iter().map(|t| { - ( - t.to_string(), - format!("🕗 `{}`", Utc::now().with_timezone(t).format("%H:%M").to_string()), - true, - ) + (t.to_string(), format!("🕗 `{}`", Utc::now().with_timezone(t).format("%H:%M")), true) }); ctx.send(|m| { @@ -142,7 +137,7 @@ WHERE ) .fetch_all(&ctx.data().database) .await - .unwrap_or(vec![]) + .unwrap_or_default() .iter() .map(|s| s.name.clone()) .collect() @@ -200,14 +195,11 @@ Please select a unique name for your macro.", let okay = { let mut lock = ctx.data().recording_macros.write().await; - if lock.contains_key(&(guild_id, ctx.author().id)) { - false - } else { - lock.insert( - (guild_id, ctx.author().id), - CommandMacro { guild_id, name, description, commands: vec![] }, - ); + if let Entry::Vacant(e) = lock.entry((guild_id, ctx.author().id)) { + e.insert(CommandMacro { guild_id, name, description, commands: vec![] }); true + } else { + false } }; diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 1059b4f..f992273 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -9,13 +9,14 @@ use chrono_tz::Tz; use num_integer::Integer; use poise::{ serenity::{builder::CreateEmbed, model::channel::Channel}, + serenity_prelude::{ButtonStyle, ReactionType}, CreateReply, }; use crate::{ component_models::{ pager::{DelPager, LookPager, Pager}, - ComponentDataModel, DelSelector, + ComponentDataModel, DelSelector, UndoReminder, }, consts::{ EMBED_DESCRIPTION_MAX_LENGTH, HOUR, MINUTE, REGEX_CHANNEL_USER, SELECT_MAX_ENTRIES, @@ -500,18 +501,16 @@ pub async fn start_timer( if count >= 25 { ctx.say("You already have 25 timers. Please delete some timers before creating a new one") .await?; - } else { - if name.len() <= 32 { - Timer::create(&name, owner, &ctx.data().database).await; + } else if name.len() <= 32 { + Timer::create(&name, owner, &ctx.data().database).await; - ctx.say("Created a new timer").await?; - } else { - ctx.say(format!( - "Please name your timer something shorted (max. 32 characters, you used {})", - name.len() - )) - .await?; - } + ctx.say("Created a new timer").await?; + } else { + ctx.say(format!( + "Please name your timer something shorted (max. 32 characters, you used {})", + name.len() + )) + .await?; } Ok(()) @@ -589,8 +588,7 @@ pub async fn remind( }; let scopes = { - let list = - channels.map(|arg| parse_mention_list(&arg.to_string())).unwrap_or_default(); + let list = channels.map(|arg| parse_mention_list(&arg)).unwrap_or_default(); if list.is_empty() { if ctx.guild_id().is_some() { @@ -610,7 +608,7 @@ pub async fn remind( { ( parse_duration(repeat) - .or_else(|_| parse_duration(&format!("1 {}", repeat.to_string()))) + .or_else(|_| parse_duration(&format!("1 {}", repeat))) .ok(), { if let Some(arg) = &expires { @@ -653,15 +651,41 @@ pub async fn remind( let (errors, successes) = builder.build().await; - let embed = create_response(successes, errors, time); + let embed = create_response(&successes, &errors, time); - ctx.send(|m| { - m.embed(|c| { - *c = embed; - c + if successes.len() == 1 { + let reminder = successes.iter().next().map(|(r, _)| r.id).unwrap(); + let undo_button = ComponentDataModel::UndoReminder(UndoReminder { + user_id: ctx.author().id, + reminder_id: reminder, + }); + + ctx.send(|m| { + m.embed(|c| { + *c = embed; + c + }) + .components(|c| { + c.create_action_row(|r| { + r.create_button(|b| { + b.emoji(ReactionType::Unicode("🔕".to_string())) + .label("Cancel") + .style(ButtonStyle::Danger) + .custom_id(undo_button.to_custom_id()) + }) + }) + }) }) - }) - .await?; + .await?; + } else { + ctx.send(|m| { + m.embed(|c| { + *c = embed; + c + }) + }) + .await?; + } } } None => { @@ -673,8 +697,8 @@ pub async fn remind( } fn create_response( - successes: HashSet, - errors: HashSet, + successes: &HashSet<(Reminder, ReminderScope)>, + errors: &HashSet, time: i64, ) -> CreateEmbed { let success_part = match successes.len() { @@ -682,7 +706,8 @@ fn create_response( n => format!( "Reminder{s} for {locations} set for ", s = if n > 1 { "s" } else { "" }, - locations = successes.iter().map(|l| l.mention()).collect::>().join(", "), + locations = + successes.iter().map(|(_, l)| l.mention()).collect::>().join(", "), offset = time ), }; diff --git a/src/commands/todo_cmds.rs b/src/commands/todo_cmds.rs index afbbe2a..742813d 100644 --- a/src/commands/todo_cmds.rs +++ b/src/commands/todo_cmds.rs @@ -336,7 +336,7 @@ pub fn show_todo_page( opt.create_option(|o| { o.label(format!("Mark {} complete", count + first_num)) .value(id) - .description(disp.split_once(" ").unwrap_or(("", "")).1) + .description(disp.split_once(' ').unwrap_or(("", "")).1) }); } diff --git a/src/component_models/mod.rs b/src/component_models/mod.rs index 51d240b..a738034 100644 --- a/src/component_models/mod.rs +++ b/src/component_models/mod.rs @@ -3,14 +3,20 @@ pub(crate) mod pager; use std::io::Cursor; use chrono_tz::Tz; -use poise::serenity::{ - builder::CreateEmbed, - client::Context, - model::{ - channel::Channel, - interactions::{message_component::MessageComponentInteraction, InteractionResponseType}, - prelude::InteractionApplicationCommandCallbackDataFlags, +use log::warn; +use poise::{ + serenity::{ + builder::CreateEmbed, + client::Context, + model::{ + channel::Channel, + interactions::{ + message_component::MessageComponentInteraction, InteractionResponseType, + }, + prelude::InteractionApplicationCommandCallbackDataFlags, + }, }, + serenity_prelude as serenity, }; use rmp_serde::Serializer; use serde::{Deserialize, Serialize}; @@ -38,6 +44,7 @@ pub enum ComponentDataModel { DelSelector(DelSelector), TodoSelector(TodoSelector), MacroPager(MacroPager), + UndoReminder(UndoReminder), } impl ComponentDataModel { @@ -334,6 +341,70 @@ WHERE guilds.guild = ?", }) .await; } + ComponentDataModel::UndoReminder(undo_reminder) => { + if component.user.id == undo_reminder.user_id { + let reminder = + Reminder::from_id(&data.database, undo_reminder.reminder_id).await; + + if let Some(reminder) = reminder { + match reminder.delete(&data.database).await { + Ok(()) => { + let _ = component + .create_interaction_response(&ctx, |f| { + f.kind(InteractionResponseType::UpdateMessage) + .interaction_response_data(|d| { + d.embed(|e| { + e.title("Reminder Canceled") + .description( + "This reminder has been canceled.", + ) + .color(*THEME_COLOR) + }) + .components(|c| c) + }) + }) + .await; + } + Err(e) => { + warn!("Error canceling reminder: {:?}", e); + + let _ = component + .create_interaction_response(&ctx, |f| { + f.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| { + d.content( + "The reminder could not be canceled: it may have already been deleted. Check `/del`!") + .ephemeral(true) + }) + }) + .await; + } + } + } else { + let _ = component + .create_interaction_response(&ctx, |f| { + f.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| { + d.content( + "The reminder could not be canceled: it may have already been deleted. Check `/del`!") + .ephemeral(true) + }) + }) + .await; + } + } else { + let _ = component + .create_interaction_response(&ctx, |f| { + f.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| { + d.content( + "Only the user who performed the command can use this button.") + .ephemeral(true) + }) + }) + .await; + } + } } } } @@ -351,3 +422,9 @@ pub struct TodoSelector { pub channel_id: Option, pub guild_id: Option, } + +#[derive(Serialize, Deserialize)] +pub struct UndoReminder { + pub user_id: serenity::UserId, + pub reminder_id: u32, +} diff --git a/src/consts.rs b/src/consts.rs index 8df74b9..746e156 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -36,15 +36,11 @@ lazy_static! { ); pub static ref CNC_GUILD: Option = env::var("CNC_GUILD").map(|var| var.parse::().ok()).ok().flatten(); - pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL") - .ok() - .map(|inner| inner.parse::().ok()) - .flatten() - .unwrap_or(600); + pub static ref MIN_INTERVAL: i64 = + env::var("MIN_INTERVAL").ok().and_then(|inner| inner.parse::().ok()).unwrap_or(600); pub static ref MAX_TIME: i64 = env::var("MAX_TIME") .ok() - .map(|inner| inner.parse::().ok()) - .flatten() + .and_then(|inner| inner.parse::().ok()) .unwrap_or(60 * 60 * 24 * 365 * 50); pub static ref LOCAL_TIMEZONE: String = env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string()); diff --git a/src/event_handlers.rs b/src/event_handlers.rs index 700fa87..3937d22 100644 --- a/src/event_handlers.rs +++ b/src/event_handlers.rs @@ -42,7 +42,7 @@ pub async fn listener( }; }); } else { - warn!("Not running postman") + warn!("Not running postman"); } if !run_settings.contains("web") { @@ -50,7 +50,7 @@ pub async fn listener( reminder_web::initialize(kill_tx, ctx2, pool2).await.unwrap(); }); } else { - warn!("Not running web") + warn!("Not running web"); } data.is_loop_running.swap(true, Ordering::Relaxed); @@ -114,14 +114,13 @@ pub async fn listener( .execute(&data.database) .await; } - poise::Event::InteractionCreate { interaction } => match interaction { - Interaction::MessageComponent(component) => { + poise::Event::InteractionCreate { interaction } => { + if let Interaction::MessageComponent(component) = interaction { let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id); component_model.act(ctx, data, component).await; } - _ => {} - }, + } _ => {} } diff --git a/src/hooks.rs b/src/hooks.rs index a87a1c1..2d434ac 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -11,10 +11,10 @@ async fn macro_check(ctx: Context<'_>) -> bool { if let Some(command_macro) = lock.get_mut(&(guild_id, ctx.author().id)) { if command_macro.commands.len() >= MACRO_MAX_COMMANDS { let _ = ctx.send(|m| { - m.ephemeral(true).content( - format!("{} commands already recorded. Please use `/macro finish` to end recording.", MACRO_MAX_COMMANDS), - ) - }) + m.ephemeral(true).content( + format!("{} commands already recorded. Please use `/macro finish` to end recording.", MACRO_MAX_COMMANDS), + ) + }) .await; } else { let recorded = RecordedCommand { @@ -30,19 +30,13 @@ async fn macro_check(ctx: Context<'_>) -> bool { .await; } - false - } else { - true + return false; } - } else { - true } - } else { - true } - } else { - true } + + true } async fn check_self_permissions(ctx: Context<'_>) -> bool { @@ -56,14 +50,13 @@ async fn check_self_permissions(ctx: Context<'_>) -> bool { let (view_channel, send_messages, embed_links) = ctx .channel_id() .to_channel_cached(&ctx.discord()) - .map(|c| { + .and_then(|c| { if let Channel::Guild(channel) = c { channel.permissions_for_user(&ctx.discord(), user_id).ok() } else { None } }) - .flatten() .map_or((false, false, false), |p| { (p.view_channel(), p.send_messages(), p.embed_links()) }); diff --git a/src/interval_parser.rs b/src/interval_parser.rs index 4ac8ef0..8befd57 100644 --- a/src/interval_parser.rs +++ b/src/interval_parser.rs @@ -75,7 +75,7 @@ impl fmt::Display for Error { match self { Error::InvalidCharacter(offset) => write!(f, "invalid character at {}", offset), Error::NumberExpected(offset) => write!(f, "expected number at {}", offset), - Error::UnknownUnit { unit, value, .. } if &unit == &"" => { + Error::UnknownUnit { unit, value, .. } if unit.is_empty() => { write!(f, "time unit needed, for example {0}sec or {0}ms", value,) } Error::UnknownUnit { unit, .. } => { @@ -162,11 +162,11 @@ impl<'a> Parser<'a> { }; let mut nsec = self.current.2 + nsec; if nsec > 1_000_000_000 { - sec = sec + nsec / 1_000_000_000; + sec += nsec / 1_000_000_000; nsec %= 1_000_000_000; } - sec = self.current.1 + sec; - month = self.current.0 + month; + sec += self.current.1; + month += self.current.0; self.current = (month, sec, nsec); diff --git a/src/main.rs b/src/main.rs index cda6210..ae77c17 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ #![feature(int_roundings)] + #[macro_use] extern crate lazy_static; @@ -23,7 +24,7 @@ use std::{ use chrono_tz::Tz; use dotenv::dotenv; use poise::serenity::model::{ - gateway::{Activity, GatewayIntents}, + gateway::GatewayIntents, id::{GuildId, UserId}, }; use sqlx::{MySql, Pool}; @@ -52,7 +53,7 @@ pub struct Data { broadcast: Sender<()>, } -impl std::fmt::Debug for Data { +impl Debug for Data { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "Data {{ .. }}") } diff --git a/src/models/command_macro.rs b/src/models/command_macro.rs index 5f7bfb1..08b314e 100644 --- a/src/models/command_macro.rs +++ b/src/models/command_macro.rs @@ -5,11 +5,11 @@ use serde::{Deserialize, Serialize}; use crate::{Context, Data, Error}; -fn default_none() -> Option< - for<'a> fn( - poise::ApplicationContext<'a, U, E>, - ) -> poise::BoxFuture<'a, Result<(), poise::FrameworkError<'a, U, E>>>, -> { +type Func = for<'a> fn( + poise::ApplicationContext<'a, U, E>, +) -> poise::BoxFuture<'a, Result<(), poise::FrameworkError<'a, U, E>>>; + +fn default_none() -> Option> { None } @@ -17,11 +17,7 @@ fn default_none() -> Option< pub struct RecordedCommand { #[serde(skip)] #[serde(default = "default_none::")] - pub action: Option< - for<'a> fn( - poise::ApplicationContext<'a, U, E>, - ) -> poise::BoxFuture<'a, Result<(), poise::FrameworkError<'a, U, E>>>, - >, + pub action: Option>, pub command_name: String, pub options: Vec, } @@ -59,7 +55,7 @@ SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND .iter() .find(|c| c.identifying_name == recorded_command.command_name); - recorded_command.action = command.map(|c| c.slash_action).flatten().clone(); + recorded_command.action = command.map(|c| c.slash_action).flatten(); } let command_macro = CommandMacro { diff --git a/src/models/reminder/builder.rs b/src/models/reminder/builder.rs index de53b81..8ddfb4a 100644 --- a/src/models/reminder/builder.rs +++ b/src/models/reminder/builder.rs @@ -126,7 +126,7 @@ INSERT INTO reminders ( .await .unwrap(); - Ok(Reminder::from_uid(&self.pool, self.uid).await.unwrap()) + Ok(Reminder::from_uid(&self.pool, &self.uid).await.unwrap()) } } @@ -207,7 +207,7 @@ impl<'a> MultiReminderBuilder<'a> { self.scopes = scopes; } - pub async fn build(self) -> (HashSet, HashSet) { + pub async fn build(self) -> (HashSet, HashSet<(Reminder, ReminderScope)>) { let mut errors = HashSet::new(); let mut ok_locs = HashSet::new(); @@ -309,8 +309,8 @@ impl<'a> MultiReminderBuilder<'a> { }; match builder.build().await { - Ok(_) => { - ok_locs.insert(scope); + Ok(r) => { + ok_locs.insert((r, scope)); } Err(e) => { errors.insert(e); diff --git a/src/models/reminder/mod.rs b/src/models/reminder/mod.rs index f50203d..eee1ae4 100644 --- a/src/models/reminder/mod.rs +++ b/src/models/reminder/mod.rs @@ -4,6 +4,8 @@ pub mod errors; mod helper; pub mod look_flags; +use std::hash::{Hash, Hasher}; + use chrono::{NaiveDateTime, TimeZone}; use chrono_tz::Tz; use poise::{ @@ -32,11 +34,22 @@ pub struct Reminder { pub set_by: Option, } +impl Hash for Reminder { + fn hash(&self, state: &mut H) { + self.uid.hash(state); + } +} + +impl PartialEq for Reminder { + fn eq(&self, other: &Self) -> bool { + self.uid == other.uid + } +} + +impl Eq for Reminder {} + impl Reminder { - pub async fn from_uid( - pool: impl Executor<'_, Database = Database>, - uid: String, - ) -> Option { + pub async fn from_uid(pool: impl Executor<'_, Database = Database>, uid: &str) -> Option { sqlx::query_as_unchecked!( Self, " @@ -72,6 +85,42 @@ WHERE .ok() } + pub async fn from_id(pool: impl Executor<'_, Database = Database>, id: u32) -> Option { + sqlx::query_as_unchecked!( + Self, + " +SELECT + reminders.id, + reminders.uid, + channels.channel, + reminders.utc_time, + reminders.interval_seconds, + reminders.interval_months, + reminders.expires, + reminders.enabled, + reminders.content, + reminders.embed_description, + users.user AS set_by +FROM + reminders +INNER JOIN + channels +ON + reminders.channel_id = channels.id +LEFT JOIN + users +ON + reminders.set_by = users.id +WHERE + reminders.id = ? + ", + id + ) + .fetch_one(pool) + .await + .ok() + } + pub async fn from_channel>( pool: impl Executor<'_, Database = Database>, channel_id: C, @@ -240,6 +289,13 @@ WHERE .unwrap() } + pub async fn delete( + &self, + db: impl Executor<'_, Database = Database>, + ) -> Result<(), sqlx::Error> { + sqlx::query!("DELETE FROM reminders WHERE uid = ?", self.uid).execute(db).await.map(|_| ()) + } + pub fn display_content(&self) -> &str { if self.content.is_empty() { &self.embed_description @@ -254,10 +310,7 @@ WHERE count + 1, self.display_content(), self.channel, - timezone - .timestamp(self.utc_time.timestamp(), 0) - .format("%Y-%m-%d %H:%M:%S") - .to_string() + timezone.timestamp(self.utc_time.timestamp(), 0).format("%Y-%m-%d %H:%M:%S") ) } diff --git a/src/time_parser.rs b/src/time_parser.rs index 1dd912b..2401d37 100644 --- a/src/time_parser.rs +++ b/src/time_parser.rs @@ -211,14 +211,12 @@ pub async fn natural_parser(time: &str, timezone: &str) -> Option { .output() .await .ok() - .map(|inner| { + .and_then(|inner| { if inner.status.success() { Some(from_utf8(&*inner.stdout).unwrap().parse::().unwrap()) } else { None } }) - .flatten() - .map(|inner| if inner < 0 { None } else { Some(inner) }) - .flatten() + .and_then(|inner| if inner < 0 { None } else { Some(inner) }) }