Files
reminder-bot/src/web/mod.rs
2024-04-09 21:21:46 +01:00

311 lines
9.5 KiB
Rust

mod consts;
#[macro_use]
mod macros;
mod catchers;
mod guards;
mod metrics;
mod routes;
pub mod string {
use std::{fmt::Display, str::FromStr};
use serde::{de, Deserialize, Deserializer, Serializer};
pub fn serialize<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
T: Display,
S: Serializer,
{
serializer.collect_str(value)
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr,
T::Err: Display,
D: Deserializer<'de>,
{
String::deserialize(deserializer)?.parse().map_err(de::Error::custom)
}
}
pub mod string_opt {
use std::{fmt::Display, str::FromStr};
use serde::{de, Deserialize, Deserializer, Serializer};
pub fn serialize<T, S>(value: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
where
T: Display,
S: Serializer,
{
if let Some(v) = value {
serializer.collect_str(v)
} else {
serializer.serialize_none()
}
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
where
T: FromStr,
T::Err: Display,
D: Deserializer<'de>,
{
Option::<String>::deserialize(deserializer)?
.map(|s| s.parse().map_err(de::Error::custom))
.transpose()
}
}
use std::{env, path::Path};
use log::{error, info, warn};
use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use poise::serenity_prelude::{
client::Context,
http::CacheHttp,
model::id::{GuildId, UserId},
};
use rocket::{
catchers,
fs::FileServer,
http::CookieJar,
routes,
serde::json::{json, Value as JsonValue},
tokio::sync::broadcast::Sender,
};
use rocket_dyn_templates::Template;
use sqlx::{MySql, Pool};
use crate::web::{
consts::{CNC_GUILD, DISCORD_OAUTH_AUTHORIZE, DISCORD_OAUTH_TOKEN, SUBSCRIPTION_ROLES},
metrics::MetricProducer,
};
type Database = MySql;
#[derive(Debug)]
enum Error {
#[allow(unused)]
SQLx(sqlx::Error),
#[allow(unused)]
Serenity(serenity::Error),
}
pub async fn initialize(
kill_channel: Sender<()>,
serenity_context: Context,
db_pool: Pool<Database>,
) -> Result<(), Box<dyn std::error::Error>> {
info!("Checking environment variables...");
if env::var("OFFLINE").map_or(true, |v| v != "1") {
env::var("OAUTH2_CLIENT_ID").expect("`OAUTH2_CLIENT_ID' not supplied");
env::var("OAUTH2_CLIENT_SECRET").expect("`OAUTH2_CLIENT_SECRET' not supplied");
env::var("OAUTH2_DISCORD_CALLBACK").expect("`OAUTH2_DISCORD_CALLBACK' not supplied");
env::var("PATREON_GUILD_ID").expect("`PATREON_GUILD_ID' not supplied");
}
info!("Done!");
let oauth2_client = BasicClient::new(
ClientId::new(env::var("OAUTH2_CLIENT_ID")?),
Some(ClientSecret::new(env::var("OAUTH2_CLIENT_SECRET")?)),
AuthUrl::new(DISCORD_OAUTH_AUTHORIZE.to_string())?,
Some(TokenUrl::new(DISCORD_OAUTH_TOKEN.to_string())?),
)
.set_redirect_uri(RedirectUrl::new(env::var("OAUTH2_DISCORD_CALLBACK")?)?);
let reqwest_client = reqwest::Client::new();
let static_path =
if Path::new("static").exists() { "static" } else { "/lib/reminder-rs/static" };
rocket::build()
.attach(MetricProducer)
.attach(Template::fairing())
.register(
"/",
catchers![
catchers::not_authorized,
catchers::forbidden,
catchers::not_found,
catchers::internal_server_error,
catchers::unprocessable_entity,
catchers::payload_too_large,
],
)
.manage(oauth2_client)
.manage(reqwest_client)
.manage(serenity_context)
.manage(db_pool)
.mount("/static", FileServer::from(static_path))
.mount(
"/",
routes![
routes::cookies,
routes::index,
routes::privacy,
routes::report::report_error,
routes::return_to_same_site,
routes::terms,
],
)
.mount(
"/help",
routes![
routes::help,
routes::help_timezone,
routes::help_create_reminder,
routes::help_delete_reminder,
routes::help_timers,
routes::help_todo_lists,
routes::help_macros,
routes::help_intervals,
routes::help_dashboard,
routes::help_iemanager,
],
)
.mount(
"/login",
routes![
routes::login::discord_login,
routes::login::discord_logout,
routes::login::discord_callback
],
)
.mount(
"/dashboard",
routes![
routes::dashboard::dashboard,
routes::dashboard::dashboard_home,
routes::dashboard::api::delete_reminder,
routes::dashboard::api::user::get_user_info,
routes::dashboard::api::user::update_user_info,
routes::dashboard::api::user::get_user_guilds,
routes::dashboard::api::user::get_reminders,
routes::dashboard::api::user::edit_reminder,
routes::dashboard::api::user::create_user_reminder,
routes::dashboard::api::guild::get_guild_info,
routes::dashboard::api::guild::get_guild_channels,
routes::dashboard::api::guild::get_guild_roles,
routes::dashboard::api::guild::get_guild_emojis,
routes::dashboard::api::guild::get_reminder_templates,
routes::dashboard::api::guild::create_reminder_template,
routes::dashboard::api::guild::delete_reminder_template,
routes::dashboard::api::guild::create_guild_reminder,
routes::dashboard::api::guild::get_reminders,
routes::dashboard::api::guild::edit_reminder,
routes::dashboard::api::guild::todos::create_todo,
routes::dashboard::api::guild::todos::get_todo,
routes::dashboard::api::guild::todos::update_todo,
routes::dashboard::api::guild::todos::delete_todo,
routes::dashboard::export::export_reminders,
routes::dashboard::export::export_reminder_templates,
routes::dashboard::export::export_todos,
routes::dashboard::export::import_reminders,
routes::dashboard::export::import_todos,
],
)
.launch()
.await?;
warn!("Exiting rocket runtime");
// distribute kill signal
match kill_channel.send(()) {
Ok(_) => {}
Err(e) => {
error!("Failed to issue kill signal: {:?}", e);
}
}
Ok(())
}
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
offline!(true);
if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId::new(subscription_guild).member(cache_http, user_id).await;
if let Ok(member) = guild_member {
for role in member.roles {
if SUBSCRIPTION_ROLES.contains(&role.get()) {
return true;
}
}
}
false
} else {
true
}
}
pub async fn check_guild_subscription(
cache_http: impl CacheHttp,
guild_id: impl Into<GuildId>,
) -> bool {
offline!(true);
if let Some(owner) = cache_http.cache().unwrap().guild(guild_id).map(|guild| guild.owner_id) {
check_subscription(&cache_http, owner).await
} else {
false
}
}
pub async fn check_authorization(
cookies: &CookieJar<'_>,
ctx: &Context,
guild_id: u64,
) -> Result<(), JsonValue> {
let user_id = cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten();
if std::env::var("OFFLINE").map_or(true, |v| v != "1") {
match user_id {
Some(user_id) => {
let admin_id = std::env::var("ADMIN_ID")
.map_or(false, |u| u.parse::<u64>().map_or(false, |u| u == user_id));
if admin_id {
return Ok(());
}
let guild_id = GuildId::new(guild_id);
match guild_id.member(ctx, UserId::new(user_id)).await {
Ok(member) => {
let permissions_res = member.permissions(ctx);
match permissions_res {
Err(_) => {
return Err(json!({"error": "Couldn't fetch permissions"}));
}
Ok(permissions) => {
if !(permissions.manage_messages()
|| permissions.manage_guild()
|| permissions.administrator())
{
return Err(json!({"error": "Incorrect permissions"}));
}
}
}
}
Err(_) => {
return Err(json!({"error": "User not in guild"}));
}
}
}
None => {
return Err(json!({"error": "User not authorized"}));
}
}
}
Ok(())
}