2024-03-24 20:23:16 +00:00

719 lines
21 KiB
Rust

use std::path::Path;
use base64::{prelude::BASE64_STANDARD, Engine};
use chrono::{naive::NaiveDateTime, Utc};
use log::warn;
use rand::{rngs::OsRng, seq::IteratorRandom};
use rocket::{
fs::NamedFile, get, http::CookieJar, response::Redirect, serde::json::json, Responder,
};
use rocket_dyn_templates::Template;
use secrecy::ExposeSecret;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use serenity::{
all::CacheHttp,
builder::CreateWebhook,
client::Context,
model::id::{ChannelId, GuildId, UserId},
};
use sqlx::types::Json;
use crate::web::{
catchers::internal_server_error,
check_guild_subscription, check_subscription,
consts::{
CHARACTERS, DAY, DEFAULT_AVATAR, MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH,
MAX_EMBED_DESCRIPTION_LENGTH, MAX_EMBED_FIELDS, MAX_EMBED_FIELD_TITLE_LENGTH,
MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH,
MAX_NAME_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL,
},
guards::transaction::Transaction,
routes::JsonResult,
Error,
};
pub mod api;
pub mod export;
type Unset<T> = Option<T>;
fn name_default() -> String {
"Reminder".to_string()
}
fn template_name_default() -> String {
"Template".to_string()
}
fn channel_default() -> u64 {
0
}
fn id_default() -> u32 {
0
}
fn interval_default() -> Unset<Option<u32>> {
None
}
#[derive(sqlx::Type)]
#[sqlx(transparent)]
pub struct Attachment(Vec<u8>);
impl<'de> Deserialize<'de> for Attachment {
fn deserialize<D>(deserializer: D) -> Result<Attachment, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
Ok(Attachment(BASE64_STANDARD.decode(string).map_err(de::Error::custom)?))
}
}
impl Serialize for Attachment {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.collect_str(&BASE64_STANDARD.encode(&self.0))
}
}
#[derive(Serialize, Deserialize)]
pub struct ReminderTemplate {
#[serde(default = "id_default")]
id: u32,
#[serde(default = "id_default")]
guild_id: u32,
#[serde(default = "template_name_default")]
name: String,
attachment: Option<Attachment>,
attachment_name: Option<String>,
avatar: Option<String>,
content: String,
embed_author: String,
embed_author_url: Option<String>,
embed_color: u32,
embed_description: String,
embed_footer: String,
embed_footer_url: Option<String>,
embed_image_url: Option<String>,
embed_thumbnail_url: Option<String>,
embed_title: String,
embed_fields: Option<Json<Vec<EmbedField>>>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
tts: bool,
username: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct ReminderTemplateCsv {
#[serde(default = "template_name_default")]
name: String,
attachment: Option<Attachment>,
attachment_name: Option<String>,
avatar: Option<String>,
content: String,
embed_author: String,
embed_author_url: Option<String>,
embed_color: u32,
embed_description: String,
embed_footer: String,
embed_footer_url: Option<String>,
embed_image_url: Option<String>,
embed_thumbnail_url: Option<String>,
embed_title: String,
embed_fields: Option<String>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
tts: bool,
username: Option<String>,
}
#[derive(Deserialize)]
pub struct DeleteReminderTemplate {
id: u32,
}
#[derive(Serialize, Deserialize)]
pub struct EmbedField {
title: String,
value: String,
inline: bool,
}
#[derive(Serialize, Deserialize)]
pub struct Reminder {
attachment: Option<Attachment>,
attachment_name: Option<String>,
avatar: Option<String>,
#[serde(with = "string")]
channel: u64,
content: String,
embed_author: String,
embed_author_url: Option<String>,
embed_color: u32,
embed_description: String,
embed_footer: String,
embed_footer_url: Option<String>,
embed_image_url: Option<String>,
embed_thumbnail_url: Option<String>,
embed_title: String,
embed_fields: Option<Json<Vec<EmbedField>>>,
enabled: bool,
expires: Option<NaiveDateTime>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
#[serde(default = "name_default")]
name: String,
restartable: bool,
tts: bool,
#[serde(default)]
uid: String,
username: Option<String>,
utc_time: NaiveDateTime,
}
#[derive(Serialize, Deserialize)]
pub struct ReminderCsv {
attachment: Option<Attachment>,
attachment_name: Option<String>,
avatar: Option<String>,
channel: String,
content: String,
embed_author: String,
embed_author_url: Option<String>,
embed_color: u32,
embed_description: String,
embed_footer: String,
embed_footer_url: Option<String>,
embed_image_url: Option<String>,
embed_thumbnail_url: Option<String>,
embed_title: String,
embed_fields: Option<String>,
enabled: bool,
expires: Option<NaiveDateTime>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
#[serde(default = "name_default")]
name: String,
restartable: bool,
tts: bool,
username: Option<String>,
utc_time: NaiveDateTime,
}
#[derive(Deserialize)]
pub struct PatchReminder {
uid: String,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
attachment: Unset<Option<Attachment>>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
attachment_name: Unset<Option<String>>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
avatar: Unset<Option<String>>,
#[serde(default = "channel_default")]
#[serde(with = "string")]
channel: u64,
#[serde(default)]
content: Unset<String>,
#[serde(default)]
embed_author: Unset<String>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
embed_author_url: Unset<Option<String>>,
#[serde(default)]
embed_color: Unset<u32>,
#[serde(default)]
embed_description: Unset<String>,
#[serde(default)]
embed_footer: Unset<String>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
embed_footer_url: Unset<Option<String>>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
embed_image_url: Unset<Option<String>>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
embed_thumbnail_url: Unset<Option<String>>,
#[serde(default)]
embed_title: Unset<String>,
#[serde(default)]
embed_fields: Unset<Json<Vec<EmbedField>>>,
#[serde(default)]
enabled: Unset<bool>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
expires: Unset<Option<NaiveDateTime>>,
#[serde(default = "interval_default")]
#[serde(deserialize_with = "deserialize_optional_field")]
interval_seconds: Unset<Option<u32>>,
#[serde(default = "interval_default")]
#[serde(deserialize_with = "deserialize_optional_field")]
interval_days: Unset<Option<u32>>,
#[serde(default = "interval_default")]
#[serde(deserialize_with = "deserialize_optional_field")]
interval_months: Unset<Option<u32>>,
#[serde(default)]
name: Unset<String>,
#[serde(default)]
restartable: Unset<bool>,
#[serde(default)]
tts: Unset<bool>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
username: Unset<Option<String>>,
#[serde(default)]
utc_time: Unset<NaiveDateTime>,
}
impl PatchReminder {
fn message_ok(&self) -> bool {
self.content.as_ref().map_or(true, |c| c.len() <= MAX_CONTENT_LENGTH)
&& self.embed_author.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_AUTHOR_LENGTH)
&& self
.embed_description
.as_ref()
.map_or(true, |c| c.len() <= MAX_EMBED_DESCRIPTION_LENGTH)
&& self.embed_footer.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_FOOTER_LENGTH)
&& self.embed_title.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_TITLE_LENGTH)
&& self.embed_fields.as_ref().map_or(true, |c| {
c.0.len() <= MAX_EMBED_FIELDS
&& c.0.iter().all(|f| {
f.title.len() <= MAX_EMBED_FIELD_TITLE_LENGTH
&& f.value.len() <= MAX_EMBED_FIELD_VALUE_LENGTH
})
})
&& self
.username
.as_ref()
.map_or(true, |c| c.as_ref().map_or(true, |v| v.len() <= MAX_USERNAME_LENGTH))
}
}
pub fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
(0..64)
.map(|_| CHARACTERS.chars().choose(&mut generator).unwrap().to_owned().to_string())
.collect::<Vec<String>>()
.join("")
}
fn deserialize_optional_field<'de, T, D>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
Ok(Some(Option::deserialize(deserializer)?))
}
// https://github.com/serde-rs/json/issues/329#issuecomment-305608405
mod string {
use std::{fmt::Display, str::FromStr};
use serde::{de, Deserialize, Deserializer, Serializer};
pub fn serialize<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
T: Display,
S: Serializer,
{
serializer.collect_str(value)
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr,
T::Err: Display,
D: Deserializer<'de>,
{
String::deserialize(deserializer)?.parse().map_err(de::Error::custom)
}
}
#[derive(Deserialize)]
pub struct DeleteReminder {
uid: String,
}
#[derive(Deserialize)]
pub struct ImportBody {
body: String,
}
#[derive(Serialize, Deserialize)]
pub struct TodoCsv {
value: String,
channel_id: Option<String>,
}
pub(crate) async fn create_reminder(
ctx: &Context,
transaction: &mut Transaction<'_>,
guild_id: GuildId,
user_id: UserId,
reminder: Reminder,
) -> JsonResult {
// check guild in db
match sqlx::query!("SELECT 1 as A FROM guilds WHERE guild = ?", guild_id.get())
.fetch_one(transaction.executor())
.await
{
Err(sqlx::Error::RowNotFound) => {
if sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id.get())
.execute(transaction.executor())
.await
.is_err()
{
return Err(json!({"error": "Guild could not be created"}));
}
}
_ => {}
}
{
// validate channel
let channel = ChannelId::new(reminder.channel).to_channel_cached(&ctx.cache);
let channel_exists = channel.is_some();
let channel_matches_guild =
channel.map_or(false, |c| c.guild(&ctx.cache).map_or(false, |c| c.id == guild_id));
if !channel_matches_guild || !channel_exists {
warn!(
"Error in `create_reminder`: channel {} not found for guild {} (channel exists: {})",
reminder.channel, guild_id, channel_exists
);
return Err(json!({"error": "Channel not found"}));
}
}
let channel =
create_database_channel(&ctx, ChannelId::new(reminder.channel), transaction).await;
if let Err(e) = channel {
warn!("`create_database_channel` returned an error code: {:?}", e);
return Err(
json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"}),
);
}
let channel = channel.unwrap();
// validate lengths
check_length!(MAX_NAME_LENGTH, reminder.name);
check_length!(MAX_CONTENT_LENGTH, reminder.content);
check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder.embed_description);
check_length!(MAX_EMBED_TITLE_LENGTH, reminder.embed_title);
check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder.embed_author);
check_length!(MAX_EMBED_FOOTER_LENGTH, reminder.embed_footer);
check_length_opt!(MAX_EMBED_FIELDS, reminder.embed_fields);
if let Some(fields) = &reminder.embed_fields {
for field in &fields.0 {
check_length!(MAX_EMBED_FIELD_VALUE_LENGTH, field.value);
check_length!(MAX_EMBED_FIELD_TITLE_LENGTH, field.title);
}
}
check_length_opt!(MAX_USERNAME_LENGTH, reminder.username);
check_length_opt!(
MAX_URL_LENGTH,
reminder.embed_footer_url,
reminder.embed_thumbnail_url,
reminder.embed_author_url,
reminder.embed_image_url,
reminder.avatar
);
// validate urls
check_url_opt!(
reminder.embed_footer_url,
reminder.embed_thumbnail_url,
reminder.embed_author_url,
reminder.embed_image_url,
reminder.avatar
);
// validate time and interval
if reminder.utc_time < Utc::now().naive_utc() {
return Err(json!({"error": "Time must be in the future"}));
}
if reminder.interval_seconds.is_some()
|| reminder.interval_days.is_some()
|| reminder.interval_months.is_some()
{
if reminder.interval_months.unwrap_or(0) * 30 * DAY as u32
+ reminder.interval_days.unwrap_or(0) * DAY as u32
+ reminder.interval_seconds.unwrap_or(0)
< *MIN_INTERVAL
{
return Err(json!({"error": "Interval too short"}));
}
}
// check patreon if necessary
if reminder.interval_seconds.is_some()
|| reminder.interval_days.is_some()
|| reminder.interval_months.is_some()
{
if !check_guild_subscription(&ctx, guild_id).await
&& !check_subscription(&ctx, user_id).await
{
return Err(json!({"error": "Patreon is required to set intervals"}));
}
}
let name = if reminder.name.is_empty() { name_default() } else { reminder.name.clone() };
let username = if reminder.username.as_ref().map(|s| s.is_empty()).unwrap_or(true) {
None
} else {
reminder.username
};
let new_uid = generate_uid();
// write to db
match sqlx::query!(
"INSERT INTO reminders (
uid,
attachment,
attachment_name,
channel_id,
avatar,
content,
embed_author,
embed_author_url,
embed_color,
embed_description,
embed_footer,
embed_footer_url,
embed_image_url,
embed_thumbnail_url,
embed_title,
embed_fields,
enabled,
expires,
interval_seconds,
interval_days,
interval_months,
name,
restartable,
tts,
username,
`utc_time`
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
new_uid,
reminder.attachment,
reminder.attachment_name,
channel,
reminder.avatar,
reminder.content,
reminder.embed_author,
reminder.embed_author_url,
reminder.embed_color,
reminder.embed_description,
reminder.embed_footer,
reminder.embed_footer_url,
reminder.embed_image_url,
reminder.embed_thumbnail_url,
reminder.embed_title,
reminder.embed_fields,
reminder.enabled,
reminder.expires,
reminder.interval_seconds,
reminder.interval_days,
reminder.interval_months,
name,
reminder.restartable,
reminder.tts,
username,
reminder.utc_time,
)
.execute(transaction.executor())
.await
{
Ok(_) => sqlx::query_as_unchecked!(
Reminder,
"SELECT
reminders.attachment,
reminders.attachment_name,
reminders.avatar,
channels.channel,
reminders.content,
reminders.embed_author,
reminders.embed_author_url,
reminders.embed_color,
reminders.embed_description,
reminders.embed_footer,
reminders.embed_footer_url,
reminders.embed_image_url,
reminders.embed_thumbnail_url,
reminders.embed_title,
reminders.embed_fields,
reminders.enabled,
reminders.expires,
reminders.interval_seconds,
reminders.interval_days,
reminders.interval_months,
reminders.name,
reminders.restartable,
reminders.tts,
reminders.uid,
reminders.username,
reminders.utc_time
FROM reminders
LEFT JOIN channels ON channels.id = reminders.channel_id
WHERE uid = ?",
new_uid
)
.fetch_one(transaction.executor())
.await
.map(|r| Ok(json!(r)))
.unwrap_or_else(|e| {
warn!("Failed to complete SQL query: {:?}", e);
Err(json!({"error": "Could not load reminder"}))
}),
Err(e) => {
warn!("Error in `create_reminder`: Could not execute query: {:?}", e);
Err(json!({"error": "Unknown error"}))
}
}
}
async fn create_database_channel(
ctx: impl CacheHttp,
channel: ChannelId,
transaction: &mut Transaction<'_>,
) -> Result<u32, Error> {
let row = sqlx::query!(
"SELECT webhook_token, webhook_id FROM channels WHERE channel = ?",
channel.get()
)
.fetch_one(transaction.executor())
.await;
match row {
Ok(row) => {
let is_dm =
channel.to_channel(&ctx).await.map_err(|e| Error::Serenity(e))?.private().is_some();
if !is_dm && (row.webhook_token.is_none() || row.webhook_id.is_none()) {
let webhook = channel
.create_webhook(&ctx, CreateWebhook::new("Reminder").avatar(&*DEFAULT_AVATAR))
.await
.map_err(|e| Error::Serenity(e))?;
let token = webhook.token.unwrap();
sqlx::query!(
"
UPDATE channels SET webhook_id = ?, webhook_token = ? WHERE channel = ?
",
webhook.id.get(),
token.expose_secret(),
channel.get()
)
.execute(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;
}
Ok(())
}
Err(sqlx::Error::RowNotFound) => {
// create webhook
let webhook = channel
.create_webhook(&ctx, CreateWebhook::new("Reminder").avatar(&*DEFAULT_AVATAR))
.await
.map_err(|e| Error::Serenity(e))?;
let token = webhook.token.unwrap();
// create database entry
sqlx::query!(
"
INSERT INTO channels (
webhook_id,
webhook_token,
channel
) VALUES (?, ?, ?)
",
webhook.id.get(),
token.expose_secret(),
channel.get()
)
.execute(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;
Ok(())
}
Err(e) => Err(Error::SQLx(e)),
}?;
let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.get())
.fetch_one(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;
Ok(row.id)
}
#[derive(Responder)]
pub enum DashboardPage {
#[response(status = 200)]
Ok(NamedFile),
#[response(status = 302)]
Unauthorised(Redirect),
#[response(status = 500)]
NotConfigured(Template),
}
#[get("/")]
pub async fn dashboard_home(cookies: &CookieJar<'_>) -> DashboardPage {
if cookies.get_private("userid").is_some() {
match NamedFile::open(Path::new(path!("static/index.html"))).await {
Ok(f) => DashboardPage::Ok(f),
Err(e) => {
warn!("Couldn't render dashboard: {:?}", e);
DashboardPage::NotConfigured(internal_server_error().await)
}
}
} else {
DashboardPage::Unauthorised(Redirect::to("/login/discord"))
}
}
#[get("/<_..>")]
pub async fn dashboard(cookies: &CookieJar<'_>) -> DashboardPage {
if cookies.get_private("userid").is_some() {
match NamedFile::open(Path::new(path!("static/index.html"))).await {
Ok(f) => DashboardPage::Ok(f),
Err(e) => {
warn!("Couldn't render dashboard: {:?}", e);
DashboardPage::NotConfigured(internal_server_error().await)
}
}
} else {
DashboardPage::Unauthorised(Redirect::to("/login/discord"))
}
}