Restructure

Move some code out to other files. Add transaction guard
This commit is contained in:
jude 2023-09-24 13:11:53 +01:00
parent bd1462a00c
commit 4bad1324b9
4 changed files with 85 additions and 46 deletions

40
web/src/catchers.rs Normal file
View File

@ -0,0 +1,40 @@
use std::collections::HashMap;
use rocket::serde::json::json;
use rocket_dyn_templates::Template;
use crate::JsonValue;
#[catch(403)]
pub(crate) async fn forbidden() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/403", &map)
}
#[catch(500)]
pub(crate) async fn internal_server_error() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/500", &map)
}
#[catch(401)]
pub(crate) async fn not_authorized() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/401", &map)
}
#[catch(404)]
pub(crate) async fn not_found() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/404", &map)
}
#[catch(413)]
pub(crate) async fn payload_too_large() -> JsonValue {
json!({"error": "Data too large.", "errors": ["Data too large."]})
}
#[catch(422)]
pub(crate) async fn unprocessable_entity() -> JsonValue {
json!({"error": "Invalid request.", "errors": ["Invalid request."]})
}

1
web/src/guards/mod.rs Normal file
View File

@ -0,0 +1 @@
pub(crate) mod transaction;

View File

@ -0,0 +1,34 @@
use rocket::{
http::Status,
request::{FromRequest, Outcome},
Request, State,
};
use sqlx::Pool;
use crate::Database;
pub(crate) struct Transaction<'a>(sqlx::Transaction<'a, Database>);
#[derive(Debug)]
pub(crate) enum TransactionError {
Error(sqlx::Error),
Missing,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Transaction<'r> {
type Error = TransactionError;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.guard::<&State<Pool<Database>>>().await {
Outcome::Success(pool) => match pool.begin().await {
Ok(transaction) => Outcome::Success(Transaction(transaction)),
Err(e) => {
Outcome::Failure((Status::InternalServerError, TransactionError::Error(e)))
}
},
Outcome::Failure(e) => Outcome::Failure((e.0, TransactionError::Missing)),
Outcome::Forward(f) => Outcome::Forward(f),
}
}
}

View File

@ -4,16 +4,14 @@ extern crate rocket;
mod consts; mod consts;
#[macro_use] #[macro_use]
mod macros; mod macros;
mod catchers;
mod guards;
mod routes; mod routes;
use std::{collections::HashMap, env, path::Path}; use std::{env, path::Path};
use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use rocket::{ use rocket::{fs::FileServer, serde::json::Value as JsonValue, tokio::sync::broadcast::Sender};
fs::FileServer,
serde::json::{json, Value as JsonValue},
tokio::sync::broadcast::Sender,
};
use rocket_dyn_templates::Template; use rocket_dyn_templates::Template;
use serenity::{ use serenity::{
client::Context, client::Context,
@ -32,40 +30,6 @@ enum Error {
Serenity(serenity::Error), Serenity(serenity::Error),
} }
#[catch(401)]
async fn not_authorized() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/401", &map)
}
#[catch(403)]
async fn forbidden() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/403", &map)
}
#[catch(404)]
async fn not_found() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/404", &map)
}
#[catch(413)]
async fn payload_too_large() -> JsonValue {
json!({"error": "Data too large.", "errors": ["Data too large."]})
}
#[catch(422)]
async fn unprocessable_entity() -> JsonValue {
json!({"error": "Invalid request.", "errors": ["Invalid request."]})
}
#[catch(500)]
async fn internal_server_error() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/500", &map)
}
pub async fn initialize( pub async fn initialize(
kill_channel: Sender<()>, kill_channel: Sender<()>,
serenity_context: Context, serenity_context: Context,
@ -100,12 +64,12 @@ pub async fn initialize(
.register( .register(
"/", "/",
catchers![ catchers![
not_authorized, catchers::not_authorized,
forbidden, catchers::forbidden,
not_found, catchers::not_found,
internal_server_error, catchers::internal_server_error,
unprocessable_entity, catchers::unprocessable_entity,
payload_too_large, catchers::payload_too_large,
], ],
) )
.manage(oauth2_client) .manage(oauth2_client)