use rocket::{ http::Status, request::{FromRequest, Outcome}, Request, State, }; use sqlx::Pool; use crate::Database; pub struct Transaction<'a>(sqlx::Transaction<'a, Database>); impl Transaction<'_> { pub fn executor(&mut self) -> impl sqlx::Executor<'_, Database = Database> { &mut *(self.0) } pub async fn commit(self) -> Result<(), sqlx::Error> { self.0.commit().await } } #[derive(Debug)] #[allow(dead_code)] pub enum TransactionError { Error(sqlx::Error), Missing, } #[rocket::async_trait] impl<'r> FromRequest<'r> for Transaction<'r> { type Error = TransactionError; async fn from_request(request: &'r Request<'_>) -> Outcome { match request.guard::<&State>>().await { Outcome::Success(pool) => match pool.begin().await { Ok(transaction) => Outcome::Success(Transaction(transaction)), Err(e) => Outcome::Error((Status::InternalServerError, TransactionError::Error(e))), }, Outcome::Error(e) => Outcome::Error((e.0, TransactionError::Missing)), Outcome::Forward(f) => Outcome::Forward(f), } } }