Use transactions for certain routes

This commit is contained in:
jude 2023-09-24 13:57:27 +01:00
parent 4bad1324b9
commit 25286da5e0
4 changed files with 77 additions and 37 deletions

View File

@ -9,6 +9,16 @@ use crate::Database;
pub(crate) struct Transaction<'a>(sqlx::Transaction<'a, Database>);
impl Transaction<'_> {
pub(crate) fn executor(&mut self) -> impl sqlx::Executor<'_, Database = Database> {
&mut *(self.0)
}
pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
self.0.commit().await
}
}
#[derive(Debug)]
pub(crate) enum TransactionError {
Error(sqlx::Error),

View File

@ -10,12 +10,15 @@ use serenity::{
};
use sqlx::{MySql, Pool};
use crate::routes::{
use crate::{
guards::transaction::Transaction,
routes::{
dashboard::{
create_reminder, generate_uid, ImportBody, Reminder, ReminderCsv, ReminderTemplateCsv,
TodoCsv,
},
JsonResult,
},
};
#[get("/api/guild/<id>/export/reminders")]
@ -118,12 +121,12 @@ pub async fn export_reminders(
}
#[put("/api/guild/<id>/export/reminders", data = "<body>")]
pub async fn import_reminders(
pub(crate) async fn import_reminders(
id: u64,
cookies: &CookieJar<'_>,
body: Json<ImportBody>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
mut transaction: Transaction<'_>,
) -> JsonResult {
check_authorization!(cookies, ctx.inner(), id);
@ -175,7 +178,7 @@ pub async fn import_reminders(
create_reminder(
ctx.inner(),
pool.inner(),
&mut transaction,
GuildId(id),
UserId(user_id),
reminder,
@ -200,7 +203,14 @@ pub async fn import_reminders(
}
}
Ok(json!({}))
match transaction.commit().await {
Ok(_) => Ok(json!({})),
Err(e) => {
warn!("Failed to commit transaction: {:?}", e);
json_err!("Couldn't commit transaction")
}
}
}
Err(_) => {

View File

@ -23,6 +23,7 @@ use crate::{
MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH,
MIN_INTERVAL,
},
guards::transaction::Transaction,
routes::{
dashboard::{
create_database_channel, create_reminder, template_name_default, DeleteReminder,
@ -30,6 +31,7 @@ use crate::{
},
JsonResult,
},
Database,
};
#[derive(Serialize)]
@ -302,26 +304,37 @@ pub async fn delete_reminder_template(
}
#[post("/api/guild/<id>/reminders", data = "<reminder>")]
pub async fn create_guild_reminder(
pub(crate) async fn create_guild_reminder(
id: u64,
reminder: Json<Reminder>,
cookies: &CookieJar<'_>,
serenity_context: &State<Context>,
pool: &State<Pool<MySql>>,
mut transaction: Transaction<'_>,
) -> JsonResult {
check_authorization!(cookies, serenity_context.inner(), id);
let user_id =
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
create_reminder(
match create_reminder(
serenity_context.inner(),
pool.inner(),
&mut transaction,
GuildId(id),
UserId(user_id),
reminder.into_inner(),
)
.await
{
Ok(r) => match transaction.commit().await {
Ok(_) => Ok(r),
Err(e) => {
warn!("Could'nt commit transaction: {:?}", e);
json_err!("Couldn't commit transaction.")
}
},
Err(e) => Err(e),
}
}
#[get("/api/guild/<id>/reminders")]
@ -397,11 +410,12 @@ pub async fn get_reminders(
}
#[patch("/api/guild/<id>/reminders", data = "<reminder>")]
pub async fn edit_reminder(
pub(crate) async fn edit_reminder(
id: u64,
reminder: Json<PatchReminder>,
serenity_context: &State<Context>,
pool: &State<Pool<MySql>>,
mut transaction: Transaction<'_>,
pool: &State<Pool<Database>>,
cookies: &CookieJar<'_>,
) -> JsonResult {
check_authorization!(cookies, serenity_context.inner(), id);
@ -412,7 +426,7 @@ pub async fn edit_reminder(
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
if reminder.message_ok() {
update_field!(pool.inner(), error, reminder.[
update_field!(transaction.executor(), error, reminder.[
content,
embed_author,
embed_description,
@ -425,7 +439,7 @@ pub async fn edit_reminder(
error.push("Message exceeds limits.".to_string());
}
update_field!(pool.inner(), error, reminder.[
update_field!(transaction.executor(), error, reminder.[
attachment,
attachment_name,
avatar,
@ -455,7 +469,7 @@ pub async fn edit_reminder(
"SELECT interval_days AS days FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(pool.inner())
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
@ -469,7 +483,7 @@ pub async fn edit_reminder(
"SELECT interval_months AS months FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(pool.inner())
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
@ -483,7 +497,7 @@ pub async fn edit_reminder(
"SELECT interval_seconds AS seconds FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(pool.inner())
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
@ -496,7 +510,7 @@ pub async fn edit_reminder(
if new_interval_length < *MIN_INTERVAL {
error.push(String::from("New interval is too short."));
} else {
update_field!(pool.inner(), error, reminder.[
update_field!(transaction.executor(), error, reminder.[
interval_days,
interval_months,
interval_seconds
@ -523,7 +537,7 @@ pub async fn edit_reminder(
let channel = create_database_channel(
serenity_context.inner(),
ChannelId(reminder.channel),
pool.inner(),
&mut transaction,
)
.await;
@ -542,7 +556,7 @@ pub async fn edit_reminder(
channel,
reminder.uid
)
.execute(pool.inner())
.execute(transaction.executor())
.await
{
Ok(_) => {}
@ -565,6 +579,11 @@ pub async fn edit_reminder(
}
}
if let Err(e) = transaction.commit().await {
warn!("Couldn't commit transaction: {:?}", e);
return json_err!("Couldn't commit transaction");
}
match sqlx::query_as_unchecked!(
Reminder,
"SELECT reminders.attachment,

View File

@ -10,7 +10,7 @@ use serenity::{
http::Http,
model::id::{ChannelId, GuildId, UserId},
};
use sqlx::{types::Json, Executor};
use sqlx::types::Json;
use crate::{
check_guild_subscription, check_subscription,
@ -20,8 +20,9 @@ use crate::{
MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH,
MAX_NAME_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL,
},
guards::transaction::Transaction,
routes::JsonResult,
Database, Error,
Error,
};
pub mod export;
@ -353,21 +354,21 @@ pub struct TodoCsv {
channel_id: Option<String>,
}
pub async fn create_reminder(
pub(crate) async fn create_reminder(
ctx: &Context,
pool: impl sqlx::Executor<'_, Database = Database> + Copy,
transaction: &mut Transaction<'_>,
guild_id: GuildId,
user_id: UserId,
reminder: Reminder,
) -> JsonResult {
// check guild in db
match sqlx::query!("SELECT 1 as A FROM guilds WHERE guild = ?", guild_id.0)
.fetch_one(pool)
.fetch_one(transaction.executor())
.await
{
Err(sqlx::Error::RowNotFound) => {
if sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id.0)
.execute(pool)
.execute(transaction.executor())
.await
.is_err()
{
@ -393,7 +394,7 @@ pub async fn create_reminder(
return Err(json!({"error": "Channel not found"}));
}
let channel = create_database_channel(&ctx, ChannelId(reminder.channel), pool).await;
let channel = create_database_channel(&ctx, ChannelId(reminder.channel), transaction).await;
if let Err(e) = channel {
warn!("`create_database_channel` returned an error code: {:?}", e);
@ -535,7 +536,7 @@ pub async fn create_reminder(
username,
reminder.utc_time,
)
.execute(pool)
.execute(transaction.executor())
.await
{
Ok(_) => sqlx::query_as_unchecked!(
@ -572,7 +573,7 @@ pub async fn create_reminder(
WHERE uid = ?",
new_uid
)
.fetch_one(pool)
.fetch_one(transaction.executor())
.await
.map(|r| Ok(json!(r)))
.unwrap_or_else(|e| {
@ -592,11 +593,11 @@ pub async fn create_reminder(
async fn create_database_channel(
ctx: impl AsRef<Http>,
channel: ChannelId,
pool: impl Executor<'_, Database = Database> + Copy,
transaction: &mut Transaction<'_>,
) -> Result<u32, crate::Error> {
let row =
sqlx::query!("SELECT webhook_token, webhook_id FROM channels WHERE channel = ?", channel.0)
.fetch_one(pool)
.fetch_one(transaction.executor())
.await;
match row {
@ -613,7 +614,7 @@ async fn create_database_channel(
webhook.token,
channel.0
)
.execute(pool)
.execute(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;
}
@ -639,7 +640,7 @@ async fn create_database_channel(
webhook.token,
channel.0
)
.execute(pool)
.execute(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;
@ -650,7 +651,7 @@ async fn create_database_channel(
}?;
let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.0)
.fetch_one(pool)
.fetch_one(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;