allow more reminder content inc. attachments, tts
This commit is contained in:
		| @@ -976,6 +976,81 @@ fn generate_uid() -> String { | ||||
|         .join("") | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| enum ContentError { | ||||
|     TooManyAttachments, | ||||
|     AttachmentTooLarge, | ||||
|     AttachmentDownloadFailed, | ||||
| } | ||||
|  | ||||
| impl std::fmt::Display for ContentError { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         write!(f, "{:?}", self) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl std::error::Error for ContentError {} | ||||
|  | ||||
| struct Content { | ||||
|     content: String, | ||||
|     tts: bool, | ||||
|     attachment: Option<Vec<u8>>, | ||||
|     attachment_name: Option<String>, | ||||
| } | ||||
|  | ||||
| impl Content { | ||||
|     async fn build<S: ToString>(content: S, message: &Message) -> Result<Self, ContentError> { | ||||
|         if message.attachments.len() > 1 { | ||||
|             Err(ContentError::TooManyAttachments) | ||||
|         } else if let Some(attachment) = message.attachments.get(0) { | ||||
|             if attachment.size > 8_000_000 { | ||||
|                 Err(ContentError::AttachmentTooLarge) | ||||
|             } else if let Ok(attachment_bytes) = attachment.download().await { | ||||
|                 Ok(Self { | ||||
|                     content: content.to_string(), | ||||
|                     tts: false, | ||||
|                     attachment: Some(attachment_bytes.clone()), | ||||
|                     attachment_name: Some(attachment.filename.clone()), | ||||
|                 }) | ||||
|             } else { | ||||
|                 Err(ContentError::AttachmentDownloadFailed) | ||||
|             } | ||||
|         } else { | ||||
|             Ok(Self { | ||||
|                 content: content.to_string(), | ||||
|                 tts: false, | ||||
|                 attachment: None, | ||||
|                 attachment_name: None, | ||||
|             }) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn substitute(&mut self, guild: Guild) { | ||||
|         if self.content.starts_with("/tts ") { | ||||
|             self.tts = true; | ||||
|             self.content = self.content.split_off(5); | ||||
|         } | ||||
|  | ||||
|         self.content = REGEX_CONTENT_SUBSTITUTION | ||||
|             .replace(&self.content, |caps: &Captures| { | ||||
|                 if let Some(user) = caps.name("user") { | ||||
|                     format!("<@{}>", user.as_str()) | ||||
|                 } else if let Some(role_name) = caps.name("role") { | ||||
|                     if let Some(role) = guild.role_by_name(role_name.as_str()) { | ||||
|                         role.mention() | ||||
|                     } else { | ||||
|                         role_name.as_str().to_string() | ||||
|                     } | ||||
|                 } else { | ||||
|                     String::new() | ||||
|                 } | ||||
|             }) | ||||
|             .to_string() | ||||
|             .replace("<<everyone>>", "@everyone") | ||||
|             .replace("<<here>>", "@here"); | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[command("remind")] | ||||
| #[permission_level(Managed)] | ||||
| async fn remind(ctx: &Context, msg: &Message, args: String) { | ||||
| @@ -1053,8 +1128,13 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem | ||||
|                 .transpose(); | ||||
|  | ||||
|             if let Ok(interval) = interval_parser { | ||||
|                 let content = captures.name("content").map(|mat| mat.as_str()).unwrap(); | ||||
|                 let content_res = Content::build( | ||||
|                     captures.name("content").map(|mat| mat.as_str()).unwrap(), | ||||
|                     msg, | ||||
|                 ) | ||||
|                 .await; | ||||
|  | ||||
|                 if let Ok(mut content) = content_res { | ||||
|                     let mut ok_locations = vec![]; | ||||
|                     let mut err_locations = vec![]; | ||||
|                     let mut err_types = HashSet::new(); | ||||
| @@ -1068,7 +1148,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem | ||||
|                             &scope, | ||||
|                             &time_parser, | ||||
|                             interval, | ||||
|                         content, | ||||
|                             &mut content, | ||||
|                         ) | ||||
|                         .await; | ||||
|  | ||||
| @@ -1151,6 +1231,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem | ||||
|                             }) | ||||
|                         }) | ||||
|                         .await; | ||||
|                 } | ||||
|             } else { | ||||
|                 let _ = msg | ||||
|                     .channel_id | ||||
| @@ -1179,6 +1260,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem | ||||
|  | ||||
|             let title = match command { | ||||
|                 RemindCommand::Remind => "Remind Help", | ||||
|  | ||||
|                 RemindCommand::Interval => "Interval Help", | ||||
|             }; | ||||
|  | ||||
| @@ -1321,6 +1403,9 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // todo remove this unwrap | ||||
|             let mut content = Content::build(&content, msg).await.unwrap(); | ||||
|  | ||||
|             if location_ids.len() == 1 { | ||||
|                 let location_id = location_ids.get(0).unwrap(); | ||||
|  | ||||
| @@ -1332,7 +1417,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { | ||||
|                     &location_id, | ||||
|                     timestamp, | ||||
|                     interval, | ||||
|                     &content, | ||||
|                     &mut content, | ||||
|                 ) | ||||
|                 .await; | ||||
|  | ||||
| @@ -1362,7 +1447,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { | ||||
|                         &location, | ||||
|                         timestamp, | ||||
|                         interval, | ||||
|                         &content, | ||||
|                         &mut content, | ||||
|                     ) | ||||
|                     .await; | ||||
|  | ||||
| @@ -1397,30 +1482,6 @@ async fn natural(ctx: &Context, msg: &Message, args: String) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn substitute_content(guild: Option<Guild>, content: &str) -> String { | ||||
|     if let Some(guild) = guild { | ||||
|         REGEX_CONTENT_SUBSTITUTION | ||||
|             .replace(content, |caps: &Captures| { | ||||
|                 if let Some(user) = caps.name("user") { | ||||
|                     format!("<@{}>", user.as_str()) | ||||
|                 } else if let Some(role_name) = caps.name("role") { | ||||
|                     if let Some(role) = guild.role_by_name(role_name.as_str()) { | ||||
|                         role.mention() | ||||
|                     } else { | ||||
|                         role_name.as_str().to_string() | ||||
|                     } | ||||
|                 } else { | ||||
|                     String::new() | ||||
|                 } | ||||
|             }) | ||||
|             .to_string() | ||||
|     } else { | ||||
|         content.to_string() | ||||
|     } | ||||
|     .replace("<<everyone>>", "@everyone") | ||||
|     .replace("<<here>>", "@here") | ||||
| } | ||||
|  | ||||
| async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>( | ||||
|     ctx: impl CacheHttp + AsRef<Cache>, | ||||
|     pool: &MySqlPool, | ||||
| @@ -1429,15 +1490,15 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>( | ||||
|     scope_id: &ReminderScope, | ||||
|     time_parser: T, | ||||
|     interval: Option<i64>, | ||||
|     content: &str, | ||||
|     content: &mut Content, | ||||
| ) -> Result<(), ReminderError> { | ||||
|     let user_id = user_id.into(); | ||||
|  | ||||
|     let content_string = if let Some(g_id) = guild_id { | ||||
|         substitute_content(g_id.to_guild_cached(&ctx).await, content) | ||||
|     } else { | ||||
|         content.to_string() | ||||
|     }; | ||||
|     if let Some(g_id) = guild_id { | ||||
|         if let Some(guild) = g_id.to_guild_cached(&ctx).await { | ||||
|             content.substitute(guild); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     let mut nudge = 0; | ||||
|  | ||||
| @@ -1460,6 +1521,7 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>( | ||||
|             let mut channel_data = ChannelData::from_channel(channel.clone(), &pool) | ||||
|                 .await | ||||
|                 .unwrap(); | ||||
|  | ||||
|             nudge = channel_data.nudge; | ||||
|  | ||||
|             if let Some(guild_channel) = channel.guild() { | ||||
| @@ -1483,10 +1545,8 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>( | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     // validate time, channel, content | ||||
|     if content_string.is_empty() { | ||||
|         Err(ReminderError::NotEnoughArgs) | ||||
|     } else if interval.map_or(false, |inner| inner < *MIN_INTERVAL) { | ||||
|     // validate time, channel | ||||
|     if interval.map_or(false, |inner| inner < *MIN_INTERVAL) { | ||||
|         Err(ReminderError::ShortInterval) | ||||
|     } else if interval.map_or(false, |inner| inner > *MAX_TIME) { | ||||
|         Err(ReminderError::LongInterval) | ||||
| @@ -1506,9 +1566,12 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>( | ||||
|                     } else { | ||||
|                         sqlx::query!( | ||||
|                             " | ||||
| INSERT INTO messages (content) VALUES (?) | ||||
| INSERT INTO messages (content, tts, attachment, attachment_name) VALUES (?, ?, ?, ?) | ||||
|                             ", | ||||
|                             content_string | ||||
|                             content.content, | ||||
|                             content.tts, | ||||
|                             content.attachment, | ||||
|                             content.attachment_name, | ||||
|                         ) | ||||
|                         .execute(&pool.clone()) | ||||
|                         .await | ||||
| @@ -1523,7 +1586,7 @@ INSERT INTO reminders (uid, message_id, channel_id, time, `interval`, method, se | ||||
|     (SELECT id FROM users WHERE user = ? LIMIT 1)) | ||||
|                             ", | ||||
|                             generate_uid(), | ||||
|                             content_string, | ||||
|                             content.content, | ||||
|                             db_channel_id, | ||||
|                             time as u32, | ||||
|                             interval, | ||||
|   | ||||
| @@ -345,11 +345,7 @@ impl Framework for RegexFramework { | ||||
|         } | ||||
|  | ||||
|         // gate to prevent analysing messages unnecessarily | ||||
|         if (msg.author.bot && self.ignore_bots) | ||||
|             || msg.tts | ||||
|             || msg.content.is_empty() | ||||
|             || !msg.attachments.is_empty() | ||||
|         { | ||||
|         if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() { | ||||
|         } | ||||
|         // Guild Command | ||||
|         else if let (Some(guild), Some(Channel::Guild(channel))) = | ||||
|   | ||||
		Reference in New Issue
	
	Block a user