use std::path::Path; use chrono::{naive::NaiveDateTime, Utc}; use rand::{rngs::OsRng, seq::IteratorRandom}; use rocket::{ fs::{relative, NamedFile}, http::CookieJar, response::Redirect, serde::json::json, }; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serenity::{ client::Context, http::Http, model::id::{ChannelId, GuildId, UserId}, }; use sqlx::types::Json; use crate::{ check_guild_subscription, check_subscription, consts::{ CHARACTERS, DAY, DEFAULT_AVATAR, MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH, MAX_EMBED_DESCRIPTION_LENGTH, MAX_EMBED_FIELDS, MAX_EMBED_FIELD_TITLE_LENGTH, MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_NAME_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL, }, guards::transaction::Transaction, routes::JsonResult, Error, }; pub mod api; pub mod export; type Unset = Option; fn name_default() -> String { "Reminder".to_string() } fn template_name_default() -> String { "Template".to_string() } fn channel_default() -> u64 { 0 } fn id_default() -> u32 { 0 } fn interval_default() -> Unset> { None } #[derive(sqlx::Type)] #[sqlx(transparent)] struct Attachment(Vec); impl<'de> Deserialize<'de> for Attachment { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { let string = String::deserialize(deserializer)?; Ok(Attachment(base64::decode(string).map_err(de::Error::custom)?)) } } impl Serialize for Attachment { fn serialize(&self, serializer: S) -> Result where S: Serializer, { serializer.collect_str(&base64::encode(&self.0)) } } #[derive(Serialize, Deserialize)] pub struct ReminderTemplate { #[serde(default = "id_default")] id: u32, #[serde(default = "id_default")] guild_id: u32, #[serde(default = "template_name_default")] name: String, attachment: Option, attachment_name: Option, avatar: Option, content: String, embed_author: String, embed_author_url: Option, embed_color: u32, embed_description: String, embed_footer: String, embed_footer_url: Option, embed_image_url: Option, embed_thumbnail_url: Option, embed_title: String, embed_fields: Option>>, interval_seconds: Option, interval_days: Option, interval_months: Option, tts: bool, username: Option, } #[derive(Serialize, Deserialize)] pub struct ReminderTemplateCsv { #[serde(default = "template_name_default")] name: String, attachment: Option, attachment_name: Option, avatar: Option, content: String, embed_author: String, embed_author_url: Option, embed_color: u32, embed_description: String, embed_footer: String, embed_footer_url: Option, embed_image_url: Option, embed_thumbnail_url: Option, embed_title: String, embed_fields: Option, interval_seconds: Option, interval_days: Option, interval_months: Option, tts: bool, username: Option, } #[derive(Deserialize)] pub struct DeleteReminderTemplate { id: u32, } #[derive(Serialize, Deserialize)] pub struct EmbedField { title: String, value: String, inline: bool, } #[derive(Serialize, Deserialize)] pub struct Reminder { attachment: Option, attachment_name: Option, avatar: Option, #[serde(with = "string")] channel: u64, content: String, embed_author: String, embed_author_url: Option, embed_color: u32, embed_description: String, embed_footer: String, embed_footer_url: Option, embed_image_url: Option, embed_thumbnail_url: Option, embed_title: String, embed_fields: Option>>, enabled: bool, expires: Option, interval_seconds: Option, interval_days: Option, interval_months: Option, #[serde(default = "name_default")] name: String, restartable: bool, tts: bool, #[serde(default)] uid: String, username: Option, utc_time: NaiveDateTime, } #[derive(Serialize, Deserialize)] pub struct ReminderCsv { attachment: Option, attachment_name: Option, avatar: Option, channel: String, content: String, embed_author: String, embed_author_url: Option, embed_color: u32, embed_description: String, embed_footer: String, embed_footer_url: Option, embed_image_url: Option, embed_thumbnail_url: Option, embed_title: String, embed_fields: Option, enabled: bool, expires: Option, interval_seconds: Option, interval_days: Option, interval_months: Option, #[serde(default = "name_default")] name: String, restartable: bool, tts: bool, username: Option, utc_time: NaiveDateTime, } #[derive(Deserialize)] pub struct PatchReminder { uid: String, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] attachment: Unset>, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] attachment_name: Unset>, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] avatar: Unset>, #[serde(default = "channel_default")] #[serde(with = "string")] channel: u64, #[serde(default)] content: Unset, #[serde(default)] embed_author: Unset, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] embed_author_url: Unset>, #[serde(default)] embed_color: Unset, #[serde(default)] embed_description: Unset, #[serde(default)] embed_footer: Unset, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] embed_footer_url: Unset>, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] embed_image_url: Unset>, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] embed_thumbnail_url: Unset>, #[serde(default)] embed_title: Unset, #[serde(default)] embed_fields: Unset>>, #[serde(default)] enabled: Unset, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] expires: Unset>, #[serde(default = "interval_default")] #[serde(deserialize_with = "deserialize_optional_field")] interval_seconds: Unset>, #[serde(default = "interval_default")] #[serde(deserialize_with = "deserialize_optional_field")] interval_days: Unset>, #[serde(default = "interval_default")] #[serde(deserialize_with = "deserialize_optional_field")] interval_months: Unset>, #[serde(default)] name: Unset, #[serde(default)] restartable: Unset, #[serde(default)] tts: Unset, #[serde(default)] #[serde(deserialize_with = "deserialize_optional_field")] username: Unset>, #[serde(default)] utc_time: Unset, } impl PatchReminder { fn message_ok(&self) -> bool { self.content.as_ref().map_or(true, |c| c.len() <= MAX_CONTENT_LENGTH) && self.embed_author.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_AUTHOR_LENGTH) && self .embed_description .as_ref() .map_or(true, |c| c.len() <= MAX_EMBED_DESCRIPTION_LENGTH) && self.embed_footer.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_FOOTER_LENGTH) && self.embed_title.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_TITLE_LENGTH) && self.embed_fields.as_ref().map_or(true, |c| { c.0.len() <= MAX_EMBED_FIELDS && c.0.iter().all(|f| { f.title.len() <= MAX_EMBED_FIELD_TITLE_LENGTH && f.value.len() <= MAX_EMBED_FIELD_VALUE_LENGTH }) }) && self .username .as_ref() .map_or(true, |c| c.as_ref().map_or(true, |v| v.len() <= MAX_USERNAME_LENGTH)) } } pub fn generate_uid() -> String { let mut generator: OsRng = Default::default(); (0..64) .map(|_| CHARACTERS.chars().choose(&mut generator).unwrap().to_owned().to_string()) .collect::>() .join("") } fn deserialize_optional_field<'de, T, D>(deserializer: D) -> Result>, D::Error> where D: Deserializer<'de>, T: Deserialize<'de>, { Ok(Some(Option::deserialize(deserializer)?)) } // https://github.com/serde-rs/json/issues/329#issuecomment-305608405 mod string { use std::{fmt::Display, str::FromStr}; use serde::{de, Deserialize, Deserializer, Serializer}; pub fn serialize(value: &T, serializer: S) -> Result where T: Display, S: Serializer, { serializer.collect_str(value) } pub fn deserialize<'de, T, D>(deserializer: D) -> Result where T: FromStr, T::Err: Display, D: Deserializer<'de>, { String::deserialize(deserializer)?.parse().map_err(de::Error::custom) } } #[derive(Deserialize)] pub struct DeleteReminder { uid: String, } #[derive(Deserialize)] pub struct ImportBody { body: String, } #[derive(Serialize, Deserialize)] pub struct TodoCsv { value: String, channel_id: Option, } pub(crate) async fn create_reminder( ctx: &Context, transaction: &mut Transaction<'_>, guild_id: GuildId, user_id: UserId, reminder: Reminder, ) -> JsonResult { // check guild in db match sqlx::query!("SELECT 1 as A FROM guilds WHERE guild = ?", guild_id.0) .fetch_one(transaction.executor()) .await { Err(sqlx::Error::RowNotFound) => { if sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id.0) .execute(transaction.executor()) .await .is_err() { return Err(json!({"error": "Guild could not be created"})); } } _ => {} } // validate channel let channel = ChannelId(reminder.channel).to_channel_cached(&ctx); let channel_exists = channel.is_some(); let channel_matches_guild = channel.map_or(false, |c| c.guild().map_or(false, |c| c.guild_id == guild_id)); if !channel_matches_guild || !channel_exists { warn!( "Error in `create_reminder`: channel {} not found for guild {} (channel exists: {})", reminder.channel, guild_id, channel_exists ); return Err(json!({"error": "Channel not found"})); } let channel = create_database_channel(&ctx, ChannelId(reminder.channel), transaction).await; if let Err(e) = channel { warn!("`create_database_channel` returned an error code: {:?}", e); return Err( json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"}), ); } let channel = channel.unwrap(); // validate lengths check_length!(MAX_NAME_LENGTH, reminder.name); check_length!(MAX_CONTENT_LENGTH, reminder.content); check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder.embed_description); check_length!(MAX_EMBED_TITLE_LENGTH, reminder.embed_title); check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder.embed_author); check_length!(MAX_EMBED_FOOTER_LENGTH, reminder.embed_footer); check_length_opt!(MAX_EMBED_FIELDS, reminder.embed_fields); if let Some(fields) = &reminder.embed_fields { for field in &fields.0 { check_length!(MAX_EMBED_FIELD_VALUE_LENGTH, field.value); check_length!(MAX_EMBED_FIELD_TITLE_LENGTH, field.title); } } check_length_opt!(MAX_USERNAME_LENGTH, reminder.username); check_length_opt!( MAX_URL_LENGTH, reminder.embed_footer_url, reminder.embed_thumbnail_url, reminder.embed_author_url, reminder.embed_image_url, reminder.avatar ); // validate urls check_url_opt!( reminder.embed_footer_url, reminder.embed_thumbnail_url, reminder.embed_author_url, reminder.embed_image_url, reminder.avatar ); // validate time and interval if reminder.utc_time < Utc::now().naive_utc() { return Err(json!({"error": "Time must be in the future"})); } if reminder.interval_seconds.is_some() || reminder.interval_days.is_some() || reminder.interval_months.is_some() { if reminder.interval_months.unwrap_or(0) * 30 * DAY as u32 + reminder.interval_days.unwrap_or(0) * DAY as u32 + reminder.interval_seconds.unwrap_or(0) < *MIN_INTERVAL { return Err(json!({"error": "Interval too short"})); } } // check patreon if necessary if reminder.interval_seconds.is_some() || reminder.interval_days.is_some() || reminder.interval_months.is_some() { if !check_guild_subscription(&ctx, guild_id).await && !check_subscription(&ctx, user_id).await { return Err(json!({"error": "Patreon is required to set intervals"})); } } let name = if reminder.name.is_empty() { name_default() } else { reminder.name.clone() }; let username = if reminder.username.as_ref().map(|s| s.is_empty()).unwrap_or(true) { None } else { reminder.username }; let new_uid = generate_uid(); // write to db match sqlx::query!( "INSERT INTO reminders ( uid, attachment, attachment_name, channel_id, avatar, content, embed_author, embed_author_url, embed_color, embed_description, embed_footer, embed_footer_url, embed_image_url, embed_thumbnail_url, embed_title, embed_fields, enabled, expires, interval_seconds, interval_days, interval_months, name, restartable, tts, username, `utc_time` ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", new_uid, reminder.attachment, reminder.attachment_name, channel, reminder.avatar, reminder.content, reminder.embed_author, reminder.embed_author_url, reminder.embed_color, reminder.embed_description, reminder.embed_footer, reminder.embed_footer_url, reminder.embed_image_url, reminder.embed_thumbnail_url, reminder.embed_title, reminder.embed_fields, reminder.enabled, reminder.expires, reminder.interval_seconds, reminder.interval_days, reminder.interval_months, name, reminder.restartable, reminder.tts, username, reminder.utc_time, ) .execute(transaction.executor()) .await { Ok(_) => sqlx::query_as_unchecked!( Reminder, "SELECT reminders.attachment, reminders.attachment_name, reminders.avatar, channels.channel, reminders.content, reminders.embed_author, reminders.embed_author_url, reminders.embed_color, reminders.embed_description, reminders.embed_footer, reminders.embed_footer_url, reminders.embed_image_url, reminders.embed_thumbnail_url, reminders.embed_title, reminders.embed_fields, reminders.enabled, reminders.expires, reminders.interval_seconds, reminders.interval_days, reminders.interval_months, reminders.name, reminders.restartable, reminders.tts, reminders.uid, reminders.username, reminders.utc_time FROM reminders LEFT JOIN channels ON channels.id = reminders.channel_id WHERE uid = ?", new_uid ) .fetch_one(transaction.executor()) .await .map(|r| Ok(json!(r))) .unwrap_or_else(|e| { warn!("Failed to complete SQL query: {:?}", e); Err(json!({"error": "Could not load reminder"})) }), Err(e) => { warn!("Error in `create_reminder`: Could not execute query: {:?}", e); Err(json!({"error": "Unknown error"})) } } } async fn create_database_channel( ctx: impl AsRef, channel: ChannelId, transaction: &mut Transaction<'_>, ) -> Result { let row = sqlx::query!("SELECT webhook_token, webhook_id FROM channels WHERE channel = ?", channel.0) .fetch_one(transaction.executor()) .await; match row { Ok(row) => { if row.webhook_token.is_none() || row.webhook_id.is_none() { let webhook = channel .create_webhook_with_avatar(&ctx, "Reminder", DEFAULT_AVATAR.clone()) .await .map_err(|e| Error::Serenity(e))?; sqlx::query!( "UPDATE channels SET webhook_id = ?, webhook_token = ? WHERE channel = ?", webhook.id.0, webhook.token, channel.0 ) .execute(transaction.executor()) .await .map_err(|e| Error::SQLx(e))?; } Ok(()) } Err(sqlx::Error::RowNotFound) => { // create webhook let webhook = channel .create_webhook_with_avatar(&ctx, "Reminder", DEFAULT_AVATAR.clone()) .await .map_err(|e| Error::Serenity(e))?; // create database entry sqlx::query!( "INSERT INTO channels ( webhook_id, webhook_token, channel ) VALUES (?, ?, ?)", webhook.id.0, webhook.token, channel.0 ) .execute(transaction.executor()) .await .map_err(|e| Error::SQLx(e))?; Ok(()) } Err(e) => Err(Error::SQLx(e)), }?; let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.0) .fetch_one(transaction.executor()) .await .map_err(|e| Error::SQLx(e))?; Ok(row.id) } #[get("/")] pub async fn dashboard_home(cookies: &CookieJar<'_>) -> Result { if cookies.get_private("userid").is_some() { NamedFile::open(Path::new(relative!("static/index.html"))).await.map_err(|e| { warn!("Couldn't render dashboard: {:?}", e); Redirect::to("/login/discord") }) } else { Err(Redirect::to("/login/discord")) } } #[get("/<_..>")] pub async fn dashboard(cookies: &CookieJar<'_>) -> Result { if cookies.get_private("userid").is_some() { NamedFile::open(Path::new(relative!("static/index.html"))).await.map_err(|e| { warn!("Couldn't render dashboard: {:?}", e); Redirect::to("/login/discord") }) } else { Err(Redirect::to("/login/discord")) } }