From 4bad1324b9ec4e9ab7ced61374b469536cac82fa Mon Sep 17 00:00:00 2001 From: jude Date: Sun, 24 Sep 2023 13:11:53 +0100 Subject: [PATCH] Restructure Move some code out to other files. Add transaction guard --- web/src/catchers.rs | 40 +++++++++++++++++++++++++ web/src/guards/mod.rs | 1 + web/src/guards/transaction.rs | 34 +++++++++++++++++++++ web/src/lib.rs | 56 +++++++---------------------------- 4 files changed, 85 insertions(+), 46 deletions(-) create mode 100644 web/src/catchers.rs create mode 100644 web/src/guards/mod.rs create mode 100644 web/src/guards/transaction.rs diff --git a/web/src/catchers.rs b/web/src/catchers.rs new file mode 100644 index 0000000..e646edf --- /dev/null +++ b/web/src/catchers.rs @@ -0,0 +1,40 @@ +use std::collections::HashMap; + +use rocket::serde::json::json; +use rocket_dyn_templates::Template; + +use crate::JsonValue; + +#[catch(403)] +pub(crate) async fn forbidden() -> Template { + let map: HashMap = HashMap::new(); + Template::render("errors/403", &map) +} + +#[catch(500)] +pub(crate) async fn internal_server_error() -> Template { + let map: HashMap = HashMap::new(); + Template::render("errors/500", &map) +} + +#[catch(401)] +pub(crate) async fn not_authorized() -> Template { + let map: HashMap = HashMap::new(); + Template::render("errors/401", &map) +} + +#[catch(404)] +pub(crate) async fn not_found() -> Template { + let map: HashMap = HashMap::new(); + Template::render("errors/404", &map) +} + +#[catch(413)] +pub(crate) async fn payload_too_large() -> JsonValue { + json!({"error": "Data too large.", "errors": ["Data too large."]}) +} + +#[catch(422)] +pub(crate) async fn unprocessable_entity() -> JsonValue { + json!({"error": "Invalid request.", "errors": ["Invalid request."]}) +} diff --git a/web/src/guards/mod.rs b/web/src/guards/mod.rs new file mode 100644 index 0000000..3f46026 --- /dev/null +++ b/web/src/guards/mod.rs @@ -0,0 +1 @@ +pub(crate) mod transaction; diff --git a/web/src/guards/transaction.rs b/web/src/guards/transaction.rs new file mode 100644 index 0000000..7585ed5 --- /dev/null +++ b/web/src/guards/transaction.rs @@ -0,0 +1,34 @@ +use rocket::{ + http::Status, + request::{FromRequest, Outcome}, + Request, State, +}; +use sqlx::Pool; + +use crate::Database; + +pub(crate) struct Transaction<'a>(sqlx::Transaction<'a, Database>); + +#[derive(Debug)] +pub(crate) enum TransactionError { + Error(sqlx::Error), + Missing, +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Transaction<'r> { + type Error = TransactionError; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + match request.guard::<&State>>().await { + Outcome::Success(pool) => match pool.begin().await { + Ok(transaction) => Outcome::Success(Transaction(transaction)), + Err(e) => { + Outcome::Failure((Status::InternalServerError, TransactionError::Error(e))) + } + }, + Outcome::Failure(e) => Outcome::Failure((e.0, TransactionError::Missing)), + Outcome::Forward(f) => Outcome::Forward(f), + } + } +} diff --git a/web/src/lib.rs b/web/src/lib.rs index 54631cd..85ac791 100644 --- a/web/src/lib.rs +++ b/web/src/lib.rs @@ -4,16 +4,14 @@ extern crate rocket; mod consts; #[macro_use] mod macros; +mod catchers; +mod guards; mod routes; -use std::{collections::HashMap, env, path::Path}; +use std::{env, path::Path}; use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; -use rocket::{ - fs::FileServer, - serde::json::{json, Value as JsonValue}, - tokio::sync::broadcast::Sender, -}; +use rocket::{fs::FileServer, serde::json::Value as JsonValue, tokio::sync::broadcast::Sender}; use rocket_dyn_templates::Template; use serenity::{ client::Context, @@ -32,40 +30,6 @@ enum Error { Serenity(serenity::Error), } -#[catch(401)] -async fn not_authorized() -> Template { - let map: HashMap = HashMap::new(); - Template::render("errors/401", &map) -} - -#[catch(403)] -async fn forbidden() -> Template { - let map: HashMap = HashMap::new(); - Template::render("errors/403", &map) -} - -#[catch(404)] -async fn not_found() -> Template { - let map: HashMap = HashMap::new(); - Template::render("errors/404", &map) -} - -#[catch(413)] -async fn payload_too_large() -> JsonValue { - json!({"error": "Data too large.", "errors": ["Data too large."]}) -} - -#[catch(422)] -async fn unprocessable_entity() -> JsonValue { - json!({"error": "Invalid request.", "errors": ["Invalid request."]}) -} - -#[catch(500)] -async fn internal_server_error() -> Template { - let map: HashMap = HashMap::new(); - Template::render("errors/500", &map) -} - pub async fn initialize( kill_channel: Sender<()>, serenity_context: Context, @@ -100,12 +64,12 @@ pub async fn initialize( .register( "/", catchers![ - not_authorized, - forbidden, - not_found, - internal_server_error, - unprocessable_entity, - payload_too_large, + catchers::not_authorized, + catchers::forbidden, + catchers::not_found, + catchers::internal_server_error, + catchers::unprocessable_entity, + catchers::payload_too_large, ], ) .manage(oauth2_client)