diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index fd17838..a438ad6 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -1,7 +1,6 @@ use regex_command_attr::command; use serenity::{ - cache::Cache, client::Context, http::CacheHttp, model::{ @@ -9,6 +8,7 @@ use serenity::{ channel::{Channel, GuildChannel}, guild::Guild, id::{ChannelId, GuildId, UserId}, + interactions::ButtonStyle, misc::Mentionable, webhook::Webhook, }, @@ -26,7 +26,7 @@ use crate::{ models::{ channel_data::ChannelData, guild_data::GuildData, - reminder::{LookFlags, Reminder}, + reminder::{LookFlags, Reminder, ReminderAction}, timer::Timer, user_data::UserData, CtxGuildData, @@ -54,43 +54,6 @@ use std::{ use regex::Captures; -use ring::hmac; - -fn generate_signed_payload(reminder_id: u32, member_id: u64) -> String { - let s_key = hmac::Key::new( - hmac::HMAC_SHA256, - env::var("SECRET_KEY") - .expect("No SECRET_KEY provided") - .as_bytes(), - ); - - let mut context = hmac::Context::with_key(&s_key); - - context.update(&reminder_id.to_le_bytes()); - context.update(&member_id.to_le_bytes()); - - let signature = context.sign(); - - format!( - "{}.{}", - base64::encode(reminder_id.to_le_bytes()), - base64::encode(&signature) - ) -} - -fn validate_signature(payload: String, member_id: u64) -> bool { - let (a, _b) = payload.split_once('.').expect("Payload format incorrect"); - - let reminder_id = u32::from_le_bytes( - base64::decode(a) - .expect("Payload format incorrect") - .try_into() - .expect("Payload format incorrect"), - ); - - payload == generate_signed_payload(reminder_id, member_id) -} - async fn create_webhook( ctx: impl CacheHttp, channel: GuildChannel, @@ -961,6 +924,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem match content_res { Ok(mut content) => { let mut ok_locations = vec![]; + let mut ok_reminders = vec![]; let mut err_locations = vec![]; let mut err_types = HashSet::new(); @@ -978,11 +942,16 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem ) .await; - if let Err(e) = res { - err_locations.push(scope); - err_types.insert(e); - } else { - ok_locations.push(scope); + match res { + Err(e) => { + err_locations.push(scope); + err_types.insert(e); + } + + Ok(id) => { + ok_locations.push(scope); + ok_reminders.push(id); + } } } @@ -1059,6 +1028,22 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem .description(format!("{}\n\n{}", success_part, error_part)) .color(*THEME_COLOR) }) + .components(|c| { + if ok_locations.len() == 1 { + c.create_action_row(|r| { + r.create_button(|b| { + b.style(ButtonStyle::Danger) + .label("Delete") + .custom_id(ok_reminders[0].signed_action( + msg.author.id, + ReminderAction::Delete, + )) + }) + }); + } + + c + }) }) .await; } @@ -1323,7 +1308,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { } async fn create_reminder<'a, U: Into, T: TryInto>( - ctx: impl CacheHttp + AsRef, + ctx: &Context, pool: &MySqlPool, user_id: U, guild_id: Option, @@ -1332,7 +1317,7 @@ async fn create_reminder<'a, U: Into, T: TryInto>( expires_parser: Option, interval: Option, content: &mut Content, -) -> Result { +) -> Result { let user_id = user_id.into(); if let Some(g_id) = guild_id { @@ -1349,7 +1334,7 @@ async fn create_reminder<'a, U: Into, T: TryInto>( let user_data = UserData::from_user(&user, &ctx, &pool).await.unwrap(); if let Some(guild_id) = guild_id { - if guild_id.member(ctx, user).await.is_err() { + if guild_id.member(&ctx, user).await.is_err() { return Err(ReminderError::InvalidTag); } } @@ -1455,7 +1440,9 @@ INSERT INTO reminders ( .await .unwrap(); - Ok(uid) + let reminder = Reminder::from_uid(ctx, uid).await.unwrap(); + + Ok(reminder) } else if time < 0 { // case required for if python returns -1 Err(ReminderError::InvalidTime) diff --git a/src/main.rs b/src/main.rs index 0887754..8c4c17c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,6 +46,7 @@ use dashmap::DashMap; use tokio::sync::RwLock; +use crate::models::reminder::{Reminder, ReminderAction}; use chrono::Utc; use chrono_tz::Tz; use serenity::model::prelude::{ @@ -336,10 +337,45 @@ DELETE FROM guilds WHERE guild = ? lm.get(&user_data.language, "lang/set_p"), ) }) + .flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL) }) }) .await; } + } else { + match Reminder::from_interaction(&ctx, member.user.id, data.custom_id).await + { + Ok((reminder, action)) => { + let response = match action { + ReminderAction::Delete => { + reminder.delete(&ctx).await; + "Reminder has been deleted" + } + }; + + let _ = interaction + .create_interaction_response(&ctx, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| d + .content(response) + .flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL) + ) + }) + .await; + } + + Err(ie) => { + let _ = interaction + .create_interaction_response(&ctx, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|d| d + .content(ie.to_string()) + .flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL) + ) + }) + .await; + } + } } } } diff --git a/src/models/reminder.rs b/src/models/reminder.rs index 4e3ab9f..f80784e 100644 --- a/src/models/reminder.rs +++ b/src/models/reminder.rs @@ -11,6 +11,9 @@ use crate::{ }; use num_integer::Integer; +use ring::hmac; +use std::convert::{TryFrom, TryInto}; +use std::env; fn longhand_displacement(seconds: u64) -> String { let (days, seconds) = seconds.div_rem(&DAY); @@ -31,13 +34,14 @@ fn longhand_displacement(seconds: u64) -> String { sections.join(", ") } +#[derive(Debug)] pub struct Reminder { pub id: u32, pub uid: String, pub channel: u64, pub utc_time: NaiveDateTime, pub interval: Option, - pub expires: NaiveDateTime, + pub expires: Option, pub enabled: bool, pub content: String, pub embed_description: String, @@ -45,6 +49,80 @@ pub struct Reminder { } impl Reminder { + pub async fn from_uid(ctx: &Context, uid: String) -> Option { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + + sqlx::query_as_unchecked!( + Self, + " +SELECT + reminders.id, + reminders.uid, + channels.channel, + reminders.utc_time, + reminders.interval, + reminders.expires, + reminders.enabled, + reminders.content, + reminders.embed_description, + users.user AS set_by +FROM + reminders +INNER JOIN + channels +ON + reminders.channel_id = channels.id +LEFT JOIN + users +ON + reminders.set_by = users.id +WHERE + reminders.uid = ? + ", + uid + ) + .fetch_one(&pool) + .await + .ok() + } + + pub async fn from_id(ctx: &Context, id: u32) -> Option { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + + sqlx::query_as_unchecked!( + Self, + " +SELECT + reminders.id, + reminders.uid, + channels.channel, + reminders.utc_time, + reminders.interval, + reminders.expires, + reminders.enabled, + reminders.content, + reminders.embed_description, + users.user AS set_by +FROM + reminders +INNER JOIN + channels +ON + reminders.channel_id = channels.id +LEFT JOIN + users +ON + reminders.set_by = users.id +WHERE + reminders.id = ? + ", + id + ) + .fetch_one(&pool) + .await + .ok() + } + pub async fn from_channel>( ctx: &Context, channel_id: C, @@ -249,6 +327,127 @@ WHERE ) } } + + pub async fn from_interaction>( + ctx: &Context, + member_id: U, + payload: String, + ) -> Result<(Self, ReminderAction), InteractionError> { + let sections = payload.split(".").collect::>(); + + if sections.len() != 3 { + Err(InteractionError::InvalidFormat) + } else { + let action = ReminderAction::try_from(sections[0]) + .map_err(|_| InteractionError::InvalidAction)?; + + let reminder_id = u32::from_le_bytes( + base64::decode(sections[1]) + .map_err(|_| InteractionError::InvalidBase64)? + .try_into() + .map_err(|_| InteractionError::InvalidSize)?, + ); + + if let Some(reminder) = Self::from_id(ctx, reminder_id).await { + if reminder.signed_action(member_id, action) == payload { + Ok((reminder, action)) + } else { + Err(InteractionError::SignatureMismatch) + } + } else { + Err(InteractionError::NoReminder) + } + } + } + + pub fn signed_action>(&self, member_id: U, action: ReminderAction) -> String { + let s_key = hmac::Key::new( + hmac::HMAC_SHA256, + env::var("SECRET_KEY") + .expect("No SECRET_KEY provided") + .as_bytes(), + ); + + let mut context = hmac::Context::with_key(&s_key); + + context.update(&self.id.to_le_bytes()); + context.update(&member_id.into().to_le_bytes()); + + let signature = context.sign(); + + format!( + "{}.{}.{}", + action.to_string(), + base64::encode(self.id.to_le_bytes()), + base64::encode(&signature) + ) + } + + pub async fn delete(&self, ctx: &Context) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + + sqlx::query!( + " +DELETE FROM reminders WHERE id = ? + ", + self.id + ) + .execute(&pool) + .await + .unwrap(); + } +} + +#[derive(Debug)] +pub enum InteractionError { + InvalidFormat, + InvalidBase64, + InvalidSize, + NoReminder, + SignatureMismatch, + InvalidAction, +} + +impl ToString for InteractionError { + fn to_string(&self) -> String { + match self { + InteractionError::InvalidFormat => { + String::from("The interaction data was improperly formatted") + } + InteractionError::InvalidBase64 => String::from("The interaction data was invalid"), + InteractionError::InvalidSize => String::from("The interaction data was invalid"), + InteractionError::NoReminder => String::from("Reminder could not be found"), + InteractionError::SignatureMismatch => { + String::from("Only the user who did the command can use interactions") + } + InteractionError::InvalidAction => String::from("The action was invalid"), + } + } +} + +#[derive(Clone, Copy)] +pub enum ReminderAction { + Delete, +} + +impl ToString for ReminderAction { + fn to_string(&self) -> String { + match self { + Self::Delete => String::from("del"), + } + } +} + +impl TryFrom<&str> for ReminderAction { + type Error = (); + + fn try_from(value: &str) -> Result { + match value { + "del" => Ok(Self::Delete), + + _ => Err(()), + } + } } enum TimeDisplayType {