use log::warn; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, }; use reqwest::Client; use rocket::{ http::{private::cookie::Expiration, Cookie, CookieJar, SameSite}, response::{Flash, Redirect}, uri, State, }; use serenity::model::user::User; use crate::{consts::DISCORD_API, routes}; #[get("/discord")] pub async fn discord_login( oauth2_client: &State, cookies: &CookieJar<'_>, ) -> Redirect { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (auth_url, csrf_token) = oauth2_client .authorize_url(CsrfToken::new_random) // Set the desired scopes. .add_scope(Scope::new("identify".to_string())) .add_scope(Scope::new("guilds".to_string())) // Set the PKCE code challenge. .set_pkce_challenge(pkce_challenge) .url(); // store the pkce secret to verify the authorization later cookies.add_private( Cookie::build("verify", pkce_verifier.secret().to_string()) .http_only(true) .path("/login") .same_site(SameSite::Lax) .expires(Expiration::Session) .finish(), ); // store the csrf token to verify no interference cookies.add_private( Cookie::build("csrf", csrf_token.secret().to_string()) .http_only(true) .path("/login") .same_site(SameSite::Lax) .expires(Expiration::Session) .finish(), ); Redirect::to(auth_url.to_string()) } #[get("/discord/logout")] pub async fn discord_logout(cookies: &CookieJar<'_>) -> Redirect { cookies.remove_private(Cookie::named("username")); cookies.remove_private(Cookie::named("userid")); cookies.remove_private(Cookie::named("access_token")); Redirect::to(uri!(routes::index)) } #[get("/discord/authorized?&")] pub async fn discord_callback( code: &str, state: &str, cookies: &CookieJar<'_>, oauth2_client: &State, reqwest_client: &State, ) -> Result> { if let (Some(pkce_secret), Some(csrf_token)) = (cookies.get_private("verify"), cookies.get_private("csrf")) { if state == csrf_token.value() { let token_result = oauth2_client .exchange_code(AuthorizationCode::new(code.to_string())) // Set the PKCE code verifier. .set_pkce_verifier(PkceCodeVerifier::new(pkce_secret.value().to_string())) .request_async(async_http_client) .await; cookies.remove_private(Cookie::named("verify")); cookies.remove_private(Cookie::named("csrf")); match token_result { Ok(token) => { cookies.add_private( Cookie::build("access_token", token.access_token().secret().to_string()) .secure(true) .http_only(true) .path("/dashboard") .finish(), ); let request_res = reqwest_client .get(format!("{}/users/@me", DISCORD_API)) .bearer_auth(token.access_token().secret()) .send() .await; match request_res { Ok(response) => { let user_res = response.json::().await; match user_res { Ok(user) => { let user_name = format!("{}#{}", user.name, user.discriminator); let user_id = user.id.as_u64().to_string(); cookies.add_private(Cookie::new("username", user_name)); cookies.add_private(Cookie::new("userid", user_id)); Ok(Redirect::to(uri!(super::return_to_same_site("dashboard")))) } Err(e) => { warn!("Error constructing user from request: {:?}", e); Err(Flash::new( Redirect::to(uri!(super::return_to_same_site(""))), "danger", "Failed to contact Discord", )) } } } Err(e) => { warn!("Error getting user info: {:?}", e); Err(Flash::new( Redirect::to(uri!(super::return_to_same_site(""))), "danger", "Failed to contact Discord", )) } } } Err(e) => { warn!("Error in discord callback: {:?}", e); Err(Flash::new( Redirect::to(uri!(super::return_to_same_site(""))), "warning", "Your login request was rejected. The server may be misconfigured. Please retry or alert us in Discord.", )) } } } else { Err(Flash::new(Redirect::to(uri!(super::return_to_same_site(""))), "danger", "Your request failed to validate, and so has been rejected (CSRF Validation Failure)")) } } else { Err(Flash::new(Redirect::to(uri!(super::return_to_same_site(""))), "warning", "Your request was missing information, and so has been rejected (CSRF Validation Tokens Missing)")) } }