From 79c86d43f2e6f8f951f46890f8f5eb3213917b51 Mon Sep 17 00:00:00 2001 From: jude Date: Sun, 24 Jul 2022 20:06:37 +0100 Subject: [PATCH] Changed return types to results --- web/src/lib.rs | 2 +- web/src/macros.rs | 16 +- web/src/routes/dashboard/export.rs | 143 ++++++++++----- web/src/routes/dashboard/guild.rs | 285 +++++------------------------ web/src/routes/dashboard/mod.rs | 230 ++++++++++++++++++++++- 5 files changed, 378 insertions(+), 298 deletions(-) diff --git a/web/src/lib.rs b/web/src/lib.rs index 783658d..81584f2 100644 --- a/web/src/lib.rs +++ b/web/src/lib.rs @@ -146,7 +146,7 @@ pub async fn initialize( routes::dashboard::guild::get_reminder_templates, routes::dashboard::guild::create_reminder_template, routes::dashboard::guild::delete_reminder_template, - routes::dashboard::guild::create_reminder, + routes::dashboard::guild::create_guild_reminder, routes::dashboard::guild::get_reminders, routes::dashboard::guild::edit_reminder, routes::dashboard::guild::delete_reminder, diff --git a/web/src/macros.rs b/web/src/macros.rs index 8b33616..288617c 100644 --- a/web/src/macros.rs +++ b/web/src/macros.rs @@ -1,7 +1,7 @@ macro_rules! check_length { ($max:ident, $field:expr) => { if $field.len() > $max { - return json!({ "error": format!("{} exceeded", stringify!($max)) }); + return Err(json!({ "error": format!("{} exceeded", stringify!($max)) })); } }; ($max:ident, $field:expr, $($fields:expr),+) => { @@ -25,7 +25,7 @@ macro_rules! check_length_opt { macro_rules! check_url { ($field:expr) => { if !($field.starts_with("http://") || $field.starts_with("https://")) { - return json!({ "error": "URL invalid" }); + return Err(json!({ "error": "URL invalid" })); } }; ($field:expr, $($fields:expr),+) => { @@ -60,7 +60,7 @@ macro_rules! check_authorization { match member { Err(_) => { - return json!({"error": "User not in guild"}) + return Err(json!({"error": "User not in guild"})); } Ok(_) => {} @@ -68,13 +68,13 @@ macro_rules! check_authorization { } None => { - return json!({"error": "Bot not in guild"}) + return Err(json!({"error": "Bot not in guild"})); } } } None => { - return json!({"error": "User not authorized"}); + return Err(json!({"error": "User not authorized"})); } } } @@ -117,3 +117,9 @@ macro_rules! update_field { update_field!($pool, $error, $reminder.[$($fields),+]); }; } + +macro_rules! json_err { + ($message:expr) => { + Err(json!({ "error": $message })) + }; +} diff --git a/web/src/routes/dashboard/export.rs b/web/src/routes/dashboard/export.rs index b0005bf..6dd530b 100644 --- a/web/src/routes/dashboard/export.rs +++ b/web/src/routes/dashboard/export.rs @@ -1,7 +1,7 @@ use csv::{QuoteStyle, WriterBuilder}; use rocket::{ http::CookieJar, - serde::json::{json, Json, Value as JsonValue}, + serde::json::{json, serde_json, Json}, State, }; use serenity::{ @@ -10,7 +10,10 @@ use serenity::{ }; use sqlx::{MySql, Pool}; -use crate::routes::dashboard::{ImportBody, ReminderCsv, ReminderTemplateCsv, TodoCsv}; +use crate::routes::dashboard::{ + create_reminder, generate_uid, ImportBody, JsonResult, Reminder, ReminderCsv, + ReminderTemplateCsv, TodoCsv, +}; #[get("/api/guild//export/reminders")] pub async fn export_reminders( @@ -18,7 +21,7 @@ pub async fn export_reminders( cookies: &CookieJar<'_>, ctx: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); let mut csv_writer = WriterBuilder::new().quote_style(QuoteStyle::Always).from_writer(vec![]); @@ -40,7 +43,7 @@ pub async fn export_reminders( reminders.attachment, reminders.attachment_name, reminders.avatar, - channels.channel, + CONCAT('#', channels.channel) AS channel, reminders.content, reminders.embed_author, reminders.embed_author_url, @@ -77,21 +80,19 @@ pub async fn export_reminders( match csv_writer.into_inner() { Ok(inner) => match String::from_utf8(inner) { - Ok(encoded) => { - json!({ "body": encoded }) - } + Ok(encoded) => Ok(json!({ "body": encoded })), Err(e) => { warn!("Failed to write UTF-8: {:?}", e); - json!({"error": "Failed to write UTF-8"}) + Err(json!({"error": "Failed to write UTF-8"})) } }, Err(e) => { warn!("Failed to extract CSV: {:?}", e); - json!({"error": "Failed to extract CSV"}) + Err(json!({"error": "Failed to extract CSV"})) } } } @@ -99,7 +100,7 @@ pub async fn export_reminders( Err(e) => { warn!("Failed to complete SQL query: {:?}", e); - json!({"error": "Failed to query reminders"}) + Err(json!({"error": "Failed to query reminders"})) } } } @@ -107,7 +108,7 @@ pub async fn export_reminders( Err(e) => { warn!("Could not fetch channels from {}: {:?}", id, e); - json!({"error": "Failed to get guild channels"}) + Err(json!({"error": "Failed to get guild channels"})) } } } @@ -119,28 +120,86 @@ pub async fn import_reminders( body: Json, ctx: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); + let user_id = + cookies.get_private("userid").map(|c| c.value().parse::().ok()).flatten().unwrap(); + match base64::decode(&body.body) { Ok(body) => { let mut reader = csv::Reader::from_reader(body.as_slice()); for result in reader.deserialize::() { match result { - Ok(record) => {} + Ok(record) => { + let channel_id = record.channel.split_at(1).1; + + match channel_id.parse::() { + Ok(channel_id) => { + let reminder = Reminder { + attachment: record.attachment, + attachment_name: record.attachment_name, + avatar: record.avatar, + channel: channel_id, + content: record.content, + embed_author: record.embed_author, + embed_author_url: record.embed_author_url, + embed_color: record.embed_color, + embed_description: record.embed_description, + embed_footer: record.embed_footer, + embed_footer_url: record.embed_footer_url, + embed_image_url: record.embed_image_url, + embed_thumbnail_url: record.embed_thumbnail_url, + embed_title: record.embed_title, + embed_fields: record + .embed_fields + .map(|s| serde_json::from_str(&s).ok()) + .flatten(), + enabled: record.enabled, + expires: record.expires, + interval_seconds: record.interval_seconds, + interval_months: record.interval_months, + name: record.name, + restartable: record.restartable, + tts: record.tts, + uid: generate_uid(), + username: record.username, + utc_time: record.utc_time, + }; + + create_reminder( + ctx.inner(), + pool.inner(), + GuildId(id), + UserId(user_id), + reminder, + ) + .await?; + } + + Err(_) => { + return json_err!(format!( + "Failed to parse channel {}", + channel_id + )); + } + } + } Err(e) => { warn!("Couldn't deserialize CSV row: {:?}", e); + + return json_err!("Deserialize error. Aborted"); } } } - json!({"error": "Not implemented"}) + Ok(json!({})) } Err(_) => { - json!({"error": "Malformed base64"}) + json_err!("Malformed base64") } } } @@ -151,7 +210,7 @@ pub async fn export_todos( cookies: &CookieJar<'_>, ctx: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); let mut csv_writer = WriterBuilder::new().quote_style(QuoteStyle::Always).from_writer(vec![]); @@ -174,28 +233,27 @@ pub async fn export_todos( match csv_writer.into_inner() { Ok(inner) => match String::from_utf8(inner) { - Ok(encoded) => { - json!({ "body": encoded }) - } + Ok(encoded) => Ok(json!({ "body": encoded })), Err(e) => { warn!("Failed to write UTF-8: {:?}", e); - json!({"error": "Failed to write UTF-8"}) + json_err!("Failed to write UTF-8") } }, Err(e) => { warn!("Failed to extract CSV: {:?}", e); - json!({"error": "Failed to extract CSV"}) + json_err!("Failed to extract CSV") } } } + Err(e) => { warn!("Could not fetch templates from {}: {:?}", id, e); - json!({"error": "Failed to query templates"}) + json_err!("Failed to query templates") } } } @@ -207,7 +265,7 @@ pub async fn import_todos( body: Json, ctx: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); let channels_res = GuildId(id).channels(&ctx.inner()).await; @@ -231,17 +289,18 @@ pub async fn import_todos( if channels.contains_key(&ChannelId(channel_id)) { query_params.push((record.value, Some(channel_id), id)); } else { - return json!({ - "error": - format!("Invalid channel ID {}", channel_id) - }); + return json_err!(format!( + "Invalid channel ID {}", + channel_id + )); } } Err(_) => { - return json!({ - "error": format!("Invalid channel ID {}", channel_id) - }); + return json_err!(format!( + "Invalid channel ID {}", + channel_id + )); } } } @@ -254,7 +313,7 @@ pub async fn import_todos( Err(e) => { warn!("Couldn't deserialize CSV row: {:?}", e); - return json!({"error": "Deserialize error. Aborted"}); + return json_err!("Deserialize error. Aborted"); } } } @@ -279,27 +338,25 @@ pub async fn import_todos( let res = query.execute(pool.inner()).await; match res { - Ok(_) => { - json!({}) - } + Ok(_) => Ok(json!({})), Err(e) => { warn!("Couldn't execute todo query: {:?}", e); - json!({"error": "An unexpected error occured."}) + json_err!("An unexpected error occured.") } } } Err(_) => { - json!({"error": "Malformed base64"}) + json_err!("Malformed base64") } }, Err(e) => { warn!("Couldn't fetch channels for guild {}: {:?}", id, e); - json!({"error": "Couldn't fetch channels."}) + json_err!("Couldn't fetch channels.") } } } @@ -310,7 +367,7 @@ pub async fn export_reminder_templates( cookies: &CookieJar<'_>, ctx: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); let mut csv_writer = WriterBuilder::new().quote_style(QuoteStyle::Always).from_writer(vec![]); @@ -348,28 +405,26 @@ pub async fn export_reminder_templates( match csv_writer.into_inner() { Ok(inner) => match String::from_utf8(inner) { - Ok(encoded) => { - json!({ "body": encoded }) - } + Ok(encoded) => Ok(json!({ "body": encoded })), Err(e) => { warn!("Failed to write UTF-8: {:?}", e); - json!({"error": "Failed to write UTF-8"}) + json_err!("Failed to write UTF-8") } }, Err(e) => { warn!("Failed to extract CSV: {:?}", e); - json!({"error": "Failed to extract CSV"}) + json_err!("Failed to extract CSV") } } } Err(e) => { warn!("Could not fetch templates from {}: {:?}", id, e); - json!({"error": "Failed to query templates"}) + json_err!("Failed to query templates") } } } diff --git a/web/src/routes/dashboard/guild.rs b/web/src/routes/dashboard/guild.rs index 9df732b..83db5f8 100644 --- a/web/src/routes/dashboard/guild.rs +++ b/web/src/routes/dashboard/guild.rs @@ -1,10 +1,8 @@ use std::env; -use base64; -use chrono::Utc; use rocket::{ http::CookieJar, - serde::json::{json, Json, Value as JsonValue}, + serde::json::{json, Json}, State, }; use serde::Serialize; @@ -18,16 +16,14 @@ use serenity::{ use sqlx::{MySql, Pool}; use crate::{ - check_guild_subscription, check_subscription, consts::{ - DAY, MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH, MAX_EMBED_DESCRIPTION_LENGTH, + MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH, MAX_EMBED_DESCRIPTION_LENGTH, MAX_EMBED_FIELDS, MAX_EMBED_FIELD_TITLE_LENGTH, MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, - MIN_INTERVAL, }, routes::dashboard::{ - create_database_channel, generate_uid, name_default, template_name_default, DeleteReminder, - DeleteReminderTemplate, PatchReminder, Reminder, ReminderTemplate, + create_database_channel, create_reminder, template_name_default, DeleteReminder, + DeleteReminderTemplate, JsonResult, PatchReminder, Reminder, ReminderTemplate, }, }; @@ -44,7 +40,7 @@ pub async fn get_guild_patreon( id: u64, cookies: &CookieJar<'_>, ctx: &State, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); match GuildId(id).to_guild_cached(ctx.inner()) { @@ -59,12 +55,10 @@ pub async fn get_guild_patreon( .contains(&RoleId(env::var("PATREON_ROLE_ID").unwrap().parse().unwrap())) }); - json!({ "patreon": patreon }) + Ok(json!({ "patreon": patreon })) } - None => { - json!({"error": "Bot not in guild"}) - } + None => json_err!("Bot not in guild"), } } @@ -73,7 +67,7 @@ pub async fn get_guild_channels( id: u64, cookies: &CookieJar<'_>, ctx: &State, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); match GuildId(id).to_guild_cached(ctx.inner()) { @@ -97,12 +91,10 @@ pub async fn get_guild_channels( }) .collect::>(); - json!(channel_info) + Ok(json!(channel_info)) } - None => { - json!({"error": "Bot not in guild"}) - } + None => json_err!("Bot not in guild"), } } @@ -113,7 +105,7 @@ struct RoleInfo { } #[get("/api/guild//roles")] -pub async fn get_guild_roles(id: u64, cookies: &CookieJar<'_>, ctx: &State) -> JsonValue { +pub async fn get_guild_roles(id: u64, cookies: &CookieJar<'_>, ctx: &State) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); let roles_res = ctx.cache.guild_roles(id); @@ -125,12 +117,12 @@ pub async fn get_guild_roles(id: u64, cookies: &CookieJar<'_>, ctx: &State>(); - json!(roles) + Ok(json!(roles)) } None => { warn!("Could not fetch roles from {}", id); - json!({"error": "Could not get roles"}) + json_err!("Could not get roles") } } } @@ -141,7 +133,7 @@ pub async fn get_reminder_templates( cookies: &CookieJar<'_>, ctx: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); match sqlx::query_as_unchecked!( @@ -152,13 +144,11 @@ pub async fn get_reminder_templates( .fetch_all(pool.inner()) .await { - Ok(templates) => { - json!(templates) - } + Ok(templates) => Ok(json!(templates)), Err(e) => { warn!("Could not fetch templates from {}: {:?}", id, e); - json!({"error": "Could not get templates"}) + json_err!("Could not get templates") } } } @@ -170,7 +160,7 @@ pub async fn create_reminder_template( cookies: &CookieJar<'_>, ctx: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); // validate lengths @@ -254,12 +244,12 @@ pub async fn create_reminder_template( .await { Ok(_) => { - json!({}) + Ok(json!({})) } Err(e) => { warn!("Could not fetch templates from {}: {:?}", id, e); - json!({"error": "Could not get templates"}) + json_err!("Could not get templates") } } } @@ -271,7 +261,7 @@ pub async fn delete_reminder_template( cookies: &CookieJar<'_>, ctx: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); match sqlx::query!( @@ -282,230 +272,41 @@ pub async fn delete_reminder_template( .await { Ok(_) => { - json!({}) + Ok(json!({})) } Err(e) => { warn!("Could not delete template from {}: {:?}", id, e); - json!({"error": "Could not delete template"}) + json_err!("Could not delete template") } } } #[post("/api/guild//reminders", data = "")] -pub async fn create_reminder( +pub async fn create_guild_reminder( id: u64, reminder: Json, cookies: &CookieJar<'_>, serenity_context: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { check_authorization!(cookies, serenity_context.inner(), id); let user_id = cookies.get_private("userid").map(|c| c.value().parse::().ok()).flatten().unwrap(); - // validate channel - let channel = ChannelId(reminder.channel).to_channel_cached(&serenity_context.inner()); - let channel_exists = channel.is_some(); - - let channel_matches_guild = - channel.map_or(false, |c| c.guild().map_or(false, |c| c.guild_id.0 == id)); - - if !channel_matches_guild || !channel_exists { - warn!( - "Error in `create_reminder`: channel {} not found for guild {} (channel exists: {})", - reminder.channel, id, channel_exists - ); - - return json!({"error": "Channel not found"}); - } - - let channel = create_database_channel( + create_reminder( serenity_context.inner(), - ChannelId(reminder.channel), pool.inner(), + GuildId(id), + UserId(user_id), + reminder.into_inner(), ) - .await; - - if let Err(e) = channel { - warn!("`create_database_channel` returned an error code: {:?}", e); - - return json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"}); - } - - let channel = channel.unwrap(); - - // validate lengths - check_length!(MAX_CONTENT_LENGTH, reminder.content); - check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder.embed_description); - check_length!(MAX_EMBED_TITLE_LENGTH, reminder.embed_title); - check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder.embed_author); - check_length!(MAX_EMBED_FOOTER_LENGTH, reminder.embed_footer); - check_length_opt!(MAX_EMBED_FIELDS, reminder.embed_fields); - if let Some(fields) = &reminder.embed_fields { - for field in &fields.0 { - check_length!(MAX_EMBED_FIELD_VALUE_LENGTH, field.value); - check_length!(MAX_EMBED_FIELD_TITLE_LENGTH, field.title); - } - } - check_length_opt!(MAX_USERNAME_LENGTH, reminder.username); - check_length_opt!( - MAX_URL_LENGTH, - reminder.embed_footer_url, - reminder.embed_thumbnail_url, - reminder.embed_author_url, - reminder.embed_image_url, - reminder.avatar - ); - - // validate urls - check_url_opt!( - reminder.embed_footer_url, - reminder.embed_thumbnail_url, - reminder.embed_author_url, - reminder.embed_image_url, - reminder.avatar - ); - - // validate time and interval - if reminder.utc_time < Utc::now().naive_utc() { - return json!({"error": "Time must be in the future"}); - } - if reminder.interval_seconds.is_some() || reminder.interval_months.is_some() { - if reminder.interval_months.unwrap_or(0) * 30 * DAY as u32 - + reminder.interval_seconds.unwrap_or(0) - < *MIN_INTERVAL - { - return json!({"error": "Interval too short"}); - } - } - - // check patreon if necessary - if reminder.interval_seconds.is_some() || reminder.interval_months.is_some() { - if !check_guild_subscription(serenity_context.inner(), GuildId(id)).await - && !check_subscription(serenity_context.inner(), user_id).await - { - return json!({"error": "Patreon is required to set intervals"}); - } - } - - // base64 decode error dropped here - let attachment_data = reminder.attachment.as_ref().map(|s| base64::decode(s).ok()).flatten(); - let name = if reminder.name.is_empty() { name_default() } else { reminder.name.clone() }; - - let new_uid = generate_uid(); - - // write to db - match sqlx::query!( - "INSERT INTO reminders ( - uid, - attachment, - attachment_name, - channel_id, - avatar, - content, - embed_author, - embed_author_url, - embed_color, - embed_description, - embed_footer, - embed_footer_url, - embed_image_url, - embed_thumbnail_url, - embed_title, - embed_fields, - enabled, - expires, - interval_seconds, - interval_months, - name, - restartable, - tts, - username, - `utc_time` - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - new_uid, - attachment_data, - reminder.attachment_name, - channel, - reminder.avatar, - reminder.content, - reminder.embed_author, - reminder.embed_author_url, - reminder.embed_color, - reminder.embed_description, - reminder.embed_footer, - reminder.embed_footer_url, - reminder.embed_image_url, - reminder.embed_thumbnail_url, - reminder.embed_title, - reminder.embed_fields, - reminder.enabled, - reminder.expires, - reminder.interval_seconds, - reminder.interval_months, - name, - reminder.restartable, - reminder.tts, - reminder.username, - reminder.utc_time, - ) - .execute(pool.inner()) .await - { - Ok(_) => sqlx::query_as_unchecked!( - Reminder, - "SELECT - reminders.attachment, - reminders.attachment_name, - reminders.avatar, - channels.channel, - reminders.content, - reminders.embed_author, - reminders.embed_author_url, - reminders.embed_color, - reminders.embed_description, - reminders.embed_footer, - reminders.embed_footer_url, - reminders.embed_image_url, - reminders.embed_thumbnail_url, - reminders.embed_title, - reminders.embed_fields, - reminders.enabled, - reminders.expires, - reminders.interval_seconds, - reminders.interval_months, - reminders.name, - reminders.restartable, - reminders.tts, - reminders.uid, - reminders.username, - reminders.utc_time - FROM reminders - LEFT JOIN channels ON channels.id = reminders.channel_id - WHERE uid = ?", - new_uid - ) - .fetch_one(pool.inner()) - .await - .map(|r| json!(r)) - .unwrap_or_else(|e| { - warn!("Failed to complete SQL query: {:?}", e); - - json!({"error": "Could not load reminder"}) - }), - - Err(e) => { - warn!("Error in `create_reminder`: Could not execute query: {:?}", e); - - json!({"error": "Unknown error"}) - } - } } #[get("/api/guild//reminders")] -pub async fn get_reminders(id: u64, ctx: &State, pool: &State>) -> JsonValue { +pub async fn get_reminders(id: u64, ctx: &State, pool: &State>) -> JsonResult { let channels_res = GuildId(id).channels(&ctx.inner()).await; match channels_res { @@ -552,17 +353,17 @@ pub async fn get_reminders(id: u64, ctx: &State, pool: &State { warn!("Could not fetch channels from {}: {:?}", id, e); - json!([]) + Ok(json!([])) } } } @@ -573,7 +374,7 @@ pub async fn edit_reminder( reminder: Json, serenity_context: &State, pool: &State>, -) -> JsonValue { +) -> JsonResult { let mut error = vec![]; update_field!(pool.inner(), error, reminder.[ @@ -614,7 +415,7 @@ pub async fn edit_reminder( reminder.channel, id ); - return json!({"error": "Channel not found"}); + return Err(json!({"error": "Channel not found"})); } let channel = create_database_channel( @@ -627,7 +428,9 @@ pub async fn edit_reminder( if let Err(e) = channel { warn!("`create_database_channel` returned an error code: {:?}", e); - return json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"}); + return Err( + json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"}), + ); } let channel = channel.unwrap(); @@ -655,7 +458,7 @@ pub async fn edit_reminder( reminder.channel, id ); - return json!({"error": "Channel not found"}); + return Err(json!({"error": "Channel not found"})); } } } @@ -695,12 +498,12 @@ pub async fn edit_reminder( .fetch_one(pool.inner()) .await { - Ok(reminder) => json!({"reminder": reminder, "errors": error}), + Ok(reminder) => Ok(json!({"reminder": reminder, "errors": error})), Err(e) => { warn!("Error exiting `edit_reminder': {:?}", e); - json!({"reminder": Option::::None, "errors": vec!["Unknown error"]}) + Err(json!({"reminder": Option::::None, "errors": vec!["Unknown error"]})) } } } @@ -709,19 +512,17 @@ pub async fn edit_reminder( pub async fn delete_reminder( reminder: Json, pool: &State>, -) -> JsonValue { +) -> JsonResult { match sqlx::query!("DELETE FROM reminders WHERE uid = ?", reminder.uid) .execute(pool.inner()) .await { - Ok(_) => { - json!({}) - } + Ok(_) => Ok(json!({})), Err(e) => { warn!("Error in `delete_reminder`: {:?}", e); - json!({"error": "Could not delete reminder"}) + Err(json!({"error": "Could not delete reminder"})) } } } diff --git a/web/src/routes/dashboard/mod.rs b/web/src/routes/dashboard/mod.rs index a04ee50..843f7c3 100644 --- a/web/src/routes/dashboard/mod.rs +++ b/web/src/routes/dashboard/mod.rs @@ -1,15 +1,29 @@ use std::collections::HashMap; -use chrono::naive::NaiveDateTime; +use chrono::{naive::NaiveDateTime, Utc}; use rand::{rngs::OsRng, seq::IteratorRandom}; -use rocket::{http::CookieJar, response::Redirect}; +use rocket::{ + http::CookieJar, + response::Redirect, + serde::json::{json, Value as JsonValue}, +}; use rocket_dyn_templates::Template; use serde::{Deserialize, Serialize}; -use serenity::{http::Http, model::id::ChannelId}; -use sqlx::{types::Json, Executor}; +use serenity::{ + client::Context, + http::Http, + model::id::{ChannelId, GuildId, UserId}, +}; +use sqlx::{types::Json, Executor, MySql, Pool}; use crate::{ - consts::{CHARACTERS, DEFAULT_AVATAR}, + check_guild_subscription, check_subscription, + consts::{ + CHARACTERS, DAY, DEFAULT_AVATAR, MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH, + MAX_EMBED_DESCRIPTION_LENGTH, MAX_EMBED_FIELDS, MAX_EMBED_FIELD_TITLE_LENGTH, + MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, + MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL, + }, Database, Error, }; @@ -17,6 +31,7 @@ pub mod export; pub mod guild; pub mod user; +pub type JsonResult = Result; type Unset = Option; fn name_default() -> String { @@ -134,7 +149,7 @@ pub struct ReminderCsv { attachment: Option>, attachment_name: Option, avatar: Option, - channel: u64, + channel: String, content: String, embed_author: String, embed_author_url: Option, @@ -284,6 +299,209 @@ pub struct TodoCsv { channel_id: Option, } +pub async fn create_reminder( + ctx: &Context, + pool: &Pool, + guild_id: GuildId, + user_id: UserId, + reminder: Reminder, +) -> JsonResult { + // validate channel + let channel = ChannelId(reminder.channel).to_channel_cached(&ctx); + let channel_exists = channel.is_some(); + + let channel_matches_guild = + channel.map_or(false, |c| c.guild().map_or(false, |c| c.guild_id == guild_id)); + + if !channel_matches_guild || !channel_exists { + warn!( + "Error in `create_reminder`: channel {} not found for guild {} (channel exists: {})", + reminder.channel, guild_id, channel_exists + ); + + return Err(json!({"error": "Channel not found"})); + } + + let channel = create_database_channel(&ctx, ChannelId(reminder.channel), pool).await; + + if let Err(e) = channel { + warn!("`create_database_channel` returned an error code: {:?}", e); + + return Err( + json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"}), + ); + } + + let channel = channel.unwrap(); + + // validate lengths + check_length!(MAX_CONTENT_LENGTH, reminder.content); + check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder.embed_description); + check_length!(MAX_EMBED_TITLE_LENGTH, reminder.embed_title); + check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder.embed_author); + check_length!(MAX_EMBED_FOOTER_LENGTH, reminder.embed_footer); + check_length_opt!(MAX_EMBED_FIELDS, reminder.embed_fields); + if let Some(fields) = &reminder.embed_fields { + for field in &fields.0 { + check_length!(MAX_EMBED_FIELD_VALUE_LENGTH, field.value); + check_length!(MAX_EMBED_FIELD_TITLE_LENGTH, field.title); + } + } + check_length_opt!(MAX_USERNAME_LENGTH, reminder.username); + check_length_opt!( + MAX_URL_LENGTH, + reminder.embed_footer_url, + reminder.embed_thumbnail_url, + reminder.embed_author_url, + reminder.embed_image_url, + reminder.avatar + ); + + // validate urls + check_url_opt!( + reminder.embed_footer_url, + reminder.embed_thumbnail_url, + reminder.embed_author_url, + reminder.embed_image_url, + reminder.avatar + ); + + // validate time and interval + if reminder.utc_time < Utc::now().naive_utc() { + return Err(json!({"error": "Time must be in the future"})); + } + if reminder.interval_seconds.is_some() || reminder.interval_months.is_some() { + if reminder.interval_months.unwrap_or(0) * 30 * DAY as u32 + + reminder.interval_seconds.unwrap_or(0) + < *MIN_INTERVAL + { + return Err(json!({"error": "Interval too short"})); + } + } + + // check patreon if necessary + if reminder.interval_seconds.is_some() || reminder.interval_months.is_some() { + if !check_guild_subscription(&ctx, guild_id).await + && !check_subscription(&ctx, user_id).await + { + return Err(json!({"error": "Patreon is required to set intervals"})); + } + } + + // base64 decode error dropped here + let attachment_data = reminder.attachment.as_ref().map(|s| base64::decode(s).ok()).flatten(); + let name = if reminder.name.is_empty() { name_default() } else { reminder.name.clone() }; + + let new_uid = generate_uid(); + + // write to db + match sqlx::query!( + "INSERT INTO reminders ( + uid, + attachment, + attachment_name, + channel_id, + avatar, + content, + embed_author, + embed_author_url, + embed_color, + embed_description, + embed_footer, + embed_footer_url, + embed_image_url, + embed_thumbnail_url, + embed_title, + embed_fields, + enabled, + expires, + interval_seconds, + interval_months, + name, + restartable, + tts, + username, + `utc_time` + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + new_uid, + attachment_data, + reminder.attachment_name, + channel, + reminder.avatar, + reminder.content, + reminder.embed_author, + reminder.embed_author_url, + reminder.embed_color, + reminder.embed_description, + reminder.embed_footer, + reminder.embed_footer_url, + reminder.embed_image_url, + reminder.embed_thumbnail_url, + reminder.embed_title, + reminder.embed_fields, + reminder.enabled, + reminder.expires, + reminder.interval_seconds, + reminder.interval_months, + name, + reminder.restartable, + reminder.tts, + reminder.username, + reminder.utc_time, + ) + .execute(pool) + .await + { + Ok(_) => sqlx::query_as_unchecked!( + Reminder, + "SELECT + reminders.attachment, + reminders.attachment_name, + reminders.avatar, + channels.channel, + reminders.content, + reminders.embed_author, + reminders.embed_author_url, + reminders.embed_color, + reminders.embed_description, + reminders.embed_footer, + reminders.embed_footer_url, + reminders.embed_image_url, + reminders.embed_thumbnail_url, + reminders.embed_title, + reminders.embed_fields, + reminders.enabled, + reminders.expires, + reminders.interval_seconds, + reminders.interval_months, + reminders.name, + reminders.restartable, + reminders.tts, + reminders.uid, + reminders.username, + reminders.utc_time + FROM reminders + LEFT JOIN channels ON channels.id = reminders.channel_id + WHERE uid = ?", + new_uid + ) + .fetch_one(pool) + .await + .map(|r| Ok(json!(r))) + .unwrap_or_else(|e| { + warn!("Failed to complete SQL query: {:?}", e); + + Err(json!({"error": "Could not load reminder"})) + }), + + Err(e) => { + warn!("Error in `create_reminder`: Could not execute query: {:?}", e); + + Err(json!({"error": "Unknown error"})) + } + } +} + async fn create_database_channel( ctx: impl AsRef, channel: ChannelId,