create reminder route. formatting on frontend

This commit is contained in:
jude
2022-03-05 19:43:02 +00:00
parent 6ae2353c92
commit e2e5b022a0
12 changed files with 564 additions and 53 deletions

View File

@ -2,3 +2,50 @@ pub const DISCORD_OAUTH_TOKEN: &'static str = "https://discord.com/api/oauth2/to
pub const DISCORD_OAUTH_AUTHORIZE: &'static str = "https://discord.com/api/oauth2/authorize";
pub const DISCORD_API: &'static str = "https://discord.com/api";
pub const DISCORD_CDN: &'static str = "https://cdn.discordapp.com/avatars";
pub const MAX_CONTENT_LENGTH: usize = 2000;
pub const MAX_EMBED_DESCRIPTION_LENGTH: usize = 4096;
pub const MAX_EMBED_TITLE_LENGTH: usize = 256;
pub const MAX_EMBED_AUTHOR_LENGTH: usize = 256;
pub const MAX_EMBED_FOOTER_LENGTH: usize = 2048;
pub const MAX_URL_LENGTH: usize = 512;
pub const MAX_USERNAME_LENGTH: usize = 100;
pub const MAX_EMBED_FIELD_TITLE_LENGTH: usize = 256;
pub const MAX_EMBED_FIELD_VALUE_LENGTH: usize = 1024;
pub const MAX_EMBED_FIELDS: usize = 25;
pub const MINUTE: usize = 60;
pub const HOUR: usize = 60 * MINUTE;
pub const DAY: usize = 24 * HOUR;
use std::{collections::HashSet, env, iter::FromIterator};
use lazy_static::lazy_static;
use serenity::model::prelude::AttachmentType;
lazy_static! {
pub static ref DEFAULT_AVATAR: AttachmentType<'static> = (
include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../assets/",
env!("WEBHOOK_AVATAR", "WEBHOOK_AVATAR not provided for compilation")
)) as &[u8],
env!("WEBHOOK_AVATAR"),
)
.into();
pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(
env::var("SUBSCRIPTION_ROLES")
.map(|var| var
.split(',')
.filter_map(|item| { item.parse::<u64>().ok() })
.collect::<Vec<u64>>())
.unwrap_or_else(|_| Vec::new())
);
pub static ref CNC_GUILD: Option<u64> =
env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten();
pub static ref MIN_INTERVAL: u32 = env::var("MIN_INTERVAL")
.ok()
.map(|inner| inner.parse::<u32>().ok())
.flatten()
.unwrap_or(600);
}

View File

@ -2,6 +2,8 @@
extern crate rocket;
mod consts;
#[macro_use]
mod macros;
mod routes;
use std::{collections::HashMap, env};
@ -9,13 +11,23 @@ use std::{collections::HashMap, env};
use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use rocket::fs::FileServer;
use rocket_dyn_templates::Template;
use serenity::client::Context;
use serenity::{
client::Context,
http::CacheHttp,
model::id::{GuildId, UserId},
};
use sqlx::{MySql, Pool};
use crate::consts::{DISCORD_OAUTH_AUTHORIZE, DISCORD_OAUTH_TOKEN};
use crate::consts::{CNC_GUILD, DISCORD_OAUTH_AUTHORIZE, DISCORD_OAUTH_TOKEN, SUBSCRIPTION_ROLES};
type Database = MySql;
#[derive(Debug)]
enum Error {
SQLx(sqlx::Error),
serenity(serenity::Error),
}
#[catch(401)]
async fn not_authorized() -> Template {
let map: HashMap<String, String> = HashMap::new();
@ -98,3 +110,34 @@ pub async fn initialize(
Ok(())
}
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await;
if let Ok(member) = guild_member {
for role in member.roles {
if SUBSCRIPTION_ROLES.contains(role.as_u64()) {
return true;
}
}
}
false
} else {
true
}
}
pub async fn check_guild_subscription(
cache_http: impl CacheHttp,
guild_id: impl Into<GuildId>,
) -> bool {
if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) {
let owner = guild.owner_id;
check_subscription(&cache_http, owner).await
} else {
false
}
}

47
web/src/macros.rs Normal file
View File

@ -0,0 +1,47 @@
macro_rules! check_length {
($max:ident, $field:expr) => {
if $field.len() > $max {
return json!({ "error": format!("{} exceeded", stringify!($max)) });
}
};
($max:ident, $field:expr, $($fields:expr),+) => {
check_length!($max, $field);
check_length!($max, $($fields),+);
};
}
macro_rules! check_length_opt {
($max:ident, $field:expr) => {
if let Some(field) = &$field {
check_length!($max, field);
}
};
($max:ident, $field:expr, $($fields:expr),+) => {
check_length_opt!($max, $field);
check_length_opt!($max, $($fields),+);
};
}
macro_rules! check_url {
($field:expr) => {
if $field.starts_with("http://") || $field.starts_with("https://") {
return json!({ "error": "URL invalid" });
}
};
($field:expr, $($fields:expr),+) => {
check_url!($max, $field);
check_url!($max, $($fields),+);
};
}
macro_rules! check_url_opt {
($field:expr) => {
if let Some(field) = &$field {
check_url!(field);
}
};
($field:expr, $($fields:expr),+) => {
check_url_opt!($field);
check_url_opt!($($fields),+);
};
}

View File

@ -1,13 +1,25 @@
use chrono::Utc;
use rocket::{
http::CookieJar,
serde::json::{json, Json, Value as JsonValue},
State,
};
use serde::Serialize;
use serenity::{client::Context, model::id::GuildId};
use serenity::{
client::Context,
model::id::{ChannelId, GuildId},
};
use sqlx::{MySql, Pool};
use super::Reminder;
use crate::consts::DISCORD_CDN;
use crate::{
check_guild_subscription, check_subscription,
consts::{
DAY, DISCORD_CDN, MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH,
MAX_EMBED_DESCRIPTION_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH,
MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL,
},
routes::dashboard::{create_database_channel, DeleteReminder, Reminder},
};
#[derive(Serialize)]
struct ChannelInfo {
@ -108,10 +120,178 @@ pub async fn get_guild_roles(id: u64, ctx: &State<Context>) -> JsonValue {
pub async fn create_reminder(
id: u64,
reminder: Json<Reminder>,
cookies: &CookieJar<'_>,
serenity_context: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonValue {
json!({"error": "Not implemented"})
// get userid from cookies
let user_id = cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten();
if user_id.is_none() {
return json!({"error": "User not authorized"});
}
let user_id = user_id.unwrap();
// validate channel
let channel = ChannelId(reminder.channel).to_channel_cached(&serenity_context.inner());
let channel_exists = channel.is_some();
let channel_matches_guild =
channel.map_or(false, |c| c.guild().map_or(false, |c| c.guild_id.0 == id));
if !channel_matches_guild || !channel_exists {
warn!(
"Error in `create_reminder`: channel {} not found for guild {} (channel exists: {})",
reminder.channel, id, channel_exists
);
return json!({"error": "Channel not found"});
}
let channel = create_database_channel(
serenity_context.inner(),
ChannelId(reminder.channel),
pool.inner(),
)
.await;
if let Err(e) = channel {
warn!("`create_database_channel` returned an error code: {:?}", e);
return json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"});
}
let channel = channel.unwrap();
// validate lengths
check_length!(MAX_CONTENT_LENGTH, reminder.content);
check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder.embed_description);
check_length!(MAX_EMBED_TITLE_LENGTH, reminder.embed_title);
check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder.embed_author);
check_length!(MAX_EMBED_FOOTER_LENGTH, reminder.embed_footer);
check_length_opt!(MAX_USERNAME_LENGTH, reminder.username);
check_length_opt!(
MAX_URL_LENGTH,
reminder.embed_footer_url,
reminder.embed_thumbnail_url,
reminder.embed_author_url,
reminder.embed_image_url,
reminder.avatar
);
// validate urls
check_url_opt!(
reminder.embed_footer_url,
reminder.embed_thumbnail_url,
reminder.embed_author_url,
reminder.embed_image_url,
reminder.avatar
);
// validate time and interval
if reminder.utc_time < Utc::now().naive_utc() {
return json!({"error": "Time must be in the future"});
}
if reminder.interval_months.unwrap_or(0) * 30 * DAY as u32
+ reminder.interval_seconds.unwrap_or(0)
< *MIN_INTERVAL
{
return json!({"error": "Interval too short"});
}
// check patreon if necessary
if reminder.interval_seconds.is_some() || reminder.interval_months.is_some() {
if !check_guild_subscription(serenity_context.inner(), GuildId(id)).await
&& !check_subscription(serenity_context.inner(), user_id).await
{
return json!({"error": "Patreon is required to set intervals"});
}
}
// write to db
match sqlx::query!(
"INSERT INTO reminders (
channel_id,
avatar,
content,
embed_author,
embed_author_url,
embed_color,
embed_description,
embed_footer,
embed_footer_url,
embed_image_url,
embed_thumbnail_url,
embed_title,
enabled,
expires,
interval_seconds,
interval_months,
name,
pin,
restartable,
tts,
username,
`utc_time`
) VALUES (
channel_id = ?,
avatar = ?,
content = ?,
embed_author = ?,
embed_author_url = ?,
embed_color = ?,
embed_description = ?,
embed_footer = ?,
embed_footer_url = ?,
embed_image_url = ?,
embed_thumbnail_url = ?,
embed_title = ?,
enabled = ?,
expires = ?,
interval_seconds = ?,
interval_months = ?,
name = ?,
pin = ?,
restartable = ?,
tts = ?,
username = ?,
`utc_time` = ?
)",
channel,
reminder.avatar,
reminder.content,
reminder.embed_author,
reminder.embed_author_url,
reminder.embed_color,
reminder.embed_description,
reminder.embed_footer,
reminder.embed_footer_url,
reminder.embed_image_url,
reminder.embed_thumbnail_url,
reminder.embed_title,
reminder.enabled,
reminder.expires,
reminder.interval_seconds,
reminder.interval_months,
reminder.name,
reminder.pin,
reminder.restartable,
reminder.tts,
reminder.username,
reminder.utc_time,
)
.execute(pool.inner())
.await
{
Ok(_) => json!({}),
Err(e) => {
warn!("Error in `create_reminder`: Could not execute query: {:?}", e);
json!({"error": "Unknown error"})
}
}
}
#[get("/api/guild/<id>/reminders")]
@ -197,8 +377,21 @@ pub async fn edit_reminder(
#[delete("/api/guild/<id>/reminders", data = "<reminder>")]
pub async fn delete_reminder(
id: u64,
reminder: Json<Reminder>,
reminder: Json<DeleteReminder>,
pool: &State<Pool<MySql>>,
) -> JsonValue {
json!({"error": "Not implemented"})
match sqlx::query!("DELETE FROM reminders WHERE uid = ?", reminder.uid)
.execute(pool.inner())
.await
{
Ok(_) => {
json!({})
}
Err(e) => {
warn!("Error in `delete_reminder`: {:?}", e);
json!({"error": "Could not delete reminder"})
}
}
}

View File

@ -1,9 +1,17 @@
use std::collections::HashMap;
use chrono::naive::NaiveDateTime;
use rocket::http::CookieJar;
use rocket::response::Redirect;
use rocket::{http::CookieJar, response::Redirect};
use rocket_dyn_templates::Template;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serenity::{
client::Context,
http::{CacheHttp, Http},
model::id::ChannelId,
};
use sqlx::{Executor, Pool};
use crate::{consts::DEFAULT_AVATAR, Database, Error};
pub mod guild;
pub mod user;
@ -46,8 +54,7 @@ pub struct Reminder {
// https://github.com/serde-rs/json/issues/329#issuecomment-305608405
mod string {
use std::fmt::Display;
use std::str::FromStr;
use std::{fmt::Display, str::FromStr};
use serde::{de, Deserialize, Deserializer, Serializer};
@ -74,6 +81,78 @@ pub struct DeleteReminder {
uid: String,
}
async fn create_database_channel(
ctx: impl AsRef<Http>,
channel: ChannelId,
pool: impl Executor<'_, Database = Database> + Copy,
) -> Result<u32, crate::Error> {
let row =
sqlx::query!("SELECT webhook_token, webhook_id FROM channels WHERE channel = ?", channel.0)
.fetch_one(pool)
.await;
match row {
Ok(row) => {
if row.webhook_token.is_none() || row.webhook_id.is_none() {
let webhook = channel
.create_webhook_with_avatar(&ctx, "Reminder", DEFAULT_AVATAR.clone())
.await
.map_err(|e| Error::serenity(e))?;
sqlx::query!(
"UPDATE channels SET webhook_id = ?, webhook_token = ? WHERE channel = ?",
webhook.id.0,
webhook.token,
channel.0
)
.execute(pool)
.await
.map_err(|e| Error::SQLx(e))?;
}
Ok(())
}
Err(sqlx::Error::RowNotFound) => {
// create webhook
let webhook = channel
.create_webhook_with_avatar(&ctx, "Reminder", DEFAULT_AVATAR.clone())
.await
.map_err(|e| Error::serenity(e))?;
// create database entry
sqlx::query!(
"INSERT INTO channels (
webhook_id,
webhook_token,
channel
) VALUES (
webhook_id = ?,
webhook_token = ?,
channel = ?
)",
webhook.id.0,
webhook.token,
channel.0
)
.execute(pool)
.await
.map_err(|e| Error::SQLx(e))?;
Ok(())
}
Err(e) => Err(Error::SQLx(e)),
}?;
let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.0)
.fetch_one(pool)
.await
.map_err(|e| Error::SQLx(e))?;
Ok(row.id)
}
#[get("/")]
pub async fn dashboard_home(cookies: &CookieJar<'_>) -> Result<Template, Redirect> {
if cookies.get_private("userid").is_some() {

View File

@ -1,18 +1,18 @@
use crate::consts::DISCORD_API;
use log::warn;
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::{
AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse,
basic::BasicClient, reqwest::async_http_client, AuthorizationCode, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse,
};
use reqwest::Client;
use rocket::http::private::cookie::Expiration;
use rocket::http::{Cookie, CookieJar, SameSite};
use rocket::response::{Flash, Redirect};
use rocket::uri;
use rocket::State;
use rocket::{
http::{private::cookie::Expiration, Cookie, CookieJar, SameSite},
response::{Flash, Redirect},
uri, State,
};
use serenity::model::user::User;
use crate::consts::DISCORD_API;
#[get("/discord")]
pub async fn discord_login(
oauth2_client: &State<BasicClient>,