allow more reminder content inc. attachments, tts

This commit is contained in:
jellywx 2020-12-16 19:20:46 +00:00
parent 02b75dde6a
commit e002984986
2 changed files with 185 additions and 126 deletions

View File

@ -976,6 +976,81 @@ fn generate_uid() -> String {
.join("") .join("")
} }
#[derive(Debug)]
enum ContentError {
TooManyAttachments,
AttachmentTooLarge,
AttachmentDownloadFailed,
}
impl std::fmt::Display for ContentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for ContentError {}
struct Content {
content: String,
tts: bool,
attachment: Option<Vec<u8>>,
attachment_name: Option<String>,
}
impl Content {
async fn build<S: ToString>(content: S, message: &Message) -> Result<Self, ContentError> {
if message.attachments.len() > 1 {
Err(ContentError::TooManyAttachments)
} else if let Some(attachment) = message.attachments.get(0) {
if attachment.size > 8_000_000 {
Err(ContentError::AttachmentTooLarge)
} else if let Ok(attachment_bytes) = attachment.download().await {
Ok(Self {
content: content.to_string(),
tts: false,
attachment: Some(attachment_bytes.clone()),
attachment_name: Some(attachment.filename.clone()),
})
} else {
Err(ContentError::AttachmentDownloadFailed)
}
} else {
Ok(Self {
content: content.to_string(),
tts: false,
attachment: None,
attachment_name: None,
})
}
}
fn substitute(&mut self, guild: Guild) {
if self.content.starts_with("/tts ") {
self.tts = true;
self.content = self.content.split_off(5);
}
self.content = REGEX_CONTENT_SUBSTITUTION
.replace(&self.content, |caps: &Captures| {
if let Some(user) = caps.name("user") {
format!("<@{}>", user.as_str())
} else if let Some(role_name) = caps.name("role") {
if let Some(role) = guild.role_by_name(role_name.as_str()) {
role.mention()
} else {
role_name.as_str().to_string()
}
} else {
String::new()
}
})
.to_string()
.replace("<<everyone>>", "@everyone")
.replace("<<here>>", "@here");
}
}
#[command("remind")] #[command("remind")]
#[permission_level(Managed)] #[permission_level(Managed)]
async fn remind(ctx: &Context, msg: &Message, args: String) { async fn remind(ctx: &Context, msg: &Message, args: String) {
@ -1053,8 +1128,13 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
.transpose(); .transpose();
if let Ok(interval) = interval_parser { if let Ok(interval) = interval_parser {
let content = captures.name("content").map(|mat| mat.as_str()).unwrap(); let content_res = Content::build(
captures.name("content").map(|mat| mat.as_str()).unwrap(),
msg,
)
.await;
if let Ok(mut content) = content_res {
let mut ok_locations = vec![]; let mut ok_locations = vec![];
let mut err_locations = vec![]; let mut err_locations = vec![];
let mut err_types = HashSet::new(); let mut err_types = HashSet::new();
@ -1068,7 +1148,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
&scope, &scope,
&time_parser, &time_parser,
interval, interval,
content, &mut content,
) )
.await; .await;
@ -1151,6 +1231,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
}) })
}) })
.await; .await;
}
} else { } else {
let _ = msg let _ = msg
.channel_id .channel_id
@ -1179,6 +1260,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
let title = match command { let title = match command {
RemindCommand::Remind => "Remind Help", RemindCommand::Remind => "Remind Help",
RemindCommand::Interval => "Interval Help", RemindCommand::Interval => "Interval Help",
}; };
@ -1321,6 +1403,9 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
} }
} }
// todo remove this unwrap
let mut content = Content::build(&content, msg).await.unwrap();
if location_ids.len() == 1 { if location_ids.len() == 1 {
let location_id = location_ids.get(0).unwrap(); let location_id = location_ids.get(0).unwrap();
@ -1332,7 +1417,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
&location_id, &location_id,
timestamp, timestamp,
interval, interval,
&content, &mut content,
) )
.await; .await;
@ -1362,7 +1447,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
&location, &location,
timestamp, timestamp,
interval, interval,
&content, &mut content,
) )
.await; .await;
@ -1397,30 +1482,6 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
} }
} }
fn substitute_content(guild: Option<Guild>, content: &str) -> String {
if let Some(guild) = guild {
REGEX_CONTENT_SUBSTITUTION
.replace(content, |caps: &Captures| {
if let Some(user) = caps.name("user") {
format!("<@{}>", user.as_str())
} else if let Some(role_name) = caps.name("role") {
if let Some(role) = guild.role_by_name(role_name.as_str()) {
role.mention()
} else {
role_name.as_str().to_string()
}
} else {
String::new()
}
})
.to_string()
} else {
content.to_string()
}
.replace("<<everyone>>", "@everyone")
.replace("<<here>>", "@here")
}
async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>( async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
ctx: impl CacheHttp + AsRef<Cache>, ctx: impl CacheHttp + AsRef<Cache>,
pool: &MySqlPool, pool: &MySqlPool,
@ -1429,15 +1490,15 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
scope_id: &ReminderScope, scope_id: &ReminderScope,
time_parser: T, time_parser: T,
interval: Option<i64>, interval: Option<i64>,
content: &str, content: &mut Content,
) -> Result<(), ReminderError> { ) -> Result<(), ReminderError> {
let user_id = user_id.into(); let user_id = user_id.into();
let content_string = if let Some(g_id) = guild_id { if let Some(g_id) = guild_id {
substitute_content(g_id.to_guild_cached(&ctx).await, content) if let Some(guild) = g_id.to_guild_cached(&ctx).await {
} else { content.substitute(guild);
content.to_string() }
}; }
let mut nudge = 0; let mut nudge = 0;
@ -1460,6 +1521,7 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
let mut channel_data = ChannelData::from_channel(channel.clone(), &pool) let mut channel_data = ChannelData::from_channel(channel.clone(), &pool)
.await .await
.unwrap(); .unwrap();
nudge = channel_data.nudge; nudge = channel_data.nudge;
if let Some(guild_channel) = channel.guild() { if let Some(guild_channel) = channel.guild() {
@ -1483,10 +1545,8 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
} }
}; };
// validate time, channel, content // validate time, channel
if content_string.is_empty() { if interval.map_or(false, |inner| inner < *MIN_INTERVAL) {
Err(ReminderError::NotEnoughArgs)
} else if interval.map_or(false, |inner| inner < *MIN_INTERVAL) {
Err(ReminderError::ShortInterval) Err(ReminderError::ShortInterval)
} else if interval.map_or(false, |inner| inner > *MAX_TIME) { } else if interval.map_or(false, |inner| inner > *MAX_TIME) {
Err(ReminderError::LongInterval) Err(ReminderError::LongInterval)
@ -1506,9 +1566,12 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
} else { } else {
sqlx::query!( sqlx::query!(
" "
INSERT INTO messages (content) VALUES (?) INSERT INTO messages (content, tts, attachment, attachment_name) VALUES (?, ?, ?, ?)
", ",
content_string content.content,
content.tts,
content.attachment,
content.attachment_name,
) )
.execute(&pool.clone()) .execute(&pool.clone())
.await .await
@ -1523,7 +1586,7 @@ INSERT INTO reminders (uid, message_id, channel_id, time, `interval`, method, se
(SELECT id FROM users WHERE user = ? LIMIT 1)) (SELECT id FROM users WHERE user = ? LIMIT 1))
", ",
generate_uid(), generate_uid(),
content_string, content.content,
db_channel_id, db_channel_id,
time as u32, time as u32,
interval, interval,

View File

@ -345,11 +345,7 @@ impl Framework for RegexFramework {
} }
// gate to prevent analysing messages unnecessarily // gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() {
|| msg.tts
|| msg.content.is_empty()
|| !msg.attachments.is_empty()
{
} }
// Guild Command // Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) = else if let (Some(guild), Some(Channel::Guild(channel))) =