From e002984986fed1340801e909ae24b7648e758699 Mon Sep 17 00:00:00 2001 From: jellywx Date: Wed, 16 Dec 2020 19:20:46 +0000 Subject: [PATCH] allow more reminder content inc. attachments, tts --- src/commands/reminder_cmds.rs | 305 ++++++++++++++++++++-------------- src/framework.rs | 6 +- 2 files changed, 185 insertions(+), 126 deletions(-) diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 691bbc2..c3819fc 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -976,6 +976,81 @@ fn generate_uid() -> String { .join("") } +#[derive(Debug)] +enum ContentError { + TooManyAttachments, + AttachmentTooLarge, + AttachmentDownloadFailed, +} + +impl std::fmt::Display for ContentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for ContentError {} + +struct Content { + content: String, + tts: bool, + attachment: Option>, + attachment_name: Option, +} + +impl Content { + async fn build(content: S, message: &Message) -> Result { + if message.attachments.len() > 1 { + Err(ContentError::TooManyAttachments) + } else if let Some(attachment) = message.attachments.get(0) { + if attachment.size > 8_000_000 { + Err(ContentError::AttachmentTooLarge) + } else if let Ok(attachment_bytes) = attachment.download().await { + Ok(Self { + content: content.to_string(), + tts: false, + attachment: Some(attachment_bytes.clone()), + attachment_name: Some(attachment.filename.clone()), + }) + } else { + Err(ContentError::AttachmentDownloadFailed) + } + } else { + Ok(Self { + content: content.to_string(), + tts: false, + attachment: None, + attachment_name: None, + }) + } + } + + fn substitute(&mut self, guild: Guild) { + if self.content.starts_with("/tts ") { + self.tts = true; + self.content = self.content.split_off(5); + } + + self.content = REGEX_CONTENT_SUBSTITUTION + .replace(&self.content, |caps: &Captures| { + if let Some(user) = caps.name("user") { + format!("<@{}>", user.as_str()) + } else if let Some(role_name) = caps.name("role") { + if let Some(role) = guild.role_by_name(role_name.as_str()) { + role.mention() + } else { + role_name.as_str().to_string() + } + } else { + String::new() + } + }) + .to_string() + .replace("<>", "@everyone") + .replace("<>", "@here"); + } +} + #[command("remind")] #[permission_level(Managed)] async fn remind(ctx: &Context, msg: &Message, args: String) { @@ -1053,104 +1128,110 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem .transpose(); if let Ok(interval) = interval_parser { - let content = captures.name("content").map(|mat| mat.as_str()).unwrap(); + let content_res = Content::build( + captures.name("content").map(|mat| mat.as_str()).unwrap(), + msg, + ) + .await; - let mut ok_locations = vec![]; - let mut err_locations = vec![]; - let mut err_types = HashSet::new(); + if let Ok(mut content) = content_res { + let mut ok_locations = vec![]; + let mut err_locations = vec![]; + let mut err_types = HashSet::new(); - for scope in scopes { - let res = create_reminder( - &ctx, - &pool, - msg.author.id, - msg.guild_id, - &scope, - &time_parser, - interval, - content, - ) - .await; - - if let Err(e) = res { - err_locations.push(scope); - err_types.insert(e); - } else { - ok_locations.push(scope); - } - } - - let success_part = match ok_locations.len() { - 0 => "".to_string(), - 1 => lm - .get(&user_data.language, "remind/success") - .replace("{location}", &ok_locations[0].mention()) - .replace( - "{offset}", - &shorthand_displacement(time_parser.displacement().unwrap() as u64), - ), - n => lm - .get(&user_data.language, "remind/success_bulk") - .replace("{number}", &n.to_string()) - .replace( - "{location}", - &ok_locations - .iter() - .map(|l| l.mention()) - .collect::>() - .join(", "), + for scope in scopes { + let res = create_reminder( + &ctx, + &pool, + msg.author.id, + msg.guild_id, + &scope, + &time_parser, + interval, + &mut content, ) - .replace( - "{offset}", - &shorthand_displacement(time_parser.displacement().unwrap() as u64), - ), - }; + .await; - let error_part = format!( - "{}\n{}", - match err_locations.len() { + if let Err(e) = res { + err_locations.push(scope); + err_types.insert(e); + } else { + ok_locations.push(scope); + } + } + + let success_part = match ok_locations.len() { 0 => "".to_string(), 1 => lm - .get(&user_data.language, "remind/issue") - .replace("{location}", &err_locations[0].mention()), + .get(&user_data.language, "remind/success") + .replace("{location}", &ok_locations[0].mention()) + .replace( + "{offset}", + &shorthand_displacement(time_parser.displacement().unwrap() as u64), + ), n => lm - .get(&user_data.language, "remind/issue_bulk") + .get(&user_data.language, "remind/success_bulk") .replace("{number}", &n.to_string()) .replace( "{location}", - &err_locations + &ok_locations .iter() .map(|l| l.mention()) .collect::>() .join(", "), - ), - }, - err_types - .iter() - .map(|err| match err { - ReminderError::DiscordError(s) => lm - .get(&user_data.language, err.to_response()) - .replace("{error}", &s), - - _ => lm.get(&user_data.language, err.to_response()).to_string(), - }) - .collect::>() - .join("\n") - ); - - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title( - lm.get(&user_data.language, "remind/title") - .replace("{number}", &ok_locations.len().to_string()), ) - .description(format!("{}\n\n{}", success_part, error_part)) - .color(*THEME_COLOR) + .replace( + "{offset}", + &shorthand_displacement(time_parser.displacement().unwrap() as u64), + ), + }; + + let error_part = format!( + "{}\n{}", + match err_locations.len() { + 0 => "".to_string(), + 1 => lm + .get(&user_data.language, "remind/issue") + .replace("{location}", &err_locations[0].mention()), + n => lm + .get(&user_data.language, "remind/issue_bulk") + .replace("{number}", &n.to_string()) + .replace( + "{location}", + &err_locations + .iter() + .map(|l| l.mention()) + .collect::>() + .join(", "), + ), + }, + err_types + .iter() + .map(|err| match err { + ReminderError::DiscordError(s) => lm + .get(&user_data.language, err.to_response()) + .replace("{error}", &s), + + _ => lm.get(&user_data.language, err.to_response()).to_string(), + }) + .collect::>() + .join("\n") + ); + + let _ = msg + .channel_id + .send_message(&ctx, |m| { + m.embed(|e| { + e.title( + lm.get(&user_data.language, "remind/title") + .replace("{number}", &ok_locations.len().to_string()), + ) + .description(format!("{}\n\n{}", success_part, error_part)) + .color(*THEME_COLOR) + }) }) - }) - .await; + .await; + } } else { let _ = msg .channel_id @@ -1179,6 +1260,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem let title = match command { RemindCommand::Remind => "Remind Help", + RemindCommand::Interval => "Interval Help", }; @@ -1321,6 +1403,9 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { } } + // todo remove this unwrap + let mut content = Content::build(&content, msg).await.unwrap(); + if location_ids.len() == 1 { let location_id = location_ids.get(0).unwrap(); @@ -1332,7 +1417,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { &location_id, timestamp, interval, - &content, + &mut content, ) .await; @@ -1362,7 +1447,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { &location, timestamp, interval, - &content, + &mut content, ) .await; @@ -1397,30 +1482,6 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { } } -fn substitute_content(guild: Option, content: &str) -> String { - if let Some(guild) = guild { - REGEX_CONTENT_SUBSTITUTION - .replace(content, |caps: &Captures| { - if let Some(user) = caps.name("user") { - format!("<@{}>", user.as_str()) - } else if let Some(role_name) = caps.name("role") { - if let Some(role) = guild.role_by_name(role_name.as_str()) { - role.mention() - } else { - role_name.as_str().to_string() - } - } else { - String::new() - } - }) - .to_string() - } else { - content.to_string() - } - .replace("<>", "@everyone") - .replace("<>", "@here") -} - async fn create_reminder<'a, U: Into, T: TryInto>( ctx: impl CacheHttp + AsRef, pool: &MySqlPool, @@ -1429,15 +1490,15 @@ async fn create_reminder<'a, U: Into, T: TryInto>( scope_id: &ReminderScope, time_parser: T, interval: Option, - content: &str, + content: &mut Content, ) -> Result<(), ReminderError> { let user_id = user_id.into(); - let content_string = if let Some(g_id) = guild_id { - substitute_content(g_id.to_guild_cached(&ctx).await, content) - } else { - content.to_string() - }; + if let Some(g_id) = guild_id { + if let Some(guild) = g_id.to_guild_cached(&ctx).await { + content.substitute(guild); + } + } let mut nudge = 0; @@ -1460,6 +1521,7 @@ async fn create_reminder<'a, U: Into, T: TryInto>( let mut channel_data = ChannelData::from_channel(channel.clone(), &pool) .await .unwrap(); + nudge = channel_data.nudge; if let Some(guild_channel) = channel.guild() { @@ -1483,10 +1545,8 @@ async fn create_reminder<'a, U: Into, T: TryInto>( } }; - // validate time, channel, content - if content_string.is_empty() { - Err(ReminderError::NotEnoughArgs) - } else if interval.map_or(false, |inner| inner < *MIN_INTERVAL) { + // validate time, channel + if interval.map_or(false, |inner| inner < *MIN_INTERVAL) { Err(ReminderError::ShortInterval) } else if interval.map_or(false, |inner| inner > *MAX_TIME) { Err(ReminderError::LongInterval) @@ -1506,9 +1566,12 @@ async fn create_reminder<'a, U: Into, T: TryInto>( } else { sqlx::query!( " -INSERT INTO messages (content) VALUES (?) +INSERT INTO messages (content, tts, attachment, attachment_name) VALUES (?, ?, ?, ?) ", - content_string + content.content, + content.tts, + content.attachment, + content.attachment_name, ) .execute(&pool.clone()) .await @@ -1523,7 +1586,7 @@ INSERT INTO reminders (uid, message_id, channel_id, time, `interval`, method, se (SELECT id FROM users WHERE user = ? LIMIT 1)) ", generate_uid(), - content_string, + content.content, db_channel_id, time as u32, interval, diff --git a/src/framework.rs b/src/framework.rs index 6cb36c9..eeb9782 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -345,11 +345,7 @@ impl Framework for RegexFramework { } // gate to prevent analysing messages unnecessarily - if (msg.author.bot && self.ignore_bots) - || msg.tts - || msg.content.is_empty() - || !msg.attachments.is_empty() - { + if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() { } // Guild Command else if let (Some(guild), Some(Channel::Guild(channel))) =