Use transactions for certain routes

This commit is contained in:
jude 2023-09-24 13:57:27 +01:00
parent 4bad1324b9
commit 25286da5e0
4 changed files with 77 additions and 37 deletions

View File

@ -9,6 +9,16 @@ use crate::Database;
pub(crate) struct Transaction<'a>(sqlx::Transaction<'a, Database>); pub(crate) struct Transaction<'a>(sqlx::Transaction<'a, Database>);
impl Transaction<'_> {
pub(crate) fn executor(&mut self) -> impl sqlx::Executor<'_, Database = Database> {
&mut *(self.0)
}
pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
self.0.commit().await
}
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum TransactionError { pub(crate) enum TransactionError {
Error(sqlx::Error), Error(sqlx::Error),

View File

@ -10,12 +10,15 @@ use serenity::{
}; };
use sqlx::{MySql, Pool}; use sqlx::{MySql, Pool};
use crate::routes::{ use crate::{
guards::transaction::Transaction,
routes::{
dashboard::{ dashboard::{
create_reminder, generate_uid, ImportBody, Reminder, ReminderCsv, ReminderTemplateCsv, create_reminder, generate_uid, ImportBody, Reminder, ReminderCsv, ReminderTemplateCsv,
TodoCsv, TodoCsv,
}, },
JsonResult, JsonResult,
},
}; };
#[get("/api/guild/<id>/export/reminders")] #[get("/api/guild/<id>/export/reminders")]
@ -118,12 +121,12 @@ pub async fn export_reminders(
} }
#[put("/api/guild/<id>/export/reminders", data = "<body>")] #[put("/api/guild/<id>/export/reminders", data = "<body>")]
pub async fn import_reminders( pub(crate) async fn import_reminders(
id: u64, id: u64,
cookies: &CookieJar<'_>, cookies: &CookieJar<'_>,
body: Json<ImportBody>, body: Json<ImportBody>,
ctx: &State<Context>, ctx: &State<Context>,
pool: &State<Pool<MySql>>, mut transaction: Transaction<'_>,
) -> JsonResult { ) -> JsonResult {
check_authorization!(cookies, ctx.inner(), id); check_authorization!(cookies, ctx.inner(), id);
@ -175,7 +178,7 @@ pub async fn import_reminders(
create_reminder( create_reminder(
ctx.inner(), ctx.inner(),
pool.inner(), &mut transaction,
GuildId(id), GuildId(id),
UserId(user_id), UserId(user_id),
reminder, reminder,
@ -200,7 +203,14 @@ pub async fn import_reminders(
} }
} }
Ok(json!({})) match transaction.commit().await {
Ok(_) => Ok(json!({})),
Err(e) => {
warn!("Failed to commit transaction: {:?}", e);
json_err!("Couldn't commit transaction")
}
}
} }
Err(_) => { Err(_) => {

View File

@ -23,6 +23,7 @@ use crate::{
MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH,
MIN_INTERVAL, MIN_INTERVAL,
}, },
guards::transaction::Transaction,
routes::{ routes::{
dashboard::{ dashboard::{
create_database_channel, create_reminder, template_name_default, DeleteReminder, create_database_channel, create_reminder, template_name_default, DeleteReminder,
@ -30,6 +31,7 @@ use crate::{
}, },
JsonResult, JsonResult,
}, },
Database,
}; };
#[derive(Serialize)] #[derive(Serialize)]
@ -302,26 +304,37 @@ pub async fn delete_reminder_template(
} }
#[post("/api/guild/<id>/reminders", data = "<reminder>")] #[post("/api/guild/<id>/reminders", data = "<reminder>")]
pub async fn create_guild_reminder( pub(crate) async fn create_guild_reminder(
id: u64, id: u64,
reminder: Json<Reminder>, reminder: Json<Reminder>,
cookies: &CookieJar<'_>, cookies: &CookieJar<'_>,
serenity_context: &State<Context>, serenity_context: &State<Context>,
pool: &State<Pool<MySql>>, mut transaction: Transaction<'_>,
) -> JsonResult { ) -> JsonResult {
check_authorization!(cookies, serenity_context.inner(), id); check_authorization!(cookies, serenity_context.inner(), id);
let user_id = let user_id =
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap(); cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
create_reminder( match create_reminder(
serenity_context.inner(), serenity_context.inner(),
pool.inner(), &mut transaction,
GuildId(id), GuildId(id),
UserId(user_id), UserId(user_id),
reminder.into_inner(), reminder.into_inner(),
) )
.await .await
{
Ok(r) => match transaction.commit().await {
Ok(_) => Ok(r),
Err(e) => {
warn!("Could'nt commit transaction: {:?}", e);
json_err!("Couldn't commit transaction.")
}
},
Err(e) => Err(e),
}
} }
#[get("/api/guild/<id>/reminders")] #[get("/api/guild/<id>/reminders")]
@ -397,11 +410,12 @@ pub async fn get_reminders(
} }
#[patch("/api/guild/<id>/reminders", data = "<reminder>")] #[patch("/api/guild/<id>/reminders", data = "<reminder>")]
pub async fn edit_reminder( pub(crate) async fn edit_reminder(
id: u64, id: u64,
reminder: Json<PatchReminder>, reminder: Json<PatchReminder>,
serenity_context: &State<Context>, serenity_context: &State<Context>,
pool: &State<Pool<MySql>>, mut transaction: Transaction<'_>,
pool: &State<Pool<Database>>,
cookies: &CookieJar<'_>, cookies: &CookieJar<'_>,
) -> JsonResult { ) -> JsonResult {
check_authorization!(cookies, serenity_context.inner(), id); check_authorization!(cookies, serenity_context.inner(), id);
@ -412,7 +426,7 @@ pub async fn edit_reminder(
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap(); cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
if reminder.message_ok() { if reminder.message_ok() {
update_field!(pool.inner(), error, reminder.[ update_field!(transaction.executor(), error, reminder.[
content, content,
embed_author, embed_author,
embed_description, embed_description,
@ -425,7 +439,7 @@ pub async fn edit_reminder(
error.push("Message exceeds limits.".to_string()); error.push("Message exceeds limits.".to_string());
} }
update_field!(pool.inner(), error, reminder.[ update_field!(transaction.executor(), error, reminder.[
attachment, attachment,
attachment_name, attachment_name,
avatar, avatar,
@ -455,7 +469,7 @@ pub async fn edit_reminder(
"SELECT interval_days AS days FROM reminders WHERE uid = ?", "SELECT interval_days AS days FROM reminders WHERE uid = ?",
reminder.uid reminder.uid
) )
.fetch_one(pool.inner()) .fetch_one(transaction.executor())
.await .await
.map_err(|e| { .map_err(|e| {
warn!("Error updating reminder interval: {:?}", e); warn!("Error updating reminder interval: {:?}", e);
@ -469,7 +483,7 @@ pub async fn edit_reminder(
"SELECT interval_months AS months FROM reminders WHERE uid = ?", "SELECT interval_months AS months FROM reminders WHERE uid = ?",
reminder.uid reminder.uid
) )
.fetch_one(pool.inner()) .fetch_one(transaction.executor())
.await .await
.map_err(|e| { .map_err(|e| {
warn!("Error updating reminder interval: {:?}", e); warn!("Error updating reminder interval: {:?}", e);
@ -483,7 +497,7 @@ pub async fn edit_reminder(
"SELECT interval_seconds AS seconds FROM reminders WHERE uid = ?", "SELECT interval_seconds AS seconds FROM reminders WHERE uid = ?",
reminder.uid reminder.uid
) )
.fetch_one(pool.inner()) .fetch_one(transaction.executor())
.await .await
.map_err(|e| { .map_err(|e| {
warn!("Error updating reminder interval: {:?}", e); warn!("Error updating reminder interval: {:?}", e);
@ -496,7 +510,7 @@ pub async fn edit_reminder(
if new_interval_length < *MIN_INTERVAL { if new_interval_length < *MIN_INTERVAL {
error.push(String::from("New interval is too short.")); error.push(String::from("New interval is too short."));
} else { } else {
update_field!(pool.inner(), error, reminder.[ update_field!(transaction.executor(), error, reminder.[
interval_days, interval_days,
interval_months, interval_months,
interval_seconds interval_seconds
@ -523,7 +537,7 @@ pub async fn edit_reminder(
let channel = create_database_channel( let channel = create_database_channel(
serenity_context.inner(), serenity_context.inner(),
ChannelId(reminder.channel), ChannelId(reminder.channel),
pool.inner(), &mut transaction,
) )
.await; .await;
@ -542,7 +556,7 @@ pub async fn edit_reminder(
channel, channel,
reminder.uid reminder.uid
) )
.execute(pool.inner()) .execute(transaction.executor())
.await .await
{ {
Ok(_) => {} Ok(_) => {}
@ -565,6 +579,11 @@ pub async fn edit_reminder(
} }
} }
if let Err(e) = transaction.commit().await {
warn!("Couldn't commit transaction: {:?}", e);
return json_err!("Couldn't commit transaction");
}
match sqlx::query_as_unchecked!( match sqlx::query_as_unchecked!(
Reminder, Reminder,
"SELECT reminders.attachment, "SELECT reminders.attachment,

View File

@ -10,7 +10,7 @@ use serenity::{
http::Http, http::Http,
model::id::{ChannelId, GuildId, UserId}, model::id::{ChannelId, GuildId, UserId},
}; };
use sqlx::{types::Json, Executor}; use sqlx::types::Json;
use crate::{ use crate::{
check_guild_subscription, check_subscription, check_guild_subscription, check_subscription,
@ -20,8 +20,9 @@ use crate::{
MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH,
MAX_NAME_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL, MAX_NAME_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL,
}, },
guards::transaction::Transaction,
routes::JsonResult, routes::JsonResult,
Database, Error, Error,
}; };
pub mod export; pub mod export;
@ -353,21 +354,21 @@ pub struct TodoCsv {
channel_id: Option<String>, channel_id: Option<String>,
} }
pub async fn create_reminder( pub(crate) async fn create_reminder(
ctx: &Context, ctx: &Context,
pool: impl sqlx::Executor<'_, Database = Database> + Copy, transaction: &mut Transaction<'_>,
guild_id: GuildId, guild_id: GuildId,
user_id: UserId, user_id: UserId,
reminder: Reminder, reminder: Reminder,
) -> JsonResult { ) -> JsonResult {
// check guild in db // check guild in db
match sqlx::query!("SELECT 1 as A FROM guilds WHERE guild = ?", guild_id.0) match sqlx::query!("SELECT 1 as A FROM guilds WHERE guild = ?", guild_id.0)
.fetch_one(pool) .fetch_one(transaction.executor())
.await .await
{ {
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {
if sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id.0) if sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id.0)
.execute(pool) .execute(transaction.executor())
.await .await
.is_err() .is_err()
{ {
@ -393,7 +394,7 @@ pub async fn create_reminder(
return Err(json!({"error": "Channel not found"})); return Err(json!({"error": "Channel not found"}));
} }
let channel = create_database_channel(&ctx, ChannelId(reminder.channel), pool).await; let channel = create_database_channel(&ctx, ChannelId(reminder.channel), transaction).await;
if let Err(e) = channel { if let Err(e) = channel {
warn!("`create_database_channel` returned an error code: {:?}", e); warn!("`create_database_channel` returned an error code: {:?}", e);
@ -535,7 +536,7 @@ pub async fn create_reminder(
username, username,
reminder.utc_time, reminder.utc_time,
) )
.execute(pool) .execute(transaction.executor())
.await .await
{ {
Ok(_) => sqlx::query_as_unchecked!( Ok(_) => sqlx::query_as_unchecked!(
@ -572,7 +573,7 @@ pub async fn create_reminder(
WHERE uid = ?", WHERE uid = ?",
new_uid new_uid
) )
.fetch_one(pool) .fetch_one(transaction.executor())
.await .await
.map(|r| Ok(json!(r))) .map(|r| Ok(json!(r)))
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
@ -592,11 +593,11 @@ pub async fn create_reminder(
async fn create_database_channel( async fn create_database_channel(
ctx: impl AsRef<Http>, ctx: impl AsRef<Http>,
channel: ChannelId, channel: ChannelId,
pool: impl Executor<'_, Database = Database> + Copy, transaction: &mut Transaction<'_>,
) -> Result<u32, crate::Error> { ) -> Result<u32, crate::Error> {
let row = let row =
sqlx::query!("SELECT webhook_token, webhook_id FROM channels WHERE channel = ?", channel.0) sqlx::query!("SELECT webhook_token, webhook_id FROM channels WHERE channel = ?", channel.0)
.fetch_one(pool) .fetch_one(transaction.executor())
.await; .await;
match row { match row {
@ -613,7 +614,7 @@ async fn create_database_channel(
webhook.token, webhook.token,
channel.0 channel.0
) )
.execute(pool) .execute(transaction.executor())
.await .await
.map_err(|e| Error::SQLx(e))?; .map_err(|e| Error::SQLx(e))?;
} }
@ -639,7 +640,7 @@ async fn create_database_channel(
webhook.token, webhook.token,
channel.0 channel.0
) )
.execute(pool) .execute(transaction.executor())
.await .await
.map_err(|e| Error::SQLx(e))?; .map_err(|e| Error::SQLx(e))?;
@ -650,7 +651,7 @@ async fn create_database_channel(
}?; }?;
let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.0) let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.0)
.fetch_one(pool) .fetch_one(transaction.executor())
.await .await
.map_err(|e| Error::SQLx(e))?; .map_err(|e| Error::SQLx(e))?;