Use transactions for certain routes
This commit is contained in:
parent
4bad1324b9
commit
25286da5e0
@ -9,6 +9,16 @@ use crate::Database;
|
||||
|
||||
pub(crate) struct Transaction<'a>(sqlx::Transaction<'a, Database>);
|
||||
|
||||
impl Transaction<'_> {
|
||||
pub(crate) fn executor(&mut self) -> impl sqlx::Executor<'_, Database = Database> {
|
||||
&mut *(self.0)
|
||||
}
|
||||
|
||||
pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
|
||||
self.0.commit().await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum TransactionError {
|
||||
Error(sqlx::Error),
|
||||
|
@ -10,12 +10,15 @@ use serenity::{
|
||||
};
|
||||
use sqlx::{MySql, Pool};
|
||||
|
||||
use crate::routes::{
|
||||
use crate::{
|
||||
guards::transaction::Transaction,
|
||||
routes::{
|
||||
dashboard::{
|
||||
create_reminder, generate_uid, ImportBody, Reminder, ReminderCsv, ReminderTemplateCsv,
|
||||
TodoCsv,
|
||||
},
|
||||
JsonResult,
|
||||
},
|
||||
};
|
||||
|
||||
#[get("/api/guild/<id>/export/reminders")]
|
||||
@ -118,12 +121,12 @@ pub async fn export_reminders(
|
||||
}
|
||||
|
||||
#[put("/api/guild/<id>/export/reminders", data = "<body>")]
|
||||
pub async fn import_reminders(
|
||||
pub(crate) async fn import_reminders(
|
||||
id: u64,
|
||||
cookies: &CookieJar<'_>,
|
||||
body: Json<ImportBody>,
|
||||
ctx: &State<Context>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
mut transaction: Transaction<'_>,
|
||||
) -> JsonResult {
|
||||
check_authorization!(cookies, ctx.inner(), id);
|
||||
|
||||
@ -175,7 +178,7 @@ pub async fn import_reminders(
|
||||
|
||||
create_reminder(
|
||||
ctx.inner(),
|
||||
pool.inner(),
|
||||
&mut transaction,
|
||||
GuildId(id),
|
||||
UserId(user_id),
|
||||
reminder,
|
||||
@ -200,7 +203,14 @@ pub async fn import_reminders(
|
||||
}
|
||||
}
|
||||
|
||||
Ok(json!({}))
|
||||
match transaction.commit().await {
|
||||
Ok(_) => Ok(json!({})),
|
||||
|
||||
Err(e) => {
|
||||
warn!("Failed to commit transaction: {:?}", e);
|
||||
json_err!("Couldn't commit transaction")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(_) => {
|
||||
|
@ -23,6 +23,7 @@ use crate::{
|
||||
MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH,
|
||||
MIN_INTERVAL,
|
||||
},
|
||||
guards::transaction::Transaction,
|
||||
routes::{
|
||||
dashboard::{
|
||||
create_database_channel, create_reminder, template_name_default, DeleteReminder,
|
||||
@ -30,6 +31,7 @@ use crate::{
|
||||
},
|
||||
JsonResult,
|
||||
},
|
||||
Database,
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
@ -302,26 +304,37 @@ pub async fn delete_reminder_template(
|
||||
}
|
||||
|
||||
#[post("/api/guild/<id>/reminders", data = "<reminder>")]
|
||||
pub async fn create_guild_reminder(
|
||||
pub(crate) async fn create_guild_reminder(
|
||||
id: u64,
|
||||
reminder: Json<Reminder>,
|
||||
cookies: &CookieJar<'_>,
|
||||
serenity_context: &State<Context>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
mut transaction: Transaction<'_>,
|
||||
) -> JsonResult {
|
||||
check_authorization!(cookies, serenity_context.inner(), id);
|
||||
|
||||
let user_id =
|
||||
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
|
||||
|
||||
create_reminder(
|
||||
match create_reminder(
|
||||
serenity_context.inner(),
|
||||
pool.inner(),
|
||||
&mut transaction,
|
||||
GuildId(id),
|
||||
UserId(user_id),
|
||||
reminder.into_inner(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => match transaction.commit().await {
|
||||
Ok(_) => Ok(r),
|
||||
Err(e) => {
|
||||
warn!("Could'nt commit transaction: {:?}", e);
|
||||
json_err!("Couldn't commit transaction.")
|
||||
}
|
||||
},
|
||||
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/api/guild/<id>/reminders")]
|
||||
@ -397,11 +410,12 @@ pub async fn get_reminders(
|
||||
}
|
||||
|
||||
#[patch("/api/guild/<id>/reminders", data = "<reminder>")]
|
||||
pub async fn edit_reminder(
|
||||
pub(crate) async fn edit_reminder(
|
||||
id: u64,
|
||||
reminder: Json<PatchReminder>,
|
||||
serenity_context: &State<Context>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
mut transaction: Transaction<'_>,
|
||||
pool: &State<Pool<Database>>,
|
||||
cookies: &CookieJar<'_>,
|
||||
) -> JsonResult {
|
||||
check_authorization!(cookies, serenity_context.inner(), id);
|
||||
@ -412,7 +426,7 @@ pub async fn edit_reminder(
|
||||
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
|
||||
|
||||
if reminder.message_ok() {
|
||||
update_field!(pool.inner(), error, reminder.[
|
||||
update_field!(transaction.executor(), error, reminder.[
|
||||
content,
|
||||
embed_author,
|
||||
embed_description,
|
||||
@ -425,7 +439,7 @@ pub async fn edit_reminder(
|
||||
error.push("Message exceeds limits.".to_string());
|
||||
}
|
||||
|
||||
update_field!(pool.inner(), error, reminder.[
|
||||
update_field!(transaction.executor(), error, reminder.[
|
||||
attachment,
|
||||
attachment_name,
|
||||
avatar,
|
||||
@ -455,7 +469,7 @@ pub async fn edit_reminder(
|
||||
"SELECT interval_days AS days FROM reminders WHERE uid = ?",
|
||||
reminder.uid
|
||||
)
|
||||
.fetch_one(pool.inner())
|
||||
.fetch_one(transaction.executor())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("Error updating reminder interval: {:?}", e);
|
||||
@ -469,7 +483,7 @@ pub async fn edit_reminder(
|
||||
"SELECT interval_months AS months FROM reminders WHERE uid = ?",
|
||||
reminder.uid
|
||||
)
|
||||
.fetch_one(pool.inner())
|
||||
.fetch_one(transaction.executor())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("Error updating reminder interval: {:?}", e);
|
||||
@ -483,7 +497,7 @@ pub async fn edit_reminder(
|
||||
"SELECT interval_seconds AS seconds FROM reminders WHERE uid = ?",
|
||||
reminder.uid
|
||||
)
|
||||
.fetch_one(pool.inner())
|
||||
.fetch_one(transaction.executor())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("Error updating reminder interval: {:?}", e);
|
||||
@ -496,7 +510,7 @@ pub async fn edit_reminder(
|
||||
if new_interval_length < *MIN_INTERVAL {
|
||||
error.push(String::from("New interval is too short."));
|
||||
} else {
|
||||
update_field!(pool.inner(), error, reminder.[
|
||||
update_field!(transaction.executor(), error, reminder.[
|
||||
interval_days,
|
||||
interval_months,
|
||||
interval_seconds
|
||||
@ -523,7 +537,7 @@ pub async fn edit_reminder(
|
||||
let channel = create_database_channel(
|
||||
serenity_context.inner(),
|
||||
ChannelId(reminder.channel),
|
||||
pool.inner(),
|
||||
&mut transaction,
|
||||
)
|
||||
.await;
|
||||
|
||||
@ -542,7 +556,7 @@ pub async fn edit_reminder(
|
||||
channel,
|
||||
reminder.uid
|
||||
)
|
||||
.execute(pool.inner())
|
||||
.execute(transaction.executor())
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
@ -565,6 +579,11 @@ pub async fn edit_reminder(
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = transaction.commit().await {
|
||||
warn!("Couldn't commit transaction: {:?}", e);
|
||||
return json_err!("Couldn't commit transaction");
|
||||
}
|
||||
|
||||
match sqlx::query_as_unchecked!(
|
||||
Reminder,
|
||||
"SELECT reminders.attachment,
|
||||
|
@ -10,7 +10,7 @@ use serenity::{
|
||||
http::Http,
|
||||
model::id::{ChannelId, GuildId, UserId},
|
||||
};
|
||||
use sqlx::{types::Json, Executor};
|
||||
use sqlx::types::Json;
|
||||
|
||||
use crate::{
|
||||
check_guild_subscription, check_subscription,
|
||||
@ -20,8 +20,9 @@ use crate::{
|
||||
MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH,
|
||||
MAX_NAME_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL,
|
||||
},
|
||||
guards::transaction::Transaction,
|
||||
routes::JsonResult,
|
||||
Database, Error,
|
||||
Error,
|
||||
};
|
||||
|
||||
pub mod export;
|
||||
@ -353,21 +354,21 @@ pub struct TodoCsv {
|
||||
channel_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn create_reminder(
|
||||
pub(crate) async fn create_reminder(
|
||||
ctx: &Context,
|
||||
pool: impl sqlx::Executor<'_, Database = Database> + Copy,
|
||||
transaction: &mut Transaction<'_>,
|
||||
guild_id: GuildId,
|
||||
user_id: UserId,
|
||||
reminder: Reminder,
|
||||
) -> JsonResult {
|
||||
// check guild in db
|
||||
match sqlx::query!("SELECT 1 as A FROM guilds WHERE guild = ?", guild_id.0)
|
||||
.fetch_one(pool)
|
||||
.fetch_one(transaction.executor())
|
||||
.await
|
||||
{
|
||||
Err(sqlx::Error::RowNotFound) => {
|
||||
if sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id.0)
|
||||
.execute(pool)
|
||||
.execute(transaction.executor())
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
@ -393,7 +394,7 @@ pub async fn create_reminder(
|
||||
return Err(json!({"error": "Channel not found"}));
|
||||
}
|
||||
|
||||
let channel = create_database_channel(&ctx, ChannelId(reminder.channel), pool).await;
|
||||
let channel = create_database_channel(&ctx, ChannelId(reminder.channel), transaction).await;
|
||||
|
||||
if let Err(e) = channel {
|
||||
warn!("`create_database_channel` returned an error code: {:?}", e);
|
||||
@ -535,7 +536,7 @@ pub async fn create_reminder(
|
||||
username,
|
||||
reminder.utc_time,
|
||||
)
|
||||
.execute(pool)
|
||||
.execute(transaction.executor())
|
||||
.await
|
||||
{
|
||||
Ok(_) => sqlx::query_as_unchecked!(
|
||||
@ -572,7 +573,7 @@ pub async fn create_reminder(
|
||||
WHERE uid = ?",
|
||||
new_uid
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.fetch_one(transaction.executor())
|
||||
.await
|
||||
.map(|r| Ok(json!(r)))
|
||||
.unwrap_or_else(|e| {
|
||||
@ -592,11 +593,11 @@ pub async fn create_reminder(
|
||||
async fn create_database_channel(
|
||||
ctx: impl AsRef<Http>,
|
||||
channel: ChannelId,
|
||||
pool: impl Executor<'_, Database = Database> + Copy,
|
||||
transaction: &mut Transaction<'_>,
|
||||
) -> Result<u32, crate::Error> {
|
||||
let row =
|
||||
sqlx::query!("SELECT webhook_token, webhook_id FROM channels WHERE channel = ?", channel.0)
|
||||
.fetch_one(pool)
|
||||
.fetch_one(transaction.executor())
|
||||
.await;
|
||||
|
||||
match row {
|
||||
@ -613,7 +614,7 @@ async fn create_database_channel(
|
||||
webhook.token,
|
||||
channel.0
|
||||
)
|
||||
.execute(pool)
|
||||
.execute(transaction.executor())
|
||||
.await
|
||||
.map_err(|e| Error::SQLx(e))?;
|
||||
}
|
||||
@ -639,7 +640,7 @@ async fn create_database_channel(
|
||||
webhook.token,
|
||||
channel.0
|
||||
)
|
||||
.execute(pool)
|
||||
.execute(transaction.executor())
|
||||
.await
|
||||
.map_err(|e| Error::SQLx(e))?;
|
||||
|
||||
@ -650,7 +651,7 @@ async fn create_database_channel(
|
||||
}?;
|
||||
|
||||
let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.0)
|
||||
.fetch_one(pool)
|
||||
.fetch_one(transaction.executor())
|
||||
.await
|
||||
.map_err(|e| Error::SQLx(e))?;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user