diff --git a/web/src/guards/transaction.rs b/web/src/guards/transaction.rs index 7585ed5..91061e9 100644 --- a/web/src/guards/transaction.rs +++ b/web/src/guards/transaction.rs @@ -9,6 +9,16 @@ use crate::Database; pub(crate) struct Transaction<'a>(sqlx::Transaction<'a, Database>); +impl Transaction<'_> { + pub(crate) fn executor(&mut self) -> impl sqlx::Executor<'_, Database = Database> { + &mut *(self.0) + } + + pub(crate) async fn commit(self) -> Result<(), sqlx::Error> { + self.0.commit().await + } +} + #[derive(Debug)] pub(crate) enum TransactionError { Error(sqlx::Error), diff --git a/web/src/routes/dashboard/export.rs b/web/src/routes/dashboard/export.rs index 0b5905f..b0b0a17 100644 --- a/web/src/routes/dashboard/export.rs +++ b/web/src/routes/dashboard/export.rs @@ -10,12 +10,15 @@ use serenity::{ }; use sqlx::{MySql, Pool}; -use crate::routes::{ - dashboard::{ - create_reminder, generate_uid, ImportBody, Reminder, ReminderCsv, ReminderTemplateCsv, - TodoCsv, +use crate::{ + guards::transaction::Transaction, + routes::{ + dashboard::{ + create_reminder, generate_uid, ImportBody, Reminder, ReminderCsv, ReminderTemplateCsv, + TodoCsv, + }, + JsonResult, }, - JsonResult, }; #[get("/api/guild//export/reminders")] @@ -118,12 +121,12 @@ pub async fn export_reminders( } #[put("/api/guild//export/reminders", data = "")] -pub async fn import_reminders( +pub(crate) async fn import_reminders( id: u64, cookies: &CookieJar<'_>, body: Json, ctx: &State, - pool: &State>, + mut transaction: Transaction<'_>, ) -> JsonResult { check_authorization!(cookies, ctx.inner(), id); @@ -175,7 +178,7 @@ pub async fn import_reminders( create_reminder( ctx.inner(), - pool.inner(), + &mut transaction, GuildId(id), UserId(user_id), reminder, @@ -200,7 +203,14 @@ pub async fn import_reminders( } } - Ok(json!({})) + match transaction.commit().await { + Ok(_) => Ok(json!({})), + + Err(e) => { + warn!("Failed to commit transaction: {:?}", e); + json_err!("Couldn't commit transaction") + } + } } Err(_) => { diff --git a/web/src/routes/dashboard/guild.rs b/web/src/routes/dashboard/guild.rs index 6cbdb9e..776c6bc 100644 --- a/web/src/routes/dashboard/guild.rs +++ b/web/src/routes/dashboard/guild.rs @@ -23,6 +23,7 @@ use crate::{ MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL, }, + guards::transaction::Transaction, routes::{ dashboard::{ create_database_channel, create_reminder, template_name_default, DeleteReminder, @@ -30,6 +31,7 @@ use crate::{ }, JsonResult, }, + Database, }; #[derive(Serialize)] @@ -302,26 +304,37 @@ pub async fn delete_reminder_template( } #[post("/api/guild//reminders", data = "")] -pub async fn create_guild_reminder( +pub(crate) async fn create_guild_reminder( id: u64, reminder: Json, cookies: &CookieJar<'_>, serenity_context: &State, - pool: &State>, + mut transaction: Transaction<'_>, ) -> JsonResult { check_authorization!(cookies, serenity_context.inner(), id); let user_id = cookies.get_private("userid").map(|c| c.value().parse::().ok()).flatten().unwrap(); - create_reminder( + match create_reminder( serenity_context.inner(), - pool.inner(), + &mut transaction, GuildId(id), UserId(user_id), reminder.into_inner(), ) .await + { + Ok(r) => match transaction.commit().await { + Ok(_) => Ok(r), + Err(e) => { + warn!("Could'nt commit transaction: {:?}", e); + json_err!("Couldn't commit transaction.") + } + }, + + Err(e) => Err(e), + } } #[get("/api/guild//reminders")] @@ -397,11 +410,12 @@ pub async fn get_reminders( } #[patch("/api/guild//reminders", data = "")] -pub async fn edit_reminder( +pub(crate) async fn edit_reminder( id: u64, reminder: Json, serenity_context: &State, - pool: &State>, + mut transaction: Transaction<'_>, + pool: &State>, cookies: &CookieJar<'_>, ) -> JsonResult { check_authorization!(cookies, serenity_context.inner(), id); @@ -412,7 +426,7 @@ pub async fn edit_reminder( cookies.get_private("userid").map(|c| c.value().parse::().ok()).flatten().unwrap(); if reminder.message_ok() { - update_field!(pool.inner(), error, reminder.[ + update_field!(transaction.executor(), error, reminder.[ content, embed_author, embed_description, @@ -425,7 +439,7 @@ pub async fn edit_reminder( error.push("Message exceeds limits.".to_string()); } - update_field!(pool.inner(), error, reminder.[ + update_field!(transaction.executor(), error, reminder.[ attachment, attachment_name, avatar, @@ -455,7 +469,7 @@ pub async fn edit_reminder( "SELECT interval_days AS days FROM reminders WHERE uid = ?", reminder.uid ) - .fetch_one(pool.inner()) + .fetch_one(transaction.executor()) .await .map_err(|e| { warn!("Error updating reminder interval: {:?}", e); @@ -469,7 +483,7 @@ pub async fn edit_reminder( "SELECT interval_months AS months FROM reminders WHERE uid = ?", reminder.uid ) - .fetch_one(pool.inner()) + .fetch_one(transaction.executor()) .await .map_err(|e| { warn!("Error updating reminder interval: {:?}", e); @@ -483,7 +497,7 @@ pub async fn edit_reminder( "SELECT interval_seconds AS seconds FROM reminders WHERE uid = ?", reminder.uid ) - .fetch_one(pool.inner()) + .fetch_one(transaction.executor()) .await .map_err(|e| { warn!("Error updating reminder interval: {:?}", e); @@ -496,7 +510,7 @@ pub async fn edit_reminder( if new_interval_length < *MIN_INTERVAL { error.push(String::from("New interval is too short.")); } else { - update_field!(pool.inner(), error, reminder.[ + update_field!(transaction.executor(), error, reminder.[ interval_days, interval_months, interval_seconds @@ -523,7 +537,7 @@ pub async fn edit_reminder( let channel = create_database_channel( serenity_context.inner(), ChannelId(reminder.channel), - pool.inner(), + &mut transaction, ) .await; @@ -542,7 +556,7 @@ pub async fn edit_reminder( channel, reminder.uid ) - .execute(pool.inner()) + .execute(transaction.executor()) .await { Ok(_) => {} @@ -565,6 +579,11 @@ pub async fn edit_reminder( } } + if let Err(e) = transaction.commit().await { + warn!("Couldn't commit transaction: {:?}", e); + return json_err!("Couldn't commit transaction"); + } + match sqlx::query_as_unchecked!( Reminder, "SELECT reminders.attachment, diff --git a/web/src/routes/dashboard/mod.rs b/web/src/routes/dashboard/mod.rs index 40069e8..46ff4e0 100644 --- a/web/src/routes/dashboard/mod.rs +++ b/web/src/routes/dashboard/mod.rs @@ -10,7 +10,7 @@ use serenity::{ http::Http, model::id::{ChannelId, GuildId, UserId}, }; -use sqlx::{types::Json, Executor}; +use sqlx::types::Json; use crate::{ check_guild_subscription, check_subscription, @@ -20,8 +20,9 @@ use crate::{ MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_NAME_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL, }, + guards::transaction::Transaction, routes::JsonResult, - Database, Error, + Error, }; pub mod export; @@ -353,21 +354,21 @@ pub struct TodoCsv { channel_id: Option, } -pub async fn create_reminder( +pub(crate) async fn create_reminder( ctx: &Context, - pool: impl sqlx::Executor<'_, Database = Database> + Copy, + transaction: &mut Transaction<'_>, guild_id: GuildId, user_id: UserId, reminder: Reminder, ) -> JsonResult { // check guild in db match sqlx::query!("SELECT 1 as A FROM guilds WHERE guild = ?", guild_id.0) - .fetch_one(pool) + .fetch_one(transaction.executor()) .await { Err(sqlx::Error::RowNotFound) => { if sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id.0) - .execute(pool) + .execute(transaction.executor()) .await .is_err() { @@ -393,7 +394,7 @@ pub async fn create_reminder( return Err(json!({"error": "Channel not found"})); } - let channel = create_database_channel(&ctx, ChannelId(reminder.channel), pool).await; + let channel = create_database_channel(&ctx, ChannelId(reminder.channel), transaction).await; if let Err(e) = channel { warn!("`create_database_channel` returned an error code: {:?}", e); @@ -535,7 +536,7 @@ pub async fn create_reminder( username, reminder.utc_time, ) - .execute(pool) + .execute(transaction.executor()) .await { Ok(_) => sqlx::query_as_unchecked!( @@ -572,7 +573,7 @@ pub async fn create_reminder( WHERE uid = ?", new_uid ) - .fetch_one(pool) + .fetch_one(transaction.executor()) .await .map(|r| Ok(json!(r))) .unwrap_or_else(|e| { @@ -592,11 +593,11 @@ pub async fn create_reminder( async fn create_database_channel( ctx: impl AsRef, channel: ChannelId, - pool: impl Executor<'_, Database = Database> + Copy, + transaction: &mut Transaction<'_>, ) -> Result { let row = sqlx::query!("SELECT webhook_token, webhook_id FROM channels WHERE channel = ?", channel.0) - .fetch_one(pool) + .fetch_one(transaction.executor()) .await; match row { @@ -613,7 +614,7 @@ async fn create_database_channel( webhook.token, channel.0 ) - .execute(pool) + .execute(transaction.executor()) .await .map_err(|e| Error::SQLx(e))?; } @@ -639,7 +640,7 @@ async fn create_database_channel( webhook.token, channel.0 ) - .execute(pool) + .execute(transaction.executor()) .await .map_err(|e| Error::SQLx(e))?; @@ -650,7 +651,7 @@ async fn create_database_channel( }?; let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.0) - .fetch_one(pool) + .fetch_one(transaction.executor()) .await .map_err(|e| Error::SQLx(e))?;