#[macro_use] extern crate rocket; mod consts; #[macro_use] mod macros; mod catchers; mod guards; mod routes; use std::{env, path::Path}; use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; use rocket::{ fs::FileServer, http::CookieJar, serde::json::{json, Value as JsonValue}, tokio::sync::broadcast::Sender, }; use rocket_dyn_templates::Template; use serenity::{ client::Context, http::CacheHttp, model::id::{GuildId, UserId}, }; use sqlx::{MySql, Pool}; use crate::consts::{CNC_GUILD, DISCORD_OAUTH_AUTHORIZE, DISCORD_OAUTH_TOKEN, SUBSCRIPTION_ROLES}; type Database = MySql; #[derive(Debug)] enum Error { SQLx(sqlx::Error), Serenity(serenity::Error), } pub async fn initialize( kill_channel: Sender<()>, serenity_context: Context, db_pool: Pool, ) -> Result<(), Box> { info!("Checking environment variables..."); if env::var("OFFLINE").map_or(true, |v| v != "1") { env::var("OAUTH2_CLIENT_ID").expect("`OAUTH2_CLIENT_ID' not supplied"); env::var("OAUTH2_CLIENT_SECRET").expect("`OAUTH2_CLIENT_SECRET' not supplied"); env::var("OAUTH2_DISCORD_CALLBACK").expect("`OAUTH2_DISCORD_CALLBACK' not supplied"); env::var("PATREON_GUILD_ID").expect("`PATREON_GUILD_ID' not supplied"); } info!("Done!"); let oauth2_client = BasicClient::new( ClientId::new(env::var("OAUTH2_CLIENT_ID")?), Some(ClientSecret::new(env::var("OAUTH2_CLIENT_SECRET")?)), AuthUrl::new(DISCORD_OAUTH_AUTHORIZE.to_string())?, Some(TokenUrl::new(DISCORD_OAUTH_TOKEN.to_string())?), ) .set_redirect_uri(RedirectUrl::new(env::var("OAUTH2_DISCORD_CALLBACK")?)?); let reqwest_client = reqwest::Client::new(); let static_path = if Path::new("web/static").exists() { "web/static" } else { "/lib/reminder-rs/static" }; rocket::build() .attach(Template::fairing()) .register( "/", catchers![ catchers::not_authorized, catchers::forbidden, catchers::not_found, catchers::internal_server_error, catchers::unprocessable_entity, catchers::payload_too_large, ], ) .manage(oauth2_client) .manage(reqwest_client) .manage(serenity_context) .manage(db_pool) .mount("/static", FileServer::from(static_path)) .mount( "/", routes![ routes::index, routes::cookies, routes::privacy, routes::terms, routes::return_to_same_site, routes::report::report_error, ], ) .mount( "/help", routes![ routes::help, routes::help_timezone, routes::help_create_reminder, routes::help_delete_reminder, routes::help_timers, routes::help_todo_lists, routes::help_macros, routes::help_intervals, routes::help_dashboard, routes::help_iemanager, ], ) .mount( "/login", routes![ routes::login::discord_login, routes::login::discord_logout, routes::login::discord_callback ], ) .mount( "/dashboard", routes![ routes::dashboard::dashboard, routes::dashboard::dashboard_home, routes::dashboard::api::user::get_user_info, routes::dashboard::api::user::update_user_info, routes::dashboard::api::user::get_user_guilds, routes::dashboard::api::guild::get_guild_info, routes::dashboard::api::guild::get_guild_channels, routes::dashboard::api::guild::get_guild_roles, routes::dashboard::api::guild::get_reminder_templates, routes::dashboard::api::guild::create_reminder_template, routes::dashboard::api::guild::delete_reminder_template, routes::dashboard::api::guild::create_guild_reminder, routes::dashboard::api::guild::get_reminders, routes::dashboard::api::guild::edit_reminder, routes::dashboard::api::guild::delete_reminder, routes::dashboard::export::export_reminders, routes::dashboard::export::export_reminder_templates, routes::dashboard::export::export_todos, routes::dashboard::export::import_reminders, routes::dashboard::export::import_todos, ], ) .mount("/admin", routes![routes::admin::admin_dashboard_home, routes::admin::bot_data]) .launch() .await?; warn!("Exiting rocket runtime"); // distribute kill signal match kill_channel.send(()) { Ok(_) => {} Err(e) => { error!("Failed to issue kill signal: {:?}", e); } } Ok(()) } pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into) -> bool { offline!(true); if let Some(subscription_guild) = *CNC_GUILD { let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await; if let Ok(member) = guild_member { for role in member.roles { if SUBSCRIPTION_ROLES.contains(role.as_u64()) { return true; } } } false } else { true } } pub async fn check_guild_subscription( cache_http: impl CacheHttp, guild_id: impl Into, ) -> bool { offline!(true); if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) { let owner = guild.owner_id; check_subscription(&cache_http, owner).await } else { false } } pub async fn check_authorization( cookies: &CookieJar<'_>, ctx: &Context, guild: u64, ) -> Result<(), JsonValue> { let user_id = cookies.get_private("userid").map(|c| c.value().parse::().ok()).flatten(); if std::env::var("OFFLINE").map_or(true, |v| v != "1") { match user_id { Some(user_id) => { let admin_id = std::env::var("ADMIN_ID") .map_or(false, |u| u.parse::().map_or(false, |u| u == user_id)); if admin_id { return Ok(()); } match GuildId(guild).to_guild_cached(ctx) { Some(guild) => { let member_res = guild.member(ctx, UserId(user_id)).await; match member_res { Err(_) => { return Err(json!({"error": "User not in guild"})); } Ok(member) => { let permissions_res = member.permissions(ctx); match permissions_res { Err(_) => { return Err(json!({"error": "Couldn't fetch permissions"})); } Ok(permissions) => { if !(permissions.manage_messages() || permissions.manage_guild() || permissions.administrator()) { return Err(json!({"error": "Incorrect permissions"})); } } } } } } None => { return Err(json!({"error": "Bot not in guild"})); } } } None => { return Err(json!({"error": "User not authorized"})); } } } Ok(()) }