moved dashboard crate into here
This commit is contained in:
205
web/src/routes/dashboard/guild.rs
Normal file
205
web/src/routes/dashboard/guild.rs
Normal file
@ -0,0 +1,205 @@
|
||||
use rocket::State;
|
||||
|
||||
use crate::consts::DISCORD_CDN;
|
||||
use serde::Serialize;
|
||||
use sqlx::{MySql, Pool};
|
||||
|
||||
use super::Reminder;
|
||||
use rocket::serde::json::{json, Json, Value as JsonValue};
|
||||
use serenity::client::Context;
|
||||
use serenity::http::CacheHttp;
|
||||
use serenity::model::id::GuildId;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChannelInfo {
|
||||
id: String,
|
||||
name: String,
|
||||
webhook_avatar: Option<String>,
|
||||
webhook_name: Option<String>,
|
||||
}
|
||||
|
||||
// todo check the user can access this guild
|
||||
#[get("/api/guild/<id>/channels")]
|
||||
pub async fn get_guild_channels(
|
||||
id: u64,
|
||||
ctx: &State<Context>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
) -> JsonValue {
|
||||
let channels_res = GuildId(id).channels(ctx.inner()).await;
|
||||
|
||||
match channels_res {
|
||||
Ok(channels) => {
|
||||
let mut channel_info = vec![];
|
||||
|
||||
for (channel_id, channel) in
|
||||
channels.iter().filter(|(_, channel)| channel.is_text_based())
|
||||
{
|
||||
let mut ch = ChannelInfo {
|
||||
name: channel.name.to_string(),
|
||||
id: channel_id.to_string(),
|
||||
webhook_avatar: None,
|
||||
webhook_name: None,
|
||||
};
|
||||
|
||||
if let Ok(webhook_details) = sqlx::query!(
|
||||
"SELECT webhook_id, webhook_token FROM channels WHERE channel = ?",
|
||||
channel.id.as_u64()
|
||||
)
|
||||
.fetch_one(pool.inner())
|
||||
.await
|
||||
{
|
||||
if let (Some(webhook_id), Some(webhook_token)) =
|
||||
(webhook_details.webhook_id, webhook_details.webhook_token)
|
||||
{
|
||||
let webhook_res =
|
||||
ctx.http.get_webhook_with_token(webhook_id, &webhook_token).await;
|
||||
|
||||
if let Ok(webhook) = webhook_res {
|
||||
ch.webhook_avatar = webhook.avatar.map(|a| {
|
||||
format!("{}/{}/{}.webp?size=128", DISCORD_CDN, webhook_id, a)
|
||||
});
|
||||
|
||||
ch.webhook_name = webhook.name;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channel_info.push(ch);
|
||||
}
|
||||
|
||||
json!(channel_info)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Could not fetch channels from {}: {:?}", id, e);
|
||||
|
||||
json!({"error": "Could not get channels"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RoleInfo {
|
||||
id: String,
|
||||
name: String,
|
||||
}
|
||||
|
||||
// todo check the user can access this guild
|
||||
#[get("/api/guild/<id>/roles")]
|
||||
pub async fn get_guild_roles(id: u64, ctx: &State<Context>) -> JsonValue {
|
||||
let roles_res = ctx.cache.guild_roles(id);
|
||||
|
||||
match roles_res {
|
||||
Some(roles) => {
|
||||
let roles = roles
|
||||
.iter()
|
||||
.map(|(_, r)| RoleInfo { id: r.id.to_string(), name: r.name.to_string() })
|
||||
.collect::<Vec<RoleInfo>>();
|
||||
|
||||
json!(roles)
|
||||
}
|
||||
None => {
|
||||
warn!("Could not fetch roles from {}", id);
|
||||
|
||||
json!({"error": "Could not get roles"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/api/guild/<id>/reminders", data = "<reminder>")]
|
||||
pub async fn create_reminder(
|
||||
id: u64,
|
||||
reminder: Json<Reminder>,
|
||||
serenity_context: &State<Context>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
) -> JsonValue {
|
||||
json!({"error": "Not implemented"})
|
||||
}
|
||||
|
||||
#[get("/api/guild/<id>/reminders")]
|
||||
pub async fn get_reminders(id: u64, ctx: &State<Context>, pool: &State<Pool<MySql>>) -> JsonValue {
|
||||
let channels_res = GuildId(id).channels(&ctx.inner()).await;
|
||||
|
||||
match channels_res {
|
||||
Ok(channels) => {
|
||||
let channels = channels
|
||||
.keys()
|
||||
.into_iter()
|
||||
.map(|k| k.as_u64().to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(",");
|
||||
|
||||
sqlx::query_as_unchecked!(
|
||||
Reminder,
|
||||
"
|
||||
SELECT
|
||||
reminders.attachment,
|
||||
reminders.attachment_name,
|
||||
reminders.avatar,
|
||||
channels.channel,
|
||||
reminders.content,
|
||||
reminders.embed_author,
|
||||
reminders.embed_author_url,
|
||||
reminders.embed_color,
|
||||
reminders.embed_description,
|
||||
reminders.embed_footer,
|
||||
reminders.embed_footer_url,
|
||||
reminders.embed_image_url,
|
||||
reminders.embed_thumbnail_url,
|
||||
reminders.embed_title,
|
||||
reminders.enabled,
|
||||
reminders.expires,
|
||||
reminders.interval_seconds,
|
||||
reminders.interval_months,
|
||||
reminders.name,
|
||||
reminders.pin,
|
||||
reminders.restartable,
|
||||
reminders.tts,
|
||||
reminders.uid,
|
||||
reminders.username,
|
||||
reminders.utc_time
|
||||
FROM
|
||||
reminders
|
||||
LEFT JOIN
|
||||
channels
|
||||
ON
|
||||
channels.id = reminders.channel_id
|
||||
WHERE
|
||||
FIND_IN_SET(channels.channel, ?)
|
||||
",
|
||||
channels
|
||||
)
|
||||
.fetch_all(pool.inner())
|
||||
.await
|
||||
.map(|r| json!(r))
|
||||
.unwrap_or_else(|e| {
|
||||
warn!("Failed to complete SQL query: {:?}", e);
|
||||
|
||||
json!({"error": "Could not load reminders"})
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Could not fetch channels from {}: {:?}", id, e);
|
||||
|
||||
json!([])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[patch("/api/guild/<id>/reminders", data = "<reminder>")]
|
||||
pub async fn edit_reminder(
|
||||
id: u64,
|
||||
reminder: Json<Reminder>,
|
||||
serenity_context: &State<Context>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
) -> JsonValue {
|
||||
json!({"error": "Not implemented"})
|
||||
}
|
||||
|
||||
#[delete("/api/guild/<id>/reminders", data = "<reminder>")]
|
||||
pub async fn delete_reminder(
|
||||
id: u64,
|
||||
reminder: Json<Reminder>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
) -> JsonValue {
|
||||
json!({"error": "Not implemented"})
|
||||
}
|
59
web/src/routes/dashboard/mod.rs
Normal file
59
web/src/routes/dashboard/mod.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use chrono::naive::NaiveDateTime;
|
||||
use rocket::http::CookieJar;
|
||||
use rocket::response::Redirect;
|
||||
use rocket_dyn_templates::Template;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub mod guild;
|
||||
pub mod user;
|
||||
|
||||
fn name_default() -> String {
|
||||
"Reminder".to_string()
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Reminder {
|
||||
attachment: Option<Vec<u8>>,
|
||||
attachment_name: Option<String>,
|
||||
avatar: Option<String>,
|
||||
channel: u64,
|
||||
content: String,
|
||||
embed_author: String,
|
||||
embed_author_url: Option<String>,
|
||||
embed_color: u32,
|
||||
embed_description: String,
|
||||
embed_footer: String,
|
||||
embed_footer_url: Option<String>,
|
||||
embed_image_url: Option<String>,
|
||||
embed_thumbnail_url: Option<String>,
|
||||
embed_title: String,
|
||||
enabled: i8,
|
||||
expires: Option<NaiveDateTime>,
|
||||
interval_seconds: Option<u32>,
|
||||
interval_months: Option<u32>,
|
||||
#[serde(default = "name_default")]
|
||||
name: String,
|
||||
pin: i8,
|
||||
restartable: i8,
|
||||
tts: i8,
|
||||
#[serde(default)]
|
||||
uid: String,
|
||||
username: Option<String>,
|
||||
utc_time: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DeleteReminder {
|
||||
uid: String,
|
||||
}
|
||||
|
||||
#[get("/")]
|
||||
pub async fn dashboard_home(cookies: &CookieJar<'_>) -> Result<Template, Redirect> {
|
||||
if cookies.get_private("userid").is_some() {
|
||||
let map: HashMap<&str, String> = HashMap::new();
|
||||
Ok(Template::render("dashboard", &map))
|
||||
} else {
|
||||
Err(Redirect::to("/login/discord"))
|
||||
}
|
||||
}
|
402
web/src/routes/dashboard/user.rs
Normal file
402
web/src/routes/dashboard/user.rs
Normal file
@ -0,0 +1,402 @@
|
||||
use rocket::serde::json::{json, Json, Value as JsonValue};
|
||||
use rocket::{http::CookieJar, State};
|
||||
|
||||
use reqwest::Client;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serenity::model::{
|
||||
id::{GuildId, RoleId},
|
||||
permissions::Permissions,
|
||||
};
|
||||
use sqlx::{MySql, Pool};
|
||||
use std::env;
|
||||
|
||||
use super::Reminder;
|
||||
use crate::consts::DISCORD_API;
|
||||
use crate::routes::dashboard::DeleteReminder;
|
||||
use chrono_tz::Tz;
|
||||
use serenity::client::Context;
|
||||
use serenity::model::id::UserId;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UserInfo {
|
||||
name: String,
|
||||
patreon: bool,
|
||||
timezone: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UpdateUser {
|
||||
timezone: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct GuildInfo {
|
||||
id: String,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct PartialGuild {
|
||||
pub id: GuildId,
|
||||
pub icon: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub owner: bool,
|
||||
#[serde(rename = "permissions_new")]
|
||||
pub permissions: Option<String>,
|
||||
}
|
||||
|
||||
#[get("/api/user")]
|
||||
pub async fn get_user_info(
|
||||
cookies: &CookieJar<'_>,
|
||||
ctx: &State<Context>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
) -> JsonValue {
|
||||
if let Some(user_id) =
|
||||
cookies.get_private("userid").map(|u| u.value().parse::<u64>().ok()).flatten()
|
||||
{
|
||||
let member_res = GuildId(env::var("PATREON_GUILD_ID").unwrap().parse().unwrap())
|
||||
.member(&ctx.inner(), user_id)
|
||||
.await;
|
||||
|
||||
let timezone = sqlx::query!("SELECT timezone FROM users WHERE user = ?", user_id)
|
||||
.fetch_one(pool.inner())
|
||||
.await
|
||||
.map_or(None, |q| Some(q.timezone));
|
||||
|
||||
let user_info = UserInfo {
|
||||
name: cookies
|
||||
.get_private("username")
|
||||
.map_or("DiscordUser#0000".to_string(), |c| c.value().to_string()),
|
||||
patreon: member_res.map_or(false, |member| {
|
||||
member
|
||||
.roles
|
||||
.contains(&RoleId(env::var("PATREON_ROLE_ID").unwrap().parse().unwrap()))
|
||||
}),
|
||||
timezone,
|
||||
};
|
||||
|
||||
json!(user_info)
|
||||
} else {
|
||||
json!({"error": "Not authorized"})
|
||||
}
|
||||
}
|
||||
|
||||
#[patch("/api/user", data = "<user>")]
|
||||
pub async fn update_user_info(
|
||||
cookies: &CookieJar<'_>,
|
||||
user: Json<UpdateUser>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
) -> JsonValue {
|
||||
if let Some(user_id) =
|
||||
cookies.get_private("userid").map(|u| u.value().parse::<u64>().ok()).flatten()
|
||||
{
|
||||
if user.timezone.parse::<Tz>().is_ok() {
|
||||
let _ = sqlx::query!(
|
||||
"UPDATE users SET timezone = ? WHERE user = ?",
|
||||
user.timezone,
|
||||
user_id,
|
||||
)
|
||||
.execute(pool.inner())
|
||||
.await;
|
||||
|
||||
json!({})
|
||||
} else {
|
||||
json!({"error": "Timezone not recognized"})
|
||||
}
|
||||
} else {
|
||||
json!({"error": "Not authorized"})
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/api/user/guilds")]
|
||||
pub async fn get_user_guilds(cookies: &CookieJar<'_>, reqwest_client: &State<Client>) -> JsonValue {
|
||||
if let Some(access_token) = cookies.get_private("access_token") {
|
||||
let request_res = reqwest_client
|
||||
.get(format!("{}/users/@me/guilds", DISCORD_API))
|
||||
.bearer_auth(access_token.value())
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match request_res {
|
||||
Ok(response) => {
|
||||
let guilds_res = response.json::<Vec<PartialGuild>>().await;
|
||||
|
||||
match guilds_res {
|
||||
Ok(guilds) => {
|
||||
let reduced_guilds = guilds
|
||||
.iter()
|
||||
.filter(|g| {
|
||||
g.owner
|
||||
|| g.permissions.as_ref().map_or(false, |p| {
|
||||
let permissions =
|
||||
Permissions::from_bits_truncate(p.parse().unwrap());
|
||||
|
||||
permissions.manage_messages()
|
||||
|| permissions.manage_guild()
|
||||
|| permissions.administrator()
|
||||
})
|
||||
})
|
||||
.map(|g| GuildInfo { id: g.id.to_string(), name: g.name.to_string() })
|
||||
.collect::<Vec<GuildInfo>>();
|
||||
|
||||
json!(reduced_guilds)
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
warn!("Error constructing user from request: {:?}", e);
|
||||
|
||||
json!({"error": "Could not get user details"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
warn!("Error getting user guilds: {:?}", e);
|
||||
|
||||
json!({"error": "Could not reach Discord"})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
json!({"error": "Not authorized"})
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/api/user/reminders", data = "<reminder>")]
|
||||
pub async fn create_reminder(
|
||||
reminder: Json<Reminder>,
|
||||
ctx: &State<Context>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
) -> JsonValue {
|
||||
match sqlx::query!(
|
||||
"INSERT INTO reminders (
|
||||
avatar,
|
||||
content,
|
||||
embed_author,
|
||||
embed_author_url,
|
||||
embed_color,
|
||||
embed_description,
|
||||
embed_footer,
|
||||
embed_footer_url,
|
||||
embed_image_url,
|
||||
embed_thumbnail_url,
|
||||
embed_title,
|
||||
enabled,
|
||||
expires,
|
||||
interval_seconds,
|
||||
interval_months,
|
||||
name,
|
||||
pin,
|
||||
restartable,
|
||||
tts,
|
||||
username,
|
||||
`utc_time`
|
||||
) VALUES (
|
||||
avatar = ?,
|
||||
content = ?,
|
||||
embed_author = ?,
|
||||
embed_author_url = ?,
|
||||
embed_color = ?,
|
||||
embed_description = ?,
|
||||
embed_footer = ?,
|
||||
embed_footer_url = ?,
|
||||
embed_image_url = ?,
|
||||
embed_thumbnail_url = ?,
|
||||
embed_title = ?,
|
||||
enabled = ?,
|
||||
expires = ?,
|
||||
interval_seconds = ?,
|
||||
interval_months = ?,
|
||||
name = ?,
|
||||
pin = ?,
|
||||
restartable = ?,
|
||||
tts = ?,
|
||||
username = ?,
|
||||
`utc_time` = ?
|
||||
)",
|
||||
reminder.avatar,
|
||||
reminder.content,
|
||||
reminder.embed_author,
|
||||
reminder.embed_author_url,
|
||||
reminder.embed_color,
|
||||
reminder.embed_description,
|
||||
reminder.embed_footer,
|
||||
reminder.embed_footer_url,
|
||||
reminder.embed_image_url,
|
||||
reminder.embed_thumbnail_url,
|
||||
reminder.embed_title,
|
||||
reminder.enabled,
|
||||
reminder.expires,
|
||||
reminder.interval_seconds,
|
||||
reminder.interval_months,
|
||||
reminder.name,
|
||||
reminder.pin,
|
||||
reminder.restartable,
|
||||
reminder.tts,
|
||||
reminder.username,
|
||||
reminder.utc_time,
|
||||
)
|
||||
.execute(pool.inner())
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
json!({})
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error in `create_reminder`: {:?}", e);
|
||||
|
||||
json!({"error": "Could not create reminder"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/api/user/reminders")]
|
||||
pub async fn get_reminders(
|
||||
pool: &State<Pool<MySql>>,
|
||||
cookies: &CookieJar<'_>,
|
||||
ctx: &State<Context>,
|
||||
) -> JsonValue {
|
||||
if let Some(user_id) =
|
||||
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten()
|
||||
{
|
||||
let query_res = sqlx::query!(
|
||||
"SELECT channel FROM channels INNER JOIN users ON users.dm_channel = channels.id WHERE users.user = ?",
|
||||
user_id
|
||||
)
|
||||
.fetch_one(pool.inner())
|
||||
.await;
|
||||
|
||||
let dm_channel = if let Ok(query) = query_res {
|
||||
Some(query.channel)
|
||||
} else {
|
||||
if let Ok(dm_channel) = UserId(user_id).create_dm_channel(&ctx.inner()).await {
|
||||
Some(dm_channel.id.as_u64().to_owned())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(channel_id) = dm_channel {
|
||||
let reminders = sqlx::query_as!(
|
||||
Reminder,
|
||||
"SELECT
|
||||
reminders.attachment,
|
||||
reminders.attachment_name,
|
||||
reminders.avatar,
|
||||
channels.channel,
|
||||
reminders.content,
|
||||
reminders.embed_author,
|
||||
reminders.embed_author_url,
|
||||
reminders.embed_color,
|
||||
reminders.embed_description,
|
||||
reminders.embed_footer,
|
||||
reminders.embed_footer_url,
|
||||
reminders.embed_image_url,
|
||||
reminders.embed_thumbnail_url,
|
||||
reminders.embed_title,
|
||||
reminders.enabled,
|
||||
reminders.expires,
|
||||
reminders.interval_seconds,
|
||||
reminders.interval_months,
|
||||
reminders.name,
|
||||
reminders.pin,
|
||||
reminders.restartable,
|
||||
reminders.tts,
|
||||
reminders.uid,
|
||||
reminders.username,
|
||||
reminders.utc_time
|
||||
FROM reminders INNER JOIN channels ON channels.id = reminders.channel_id WHERE channels.channel = ?",
|
||||
channel_id
|
||||
)
|
||||
.fetch_all(pool.inner())
|
||||
.await
|
||||
.unwrap_or(vec![]);
|
||||
|
||||
json!(reminders)
|
||||
} else {
|
||||
json!({"error": "User's DM channel could not be determined"})
|
||||
}
|
||||
} else {
|
||||
json!({"error": "Not authorized"})
|
||||
}
|
||||
}
|
||||
|
||||
#[put("/api/user/reminders", data = "<reminder>")]
|
||||
pub async fn overwrite_reminder(reminder: Json<Reminder>, pool: &State<Pool<MySql>>) -> JsonValue {
|
||||
match sqlx::query!(
|
||||
"UPDATE reminders SET
|
||||
avatar = ?,
|
||||
content = ?,
|
||||
embed_author = ?,
|
||||
embed_author_url = ?,
|
||||
embed_color = ?,
|
||||
embed_description = ?,
|
||||
embed_footer = ?,
|
||||
embed_footer_url = ?,
|
||||
embed_image_url = ?,
|
||||
embed_thumbnail_url = ?,
|
||||
embed_title = ?,
|
||||
enabled = ?,
|
||||
expires = ?,
|
||||
interval_seconds = ?,
|
||||
interval_months = ?,
|
||||
name = ?,
|
||||
pin = ?,
|
||||
restartable = ?,
|
||||
tts = ?,
|
||||
username = ?,
|
||||
`utc_time` = ?
|
||||
WHERE uid = ?",
|
||||
reminder.avatar,
|
||||
reminder.content,
|
||||
reminder.embed_author,
|
||||
reminder.embed_author_url,
|
||||
reminder.embed_color,
|
||||
reminder.embed_description,
|
||||
reminder.embed_footer,
|
||||
reminder.embed_footer_url,
|
||||
reminder.embed_image_url,
|
||||
reminder.embed_thumbnail_url,
|
||||
reminder.embed_title,
|
||||
reminder.enabled,
|
||||
reminder.expires,
|
||||
reminder.interval_seconds,
|
||||
reminder.interval_months,
|
||||
reminder.name,
|
||||
reminder.pin,
|
||||
reminder.restartable,
|
||||
reminder.tts,
|
||||
reminder.username,
|
||||
reminder.utc_time,
|
||||
reminder.uid
|
||||
)
|
||||
.execute(pool.inner())
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
json!({})
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error in `overwrite_reminder`: {:?}", e);
|
||||
|
||||
json!({"error": "Could not modify reminder"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[delete("/api/user/reminders", data = "<reminder>")]
|
||||
pub async fn delete_reminder(
|
||||
reminder: Json<DeleteReminder>,
|
||||
pool: &State<Pool<MySql>>,
|
||||
) -> JsonValue {
|
||||
if sqlx::query!("DELETE FROM reminders WHERE uid = ?", reminder.uid)
|
||||
.execute(pool.inner())
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
json!({})
|
||||
} else {
|
||||
json!({"error": "Could not delete reminder"})
|
||||
}
|
||||
}
|
149
web/src/routes/login.rs
Normal file
149
web/src/routes/login.rs
Normal file
@ -0,0 +1,149 @@
|
||||
use crate::consts::DISCORD_API;
|
||||
use log::warn;
|
||||
use oauth2::basic::BasicClient;
|
||||
use oauth2::reqwest::async_http_client;
|
||||
use oauth2::{
|
||||
AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use rocket::http::private::cookie::Expiration;
|
||||
use rocket::http::{Cookie, CookieJar, SameSite};
|
||||
use rocket::response::{Flash, Redirect};
|
||||
use rocket::uri;
|
||||
use rocket::State;
|
||||
use serenity::model::user::User;
|
||||
|
||||
#[get("/discord")]
|
||||
pub async fn discord_login(
|
||||
oauth2_client: &State<BasicClient>,
|
||||
cookies: &CookieJar<'_>,
|
||||
) -> Redirect {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
let (auth_url, csrf_token) = oauth2_client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
// Set the desired scopes.
|
||||
.add_scope(Scope::new("identify".to_string()))
|
||||
.add_scope(Scope::new("guilds".to_string()))
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
// Set the PKCE code challenge.
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.url();
|
||||
|
||||
// store the pkce secret to verify the authorization later
|
||||
cookies.add_private(
|
||||
Cookie::build("verify", pkce_verifier.secret().to_string())
|
||||
.http_only(true)
|
||||
.path("/login")
|
||||
.same_site(SameSite::Lax)
|
||||
.expires(Expiration::Session)
|
||||
.finish(),
|
||||
);
|
||||
|
||||
// store the csrf token to verify no interference
|
||||
cookies.add_private(
|
||||
Cookie::build("csrf", csrf_token.secret().to_string())
|
||||
.http_only(true)
|
||||
.path("/login")
|
||||
.same_site(SameSite::Lax)
|
||||
.expires(Expiration::Session)
|
||||
.finish(),
|
||||
);
|
||||
|
||||
Redirect::to(auth_url.to_string())
|
||||
}
|
||||
|
||||
#[get("/discord/authorized?<code>&<state>")]
|
||||
pub async fn discord_callback(
|
||||
code: &str,
|
||||
state: &str,
|
||||
cookies: &CookieJar<'_>,
|
||||
oauth2_client: &State<BasicClient>,
|
||||
reqwest_client: &State<Client>,
|
||||
) -> Result<Redirect, Flash<Redirect>> {
|
||||
if let (Some(pkce_secret), Some(csrf_token)) =
|
||||
(cookies.get_private("verify"), cookies.get_private("csrf"))
|
||||
{
|
||||
if state == csrf_token.value() {
|
||||
let token_result = oauth2_client
|
||||
.exchange_code(AuthorizationCode::new(code.to_string()))
|
||||
// Set the PKCE code verifier.
|
||||
.set_pkce_verifier(PkceCodeVerifier::new(pkce_secret.value().to_string()))
|
||||
.request_async(async_http_client)
|
||||
.await;
|
||||
|
||||
cookies.remove_private(Cookie::named("verify"));
|
||||
cookies.remove_private(Cookie::named("csrf"));
|
||||
|
||||
match token_result {
|
||||
Ok(token) => {
|
||||
cookies.add_private(
|
||||
Cookie::build("access_token", token.access_token().secret().to_string())
|
||||
.secure(true)
|
||||
.http_only(true)
|
||||
.path("/dashboard")
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let request_res = reqwest_client
|
||||
.get(format!("{}/users/@me", DISCORD_API))
|
||||
.bearer_auth(token.access_token().secret())
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match request_res {
|
||||
Ok(response) => {
|
||||
let user_res = response.json::<User>().await;
|
||||
|
||||
match user_res {
|
||||
Ok(user) => {
|
||||
let user_name = format!("{}#{}", user.name, user.discriminator);
|
||||
let user_id = user.id.as_u64().to_string();
|
||||
|
||||
cookies.add_private(Cookie::new("username", user_name));
|
||||
cookies.add_private(Cookie::new("userid", user_id));
|
||||
|
||||
Ok(Redirect::to(uri!(super::return_to_same_site("dashboard"))))
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
warn!("Error constructing user from request: {:?}", e);
|
||||
|
||||
Err(Flash::new(
|
||||
Redirect::to(uri!(super::return_to_same_site(""))),
|
||||
"danger",
|
||||
"Failed to contact Discord",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
warn!("Error getting user info: {:?}", e);
|
||||
|
||||
Err(Flash::new(
|
||||
Redirect::to(uri!(super::return_to_same_site(""))),
|
||||
"danger",
|
||||
"Failed to contact Discord",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
warn!("Error in discord callback: {:?}", e);
|
||||
|
||||
Err(Flash::new(
|
||||
Redirect::to(uri!(super::return_to_same_site(""))),
|
||||
"warning",
|
||||
"Your login request was rejected",
|
||||
))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(Flash::new(Redirect::to(uri!(super::return_to_same_site(""))), "danger", "Your request failed to validate, and so has been rejected (error: CSRF Validation Failure)"))
|
||||
}
|
||||
} else {
|
||||
Err(Flash::new(Redirect::to(uri!(super::return_to_same_site(""))), "warning", "Your request was missing information, and so has been rejected (error: CSRF Validation Tokens Missing)"))
|
||||
}
|
||||
}
|
51
web/src/routes/mod.rs
Normal file
51
web/src/routes/mod.rs
Normal file
@ -0,0 +1,51 @@
|
||||
pub mod dashboard;
|
||||
pub mod login;
|
||||
|
||||
use rocket::request::FlashMessage;
|
||||
use rocket_dyn_templates::Template;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[get("/")]
|
||||
pub async fn index(flash: Option<FlashMessage<'_>>) -> Template {
|
||||
let mut map: HashMap<&str, String> = HashMap::new();
|
||||
|
||||
if let Some(message) = flash {
|
||||
map.insert("flashed_message", message.message().to_string());
|
||||
map.insert("flashed_grade", message.kind().to_string());
|
||||
}
|
||||
|
||||
Template::render("index", &map)
|
||||
}
|
||||
|
||||
#[get("/ret?<to>")]
|
||||
pub async fn return_to_same_site(to: &str) -> Template {
|
||||
let mut map: HashMap<&str, String> = HashMap::new();
|
||||
|
||||
map.insert("to", to.to_string());
|
||||
|
||||
Template::render("return", &map)
|
||||
}
|
||||
|
||||
#[get("/cookies")]
|
||||
pub async fn cookies() -> Template {
|
||||
let map: HashMap<&str, String> = HashMap::new();
|
||||
Template::render("cookies", &map)
|
||||
}
|
||||
|
||||
#[get("/privacy")]
|
||||
pub async fn privacy() -> Template {
|
||||
let map: HashMap<&str, String> = HashMap::new();
|
||||
Template::render("privacy", &map)
|
||||
}
|
||||
|
||||
#[get("/terms")]
|
||||
pub async fn terms() -> Template {
|
||||
let map: HashMap<&str, String> = HashMap::new();
|
||||
Template::render("terms", &map)
|
||||
}
|
||||
|
||||
#[get("/help")]
|
||||
pub async fn help() -> Template {
|
||||
let map: HashMap<&str, String> = HashMap::new();
|
||||
Template::render("help", &map)
|
||||
}
|
Reference in New Issue
Block a user