From 5f0aa0f8341368044eb9c5e6db26e8a77e33f90e Mon Sep 17 00:00:00 2001 From: jude Date: Tue, 5 Mar 2024 20:36:38 +0000 Subject: [PATCH] Add routes for getting/posting user reminders --- Cargo.lock | 2 + Cargo.toml | 1 + reminder-dashboard/src/api.ts | 5 + .../Reminder/ButtonRow/CreateButtonRow.tsx | 19 +- .../components/Reminder/CreateReminder.tsx | 3 +- .../src/components/Reminder/Settings.tsx | 34 +-- web/src/lib.rs | 8 +- .../routes/dashboard/api/guild/reminders.rs | 2 +- web/src/routes/dashboard/api/user/mod.rs | 3 + web/src/routes/dashboard/api/user/models.rs | 246 ++++++++++++++++-- .../routes/dashboard/api/user/reminders.rs | 106 ++++++-- web/src/routes/dashboard/mod.rs | 18 +- 12 files changed, 381 insertions(+), 66 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5284aed..61e419f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -942,6 +942,7 @@ checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -2366,6 +2367,7 @@ dependencies = [ "dotenv", "env_logger", "extract_derive", + "futures", "lazy-regex", "lazy_static", "levenshtein", diff --git a/Cargo.toml b/Cargo.toml index 20ad7e0..78dbf6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ levenshtein = "1.0" sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono", "migrate"] } base64 = "0.21" secrecy = "0.8.0" +futures = "0.3.30" [dependencies.postman] path = "postman" diff --git a/reminder-dashboard/src/api.ts b/reminder-dashboard/src/api.ts index 363f96c..376a42b 100644 --- a/reminder-dashboard/src/api.ts +++ b/reminder-dashboard/src/api.ts @@ -182,3 +182,8 @@ export const fetchUserReminders = () => ({ axios.get(`/dashboard/api/user/reminders`).then((resp) => resp.data) as Promise, staleTime: OTHER_STALE_TIME, }); + +export const postUserReminder = () => ({ + mutationFn: (reminder: Reminder) => + axios.post(`/dashboard/api/user/reminders`, reminder).then((resp) => resp.data), +}); diff --git a/reminder-dashboard/src/components/Reminder/ButtonRow/CreateButtonRow.tsx b/reminder-dashboard/src/components/Reminder/ButtonRow/CreateButtonRow.tsx index 72434f7..b163e93 100644 --- a/reminder-dashboard/src/components/Reminder/ButtonRow/CreateButtonRow.tsx +++ b/reminder-dashboard/src/components/Reminder/ButtonRow/CreateButtonRow.tsx @@ -1,14 +1,15 @@ import { LoadTemplate } from "../LoadTemplate"; import { useReminder } from "../ReminderContext"; import { useMutation, useQueryClient } from "react-query"; -import { postGuildReminder, postGuildTemplate } from "../../../api"; +import { postGuildReminder, postGuildTemplate, postUserReminder } from "../../../api"; import { useParams } from "wouter"; import { useState } from "preact/hooks"; import { ICON_FLASH_TIME } from "../../../consts"; import { useFlash } from "../../App/FlashContext"; +import { useGuild } from "../../App/useGuild"; export const CreateButtonRow = () => { - const { guild } = useParams(); + const guild = useGuild(); const [reminder] = useReminder(); const [recentlyCreated, setRecentlyCreated] = useState(false); @@ -17,7 +18,7 @@ export const CreateButtonRow = () => { const flash = useFlash(); const queryClient = useQueryClient(); const mutation = useMutation({ - ...postGuildReminder(guild), + ...(guild ? postGuildReminder(guild) : postUserReminder()), onSuccess: (data) => { if (data.error) { flash({ @@ -29,9 +30,15 @@ export const CreateButtonRow = () => { message: "Reminder created", type: "success", }); - queryClient.invalidateQueries({ - queryKey: ["GUILD_REMINDERS", guild], - }); + if (guild) { + queryClient.invalidateQueries({ + queryKey: ["GUILD_REMINDERS", guild], + }); + } else { + queryClient.invalidateQueries({ + queryKey: ["USER_REMINDERS"], + }); + } setRecentlyCreated(true); setTimeout(() => { setRecentlyCreated(false); diff --git a/reminder-dashboard/src/components/Reminder/CreateReminder.tsx b/reminder-dashboard/src/components/Reminder/CreateReminder.tsx index 777009c..0ab85d1 100644 --- a/reminder-dashboard/src/components/Reminder/CreateReminder.tsx +++ b/reminder-dashboard/src/components/Reminder/CreateReminder.tsx @@ -9,6 +9,7 @@ import { ReminderContext } from "./ReminderContext"; import { useQuery } from "react-query"; import { useParams } from "wouter"; import "./styles.scss"; +import { useGuild } from "../App/useGuild"; function defaultReminder(): Reminder { return { @@ -42,7 +43,7 @@ function defaultReminder(): Reminder { } export const CreateReminder = () => { - const { guild } = useParams(); + const guild = useGuild(); const [reminder, setReminder] = useState(defaultReminder()); const [collapsed, setCollapsed] = useState(false); diff --git a/reminder-dashboard/src/components/Reminder/Settings.tsx b/reminder-dashboard/src/components/Reminder/Settings.tsx index df29897..463c300 100644 --- a/reminder-dashboard/src/components/Reminder/Settings.tsx +++ b/reminder-dashboard/src/components/Reminder/Settings.tsx @@ -7,8 +7,10 @@ import { useReminder } from "./ReminderContext"; import { Attachment } from "./Attachment"; import { TTS } from "./TTS"; import { TimeInput } from "./TimeInput"; +import { useGuild } from "../App/useGuild"; export const Settings = () => { + const guild = useGuild(); const { isSuccess: userFetched, data: userInfo } = useQuery(fetchUserInfo()); const [reminder, setReminder] = useReminder(); @@ -19,22 +21,24 @@ export const Settings = () => { return (
-
-
- + {guild && ( +
+
+ +
+ { + setReminder((reminder) => ({ + ...reminder, + channel: channel, + })); + }} + />
- { - setReminder((reminder) => ({ - ...reminder, - channel: channel, - })); - }} - /> -
+ )}
diff --git a/web/src/lib.rs b/web/src/lib.rs index fba7e4c..ce75f7a 100644 --- a/web/src/lib.rs +++ b/web/src/lib.rs @@ -35,8 +35,10 @@ type Database = MySql; #[derive(Debug)] enum Error { - SQLx, - Serenity, + #[allow(unused)] + SQLx(sqlx::Error), + #[allow(unused)] + Serenity(serenity::Error), } pub async fn initialize( @@ -132,6 +134,8 @@ pub async fn initialize( routes::dashboard::api::user::get_user_info, routes::dashboard::api::user::update_user_info, routes::dashboard::api::user::get_user_guilds, + routes::dashboard::api::user::get_reminders, + routes::dashboard::api::user::create_user_reminder, routes::dashboard::api::guild::get_guild_info, routes::dashboard::api::guild::get_guild_channels, routes::dashboard::api::guild::get_guild_roles, diff --git a/web/src/routes/dashboard/api/guild/reminders.rs b/web/src/routes/dashboard/api/guild/reminders.rs index 0c9c839..6c94a31 100644 --- a/web/src/routes/dashboard/api/guild/reminders.rs +++ b/web/src/routes/dashboard/api/guild/reminders.rs @@ -106,7 +106,7 @@ pub async fn get_reminders( reminders.username, reminders.utc_time FROM reminders - LEFT JOIN channels ON channels.id = reminders.channel_id + INNER JOIN channels ON channels.id = reminders.channel_id WHERE `status` = 'pending' AND FIND_IN_SET(channels.channel, ?)", channels ) diff --git a/web/src/routes/dashboard/api/user/mod.rs b/web/src/routes/dashboard/api/user/mod.rs index 1e0a588..91b02ba 100644 --- a/web/src/routes/dashboard/api/user/mod.rs +++ b/web/src/routes/dashboard/api/user/mod.rs @@ -1,9 +1,12 @@ mod guilds; +mod models; +mod reminders; use std::env; use chrono_tz::Tz; pub use guilds::*; +pub use reminders::*; use rocket::{ http::CookieJar, serde::json::{json, Json, Value as JsonValue}, diff --git a/web/src/routes/dashboard/api/user/models.rs b/web/src/routes/dashboard/api/user/models.rs index 339e7c3..6b64fc8 100644 --- a/web/src/routes/dashboard/api/user/models.rs +++ b/web/src/routes/dashboard/api/user/models.rs @@ -1,20 +1,230 @@ -use std::env; - -use chrono_tz::Tz; -use reqwest::Client; -use rocket::{ - http::CookieJar, - serde::json::{json, Json, Value as JsonValue}, - State, -}; +use chrono::{naive::NaiveDateTime, Utc}; +use futures::TryFutureExt; +use rocket::serde::json::json; use serde::{Deserialize, Serialize}; -use serenity::{ - client::Context, - model::{ - id::{GuildId, RoleId}, - permissions::Permissions, - }, -}; -use sqlx::{MySql, Pool}; +use serenity::{client::Context, futures, model::id::UserId}; +use sqlx::types::Json; -use crate::{consts::DISCORD_API, routes::JsonResult}; +use crate::{ + check_subscription, + consts::{ + DAY, MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH, MAX_EMBED_DESCRIPTION_LENGTH, + MAX_EMBED_FIELDS, MAX_EMBED_FIELD_TITLE_LENGTH, MAX_EMBED_FIELD_VALUE_LENGTH, + MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_NAME_LENGTH, MAX_URL_LENGTH, + MIN_INTERVAL, + }, + guards::transaction::Transaction, + routes::{ + dashboard::{create_database_channel, generate_uid, name_default, Attachment, EmbedField}, + JsonResult, + }, + Error, +}; + +#[derive(Serialize, Deserialize)] +pub struct Reminder { + pub attachment: Option, + pub attachment_name: Option, + pub content: String, + pub embed_author: String, + pub embed_author_url: Option, + pub embed_color: u32, + pub embed_description: String, + pub embed_footer: String, + pub embed_footer_url: Option, + pub embed_image_url: Option, + pub embed_thumbnail_url: Option, + pub embed_title: String, + pub embed_fields: Option>>, + pub enabled: bool, + pub expires: Option, + pub interval_seconds: Option, + pub interval_days: Option, + pub interval_months: Option, + #[serde(default = "name_default")] + pub name: String, + pub tts: bool, + #[serde(default)] + pub uid: String, + pub utc_time: NaiveDateTime, +} + +pub async fn create_reminder( + ctx: &Context, + transaction: &mut Transaction<'_>, + user_id: UserId, + reminder: Reminder, +) -> JsonResult { + let channel = user_id + .create_dm_channel(&ctx) + .map_err(|e| Error::Serenity(e)) + .and_then(|dm_channel| create_database_channel(&ctx, dm_channel.id, transaction)) + .await; + + if let Err(e) = channel { + warn!("`create_database_channel` returned an error code: {:?}", e); + + return Err(json!({"error": "Failed to configure channel for reminders."})); + } + + let channel = channel.unwrap(); + + // validate lengths + check_length!(MAX_NAME_LENGTH, reminder.name); + check_length!(MAX_CONTENT_LENGTH, reminder.content); + check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder.embed_description); + check_length!(MAX_EMBED_TITLE_LENGTH, reminder.embed_title); + check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder.embed_author); + check_length!(MAX_EMBED_FOOTER_LENGTH, reminder.embed_footer); + check_length_opt!(MAX_EMBED_FIELDS, reminder.embed_fields); + if let Some(fields) = &reminder.embed_fields { + for field in &fields.0 { + check_length!(MAX_EMBED_FIELD_VALUE_LENGTH, field.value); + check_length!(MAX_EMBED_FIELD_TITLE_LENGTH, field.title); + } + } + check_length_opt!( + MAX_URL_LENGTH, + reminder.embed_footer_url, + reminder.embed_thumbnail_url, + reminder.embed_author_url, + reminder.embed_image_url + ); + + // validate urls + check_url_opt!( + reminder.embed_footer_url, + reminder.embed_thumbnail_url, + reminder.embed_author_url, + reminder.embed_image_url + ); + + // validate time and interval + if reminder.utc_time < Utc::now().naive_utc() { + return Err(json!({"error": "Time must be in the future"})); + } + if reminder.interval_seconds.is_some() + || reminder.interval_days.is_some() + || reminder.interval_months.is_some() + { + if reminder.interval_months.unwrap_or(0) * 30 * DAY as u32 + + reminder.interval_days.unwrap_or(0) * DAY as u32 + + reminder.interval_seconds.unwrap_or(0) + < *MIN_INTERVAL + { + return Err(json!({"error": "Interval too short"})); + } + } + + // check patreon if necessary + if reminder.interval_seconds.is_some() + || reminder.interval_days.is_some() + || reminder.interval_months.is_some() + { + if !check_subscription(&ctx, user_id).await { + return Err(json!({"error": "Patreon is required to set intervals"})); + } + } + + let name = if reminder.name.is_empty() { name_default() } else { reminder.name.clone() }; + let new_uid = generate_uid(); + + // write to db + match sqlx::query!( + "INSERT INTO reminders ( + uid, + attachment, + attachment_name, + channel_id, + content, + embed_author, + embed_author_url, + embed_color, + embed_description, + embed_footer, + embed_footer_url, + embed_image_url, + embed_thumbnail_url, + embed_title, + embed_fields, + enabled, + expires, + interval_seconds, + interval_days, + interval_months, + name, + tts, + `utc_time` + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + new_uid, + reminder.attachment, + reminder.attachment_name, + channel, + reminder.content, + reminder.embed_author, + reminder.embed_author_url, + reminder.embed_color, + reminder.embed_description, + reminder.embed_footer, + reminder.embed_footer_url, + reminder.embed_image_url, + reminder.embed_thumbnail_url, + reminder.embed_title, + reminder.embed_fields, + reminder.enabled, + reminder.expires, + reminder.interval_seconds, + reminder.interval_days, + reminder.interval_months, + name, + reminder.tts, + reminder.utc_time, + ) + .execute(transaction.executor()) + .await + { + Ok(_) => sqlx::query_as_unchecked!( + Reminder, + "SELECT + reminders.attachment, + reminders.attachment_name, + reminders.content, + reminders.embed_author, + reminders.embed_author_url, + reminders.embed_color, + reminders.embed_description, + reminders.embed_footer, + reminders.embed_footer_url, + reminders.embed_image_url, + reminders.embed_thumbnail_url, + reminders.embed_title, + reminders.embed_fields, + reminders.enabled, + reminders.expires, + reminders.interval_seconds, + reminders.interval_days, + reminders.interval_months, + reminders.name, + reminders.tts, + reminders.uid, + reminders.utc_time + FROM reminders + WHERE uid = ?", + new_uid + ) + .fetch_one(transaction.executor()) + .await + .map(|r| Ok(json!(r))) + .unwrap_or_else(|e| { + warn!("Failed to complete SQL query: {:?}", e); + + Err(json!({"error": "Could not load reminder"})) + }), + + Err(e) => { + warn!("Error in `create_reminder`: Could not execute query: {:?}", e); + + Err(json!({"error": "Unknown error"})) + } + } +} diff --git a/web/src/routes/dashboard/api/user/reminders.rs b/web/src/routes/dashboard/api/user/reminders.rs index 39d32be..f2da8b8 100644 --- a/web/src/routes/dashboard/api/user/reminders.rs +++ b/web/src/routes/dashboard/api/user/reminders.rs @@ -1,23 +1,48 @@ -use std::env; - -use chrono_tz::Tz; -use reqwest::Client; use rocket::{ http::CookieJar, - serde::json::{json, Json, Value as JsonValue}, + serde::json::{json, Json}, State, }; -use serde::{Deserialize, Serialize}; -use serenity::{ - client::Context, - model::{ - id::{GuildId, RoleId}, - permissions::Permissions, - }, -}; +use serenity::{client::Context, model::id::UserId}; use sqlx::{MySql, Pool}; -use crate::{consts::DISCORD_API, routes::JsonResult}; +use crate::{ + guards::transaction::Transaction, + routes::{ + dashboard::api::user::models::{create_reminder, Reminder}, + JsonResult, + }, +}; + +#[post("/api/user/reminders", data = "")] +pub async fn create_user_reminder( + reminder: Json, + cookies: &CookieJar<'_>, + ctx: &State, + mut transaction: Transaction<'_>, +) -> JsonResult { + let user_id = + cookies.get_private("userid").map(|c| c.value().parse::().ok()).flatten().unwrap(); + + match create_reminder( + ctx.inner(), + &mut transaction, + UserId::new(user_id), + reminder.into_inner(), + ) + .await + { + Ok(r) => match transaction.commit().await { + Ok(_) => Ok(r), + Err(e) => { + warn!("Couldn't commit transaction: {:?}", e); + json_err!("Couldn't commit transaction.") + } + }, + + Err(e) => Err(e), + } +} #[get("/api/user/reminders")] pub async fn get_reminders( @@ -25,5 +50,56 @@ pub async fn get_reminders( ctx: &State, pool: &State>, ) -> JsonResult { - Ok(json! {}) + let user_id = + cookies.get_private("userid").map(|c| c.value().parse::().ok()).flatten().unwrap(); + let channel = UserId::new(user_id).create_dm_channel(ctx.inner()).await; + + match channel { + Ok(channel) => sqlx::query_as_unchecked!( + Reminder, + " + SELECT + reminders.attachment, + reminders.attachment_name, + reminders.content, + reminders.embed_author, + reminders.embed_author_url, + reminders.embed_color, + reminders.embed_description, + reminders.embed_footer, + reminders.embed_footer_url, + reminders.embed_image_url, + reminders.embed_thumbnail_url, + reminders.embed_title, + IFNULL(reminders.embed_fields, '[]') AS embed_fields, + reminders.enabled, + reminders.expires, + reminders.interval_seconds, + reminders.interval_days, + reminders.interval_months, + reminders.name, + reminders.tts, + reminders.uid, + reminders.utc_time + FROM reminders + INNER JOIN channels ON channels.id = reminders.channel_id + WHERE `status` = 'pending' AND channels.channel = ? + ", + channel.id.get() + ) + .fetch_all(pool.inner()) + .await + .map(|r| Ok(json!(r))) + .unwrap_or_else(|e| { + warn!("Failed to complete SQL query: {:?}", e); + + json_err!("Could not load reminders") + }), + + Err(e) => { + warn!("Couldn't get DM channel: {:?}", e); + + json_err!("Could not find a DM channel") + } + } } diff --git a/web/src/routes/dashboard/mod.rs b/web/src/routes/dashboard/mod.rs index ba3f26c..3cd4de4 100644 --- a/web/src/routes/dashboard/mod.rs +++ b/web/src/routes/dashboard/mod.rs @@ -55,7 +55,7 @@ fn interval_default() -> Unset> { #[derive(sqlx::Type)] #[sqlx(transparent)] -struct Attachment(Vec); +pub struct Attachment(Vec); impl<'de> Deserialize<'de> for Attachment { fn deserialize(deserializer: D) -> Result @@ -605,11 +605,13 @@ async fn create_database_channel( match row { Ok(row) => { - if row.webhook_token.is_none() || row.webhook_id.is_none() { + let is_dm = + channel.to_channel(&ctx).await.map_err(|e| Error::Serenity(e))?.private().is_some(); + if !is_dm && (row.webhook_token.is_none() || row.webhook_id.is_none()) { let webhook = channel .create_webhook(&ctx, CreateWebhook::new("Reminder").avatar(&*DEFAULT_AVATAR)) .await - .map_err(|_| Error::Serenity)?; + .map_err(|e| Error::Serenity(e))?; let token = webhook.token.unwrap(); @@ -623,7 +625,7 @@ async fn create_database_channel( ) .execute(transaction.executor()) .await - .map_err(|_| Error::SQLx)?; + .map_err(|e| Error::SQLx(e))?; } Ok(()) @@ -634,7 +636,7 @@ async fn create_database_channel( let webhook = channel .create_webhook(&ctx, CreateWebhook::new("Reminder").avatar(&*DEFAULT_AVATAR)) .await - .map_err(|_| Error::Serenity)?; + .map_err(|e| Error::Serenity(e))?; let token = webhook.token.unwrap(); @@ -653,18 +655,18 @@ async fn create_database_channel( ) .execute(transaction.executor()) .await - .map_err(|_| Error::SQLx)?; + .map_err(|e| Error::SQLx(e))?; Ok(()) } - Err(_) => Err(Error::SQLx), + Err(e) => Err(Error::SQLx(e)), }?; let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.get()) .fetch_one(transaction.executor()) .await - .map_err(|_| Error::SQLx)?; + .map_err(|e| Error::SQLx(e))?; Ok(row.id) }