Move postman and web inside src

This commit is contained in:
jude
2024-03-24 20:23:16 +00:00
parent 075fde71df
commit 4a80d42f86
152 changed files with 833 additions and 779 deletions

View File

@ -1,4 +1,4 @@
use chrono::NaiveDateTime;
use chrono::DateTime;
use serde::{Deserialize, Serialize};
use crate::{
@ -24,10 +24,10 @@ impl Recordable for Options {
let parsed = natural_parser(&until, &timezone.to_string()).await;
if let Some(timestamp) = parsed {
match NaiveDateTime::from_timestamp_opt(timestamp, 0) {
match DateTime::from_timestamp(timestamp, 0) {
Some(dt) => {
channel.paused = true;
channel.paused_until = Some(dt);
channel.paused_until = Some(dt.naive_utc());
channel.commit_changes(&ctx.data().database).await;

View File

@ -12,12 +12,15 @@ mod event_handlers;
#[cfg(not(test))]
mod hooks;
mod interval_parser;
mod metrics;
#[cfg(not(test))]
mod models;
mod postman;
#[cfg(test)]
mod test;
mod time_parser;
mod utils;
mod web;
use std::{
collections::HashMap,
@ -28,7 +31,7 @@ use std::{
};
use chrono_tz::Tz;
use log::{error, warn};
use log::warn;
use poise::serenity_prelude::{
model::{
gateway::GatewayIntents,
@ -39,6 +42,7 @@ use poise::serenity_prelude::{
use sqlx::{MySql, Pool};
use tokio::sync::{broadcast, broadcast::Sender, RwLock};
use crate::metrics::init_metrics;
#[cfg(test)]
use crate::test::TestContext;
#[cfg(not(test))]
@ -206,6 +210,10 @@ async fn _main(tx: Sender<()>) -> Result<(), Box<dyn StdError + Send + Sync>> {
..Default::default()
};
// Start metrics
init_metrics();
tokio::spawn(async { metrics::serve().await });
let database =
Pool::connect(&env::var("DATABASE_URL").expect("No database URL provided")).await.unwrap();
@ -249,7 +257,7 @@ async fn _main(tx: Sender<()>) -> Result<(), Box<dyn StdError + Send + Sync>> {
match postman::initialize(kill_recv, ctx1, &pool1).await {
Ok(_) => {}
Err(e) => {
error!("postman exiting: {}", e);
panic!("postman exiting: {}", e);
}
};
});
@ -259,7 +267,7 @@ async fn _main(tx: Sender<()>) -> Result<(), Box<dyn StdError + Send + Sync>> {
if !run_settings.contains("web") {
tokio::spawn(async move {
reminder_web::initialize(kill_tx, ctx2, pool2).await.unwrap();
web::initialize(kill_tx, ctx2, pool2).await.unwrap();
});
} else {
warn!("Not running web");

36
src/metrics.rs Normal file
View File

@ -0,0 +1,36 @@
use axum::{routing::get, Router};
use lazy_static::lazy_static;
use log::warn;
use prometheus::{IntCounterVec, Opts, Registry};
lazy_static! {
pub static ref REGISTRY: Registry = Registry::new();
pub static ref REQUEST_COUNTER: IntCounterVec =
IntCounterVec::new(Opts::new("requests", "Requests"), &["method", "status", "route"])
.unwrap();
}
pub fn init_metrics() {
REGISTRY.register(Box::new(REQUEST_COUNTER.clone())).unwrap();
}
pub async fn serve() {
let app = Router::new().route("/metrics", get(metrics));
let listener = tokio::net::TcpListener::bind("localhost:31756").await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn metrics() -> String {
let encoder = prometheus::TextEncoder::new();
let res_custom = encoder.encode_to_string(&REGISTRY.gather());
match res_custom {
Ok(s) => s,
Err(e) => {
warn!("Error encoding metrics: {:?}", e);
String::new()
}
}
}

View File

@ -1,6 +1,6 @@
use std::collections::HashSet;
use chrono::{Duration, NaiveDateTime, Utc};
use chrono::{DateTime, NaiveDateTime, TimeDelta, Utc};
use chrono_tz::Tz;
use poise::serenity_prelude::{
model::id::{ChannelId, GuildId, UserId},
@ -73,7 +73,7 @@ impl ReminderBuilder {
match queried_time.utc_time {
Some(utc_time) => {
if utc_time < (Utc::now() - Duration::seconds(60)).naive_local() {
if utc_time < (Utc::now() - TimeDelta::try_minutes(1).unwrap()).naive_local() {
Err(ReminderError::PastTime)
} else {
sqlx::query!(
@ -165,7 +165,7 @@ impl<'a> MultiReminderBuilder<'a> {
}
pub fn time<T: Into<i64>>(mut self, time: T) -> Self {
if let Some(utc_time) = NaiveDateTime::from_timestamp_opt(time.into(), 0) {
if let Some(utc_time) = DateTime::from_timestamp(time.into(), 0).map(|d| d.naive_utc()) {
self.utc_time = utc_time;
}
@ -173,7 +173,8 @@ impl<'a> MultiReminderBuilder<'a> {
}
pub fn expires<T: Into<i64>>(mut self, time: Option<T>) -> Self {
self.expires = time.map(|t| NaiveDateTime::from_timestamp_opt(t.into(), 0)).flatten();
self.expires =
time.map(|t| DateTime::from_timestamp(t.into(), 0)).flatten().map(|d| d.naive_utc());
self
}

View File

@ -38,6 +38,7 @@ use crate::{
};
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct Reminder {
pub id: u32,
pub uid: String,

View File

@ -4,6 +4,7 @@ use sqlx::MySqlPool;
pub struct Timer {
pub name: String,
pub start_time: DateTime<Utc>,
#[allow(dead_code)]
pub owner: u64,
}

View File

@ -7,6 +7,7 @@ use crate::consts::LOCAL_TIMEZONE;
pub struct UserData {
pub id: u32,
#[allow(dead_code)]
pub user: u64,
pub dm_channel: u32,
pub timezone: String,
@ -22,7 +23,7 @@ impl UserData {
match sqlx::query!(
"
SELECT IFNULL(timezone, 'UTC') AS timezone FROM users WHERE user = ?
SELECT IFNULL(timezone, 'UTC') AS timezone FROM users WHERE user = ?
",
user_id
)

40
src/postman/metrics.rs Normal file
View File

@ -0,0 +1,40 @@
use axum::{routing::get, Router};
use lazy_static;
use log::warn;
use prometheus::{register_int_counter, IntCounter, Registry};
lazy_static! {
static ref REGISTRY: Registry = Registry::new();
pub static ref REMINDERS_SENT: IntCounter =
register_int_counter!("reminders_sent", "Number of reminders sent").unwrap();
pub static ref REMINDERS_FAILED: IntCounter =
register_int_counter!("reminders_failed", "Number of reminders failed").unwrap();
}
pub fn init_metrics() {
REGISTRY.register(Box::new(REMINDERS_SENT.clone())).unwrap();
REGISTRY.register(Box::new(REMINDERS_FAILED.clone())).unwrap();
}
pub async fn serve() {
let app = Router::new().route("/metrics", get(metrics));
let metrics_port = std::env("PROMETHEUS_PORT").unwrap();
let listener =
tokio::net::TcpListener::bind(format!("localhost:{}", metrics_port)).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn metrics() -> String {
let encoder = prometheus::TextEncoder::new();
let res_custom = encoder.encode_to_string(&REGISTRY.gather());
match res_custom {
Ok(s) => s,
Err(e) => {
warn!("Error encoding metrics: {:?}", e);
String::new()
}
}
}

50
src/postman/mod.rs Normal file
View File

@ -0,0 +1,50 @@
mod sender;
use std::env;
use log::{info, warn};
use poise::serenity_prelude::client::Context;
use sqlx::{Executor, MySql};
use tokio::{
sync::broadcast::Receiver,
time::{sleep_until, Duration, Instant},
};
type Database = MySql;
pub async fn initialize(
mut kill: Receiver<()>,
ctx: Context,
pool: impl Executor<'_, Database = Database> + Copy,
) -> Result<(), &'static str> {
tokio::select! {
output = _initialize(ctx, pool) => Ok(output),
_ = kill.recv() => {
warn!("Received terminate signal. Goodbye");
Err("Received terminate signal. Goodbye")
}
}
}
async fn _initialize(ctx: Context, pool: impl Executor<'_, Database = Database> + Copy) {
let remind_interval = env::var("REMIND_INTERVAL")
.map(|inner| inner.parse::<u64>().ok())
.ok()
.flatten()
.unwrap_or(10);
loop {
let sleep_to = Instant::now() + Duration::from_secs(remind_interval);
let reminders = sender::Reminder::fetch_reminders(pool).await;
if reminders.len() > 0 {
info!("Preparing to send {} reminders.", reminders.len());
for reminder in reminders {
reminder.send(pool, ctx.clone()).await;
}
}
sleep_until(sleep_to).await;
}
}

703
src/postman/sender.rs Normal file
View File

@ -0,0 +1,703 @@
use std::env;
use chrono::{DateTime, Days, Months, TimeDelta};
use chrono_tz::Tz;
use lazy_static::lazy_static;
use log::{error, info, warn};
use num_integer::Integer;
use poise::serenity_prelude::{
all::{CreateAttachment, CreateEmbedFooter},
builder::{CreateEmbed, CreateEmbedAuthor, CreateMessage, ExecuteWebhook},
http::{CacheHttp, Http, HttpError},
model::{
channel::Channel,
id::{ChannelId, MessageId},
webhook::Webhook,
},
Error, Result,
};
use regex::{Captures, Regex};
use serde::Deserialize;
use sqlx::{
types::{
chrono::{NaiveDateTime, Utc},
Json,
},
Executor,
};
use crate::Database;
lazy_static! {
pub static ref TIMEFROM_REGEX: Regex =
Regex::new(r#"<<timefrom:(?P<time>\d+):(?P<format>.+)?>>"#).unwrap();
pub static ref TIMENOW_REGEX: Regex =
Regex::new(r#"<<timenow:(?P<timezone>(?:\w|/|_)+):(?P<format>.+)?>>"#).unwrap();
pub static ref LOG_TO_DATABASE: bool = env::var("LOG_TO_DATABASE").map_or(true, |v| v == "1");
}
fn fmt_displacement(format: &str, seconds: u64) -> String {
let mut seconds = seconds;
let mut days: u64 = 0;
let mut hours: u64 = 0;
let mut minutes: u64 = 0;
for (rep, time_type, div) in
[("%d", &mut days, 86400), ("%h", &mut hours, 3600), ("%m", &mut minutes, 60)].iter_mut()
{
if format.contains(*rep) {
let (divided, new_seconds) = seconds.div_rem(&div);
**time_type = divided;
seconds = new_seconds;
}
}
format
.replace("%s", &seconds.to_string())
.replace("%m", &minutes.to_string())
.replace("%h", &hours.to_string())
.replace("%d", &days.to_string())
}
pub fn substitute(string: &str) -> String {
let new = TIMEFROM_REGEX.replace(string, |caps: &Captures| {
let final_time = caps.name("time").map(|m| m.as_str().parse::<i64>().ok()).flatten();
let format = caps.name("format").map(|m| m.as_str());
if let (Some(final_time), Some(format)) = (final_time, format) {
match DateTime::from_timestamp(final_time, 0) {
Some(dt) => {
let now = Utc::now();
let difference = {
if now < dt {
dt - Utc::now()
} else {
Utc::now() - dt
}
};
fmt_displacement(format, difference.num_seconds() as u64)
}
None => String::new(),
}
} else {
String::new()
}
});
TIMENOW_REGEX
.replace(&new, |caps: &Captures| {
let timezone = caps.name("timezone").map(|m| m.as_str().parse::<Tz>().ok()).flatten();
let format = caps.name("format").map(|m| m.as_str());
if let (Some(timezone), Some(format)) = (timezone, format) {
let now = Utc::now().with_timezone(&timezone);
now.format(format).to_string()
} else {
String::new()
}
})
.to_string()
}
struct Embed {
title: String,
description: String,
image_url: Option<String>,
thumbnail_url: Option<String>,
footer: String,
footer_url: Option<String>,
author: String,
author_url: Option<String>,
color: u32,
fields: Json<Vec<EmbedField>>,
}
#[derive(Deserialize)]
struct EmbedField {
title: String,
value: String,
inline: bool,
}
impl Embed {
pub async fn from_id(
pool: impl Executor<'_, Database = Database> + Copy,
id: u32,
) -> Option<Self> {
match sqlx::query_as!(
Self,
r#"
SELECT
`embed_title` AS title,
`embed_description` AS description,
`embed_image_url` AS image_url,
`embed_thumbnail_url` AS thumbnail_url,
`embed_footer` AS footer,
`embed_footer_url` AS footer_url,
`embed_author` AS author,
`embed_author_url` AS author_url,
`embed_color` AS color,
IFNULL(`embed_fields`, '[]') AS "fields:_"
FROM reminders
WHERE `id` = ?"#,
id
)
.fetch_one(pool)
.await
{
Ok(mut embed) => {
embed.title = substitute(&embed.title);
embed.description = substitute(&embed.description);
embed.footer = substitute(&embed.footer);
embed.fields.iter_mut().for_each(|field| {
field.title = substitute(&field.title);
field.value = substitute(&field.value);
});
if embed.has_content() {
Some(embed)
} else {
None
}
}
Err(e) => {
warn!("Error loading embed from reminder: {:?}", e);
None
}
}
}
pub fn has_content(&self) -> bool {
if self.title.is_empty()
&& self.description.is_empty()
&& self.image_url.is_none()
&& self.thumbnail_url.is_none()
&& self.footer.is_empty()
&& self.footer_url.is_none()
&& self.author.is_empty()
&& self.author_url.is_none()
&& self.fields.0.is_empty()
{
false
} else {
true
}
}
}
impl Into<CreateEmbed> for Embed {
fn into(self) -> CreateEmbed {
let mut author = CreateEmbedAuthor::new(&self.author);
if let Some(author_icon) = &self.author_url {
author = author.icon_url(author_icon);
}
let mut footer = CreateEmbedFooter::new(&self.footer);
if let Some(footer_icon) = &self.footer_url {
footer = footer.icon_url(footer_icon);
}
let mut embed = CreateEmbed::default()
.title(&self.title)
.description(&self.description)
.color(self.color)
.author(author)
.footer(footer);
for field in &self.fields.0 {
embed = embed.field(&field.title, &field.value, field.inline);
}
if let Some(image_url) = &self.image_url {
embed = embed.image(image_url);
}
if let Some(thumbnail_url) = &self.thumbnail_url {
embed = embed.thumbnail(thumbnail_url);
}
embed
}
}
pub struct Reminder {
id: u32,
channel_id: u64,
thread_id: Option<u64>,
webhook_id: Option<u64>,
webhook_token: Option<String>,
channel_paused: bool,
channel_paused_until: Option<NaiveDateTime>,
enabled: bool,
tts: bool,
pin: bool,
content: String,
attachment: Option<Vec<u8>>,
attachment_name: Option<String>,
utc_time: DateTime<Utc>,
timezone: String,
restartable: bool,
expires: Option<DateTime<Utc>>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
avatar: Option<String>,
username: Option<String>,
}
impl Reminder {
pub async fn fetch_reminders(pool: impl Executor<'_, Database = Database> + Copy) -> Vec<Self> {
match sqlx::query_as_unchecked!(
Reminder,
r#"
SELECT
reminders.`id` AS id,
channels.`channel` AS channel_id,
reminders.`thread_id` AS thread_id,
channels.`webhook_id` AS webhook_id,
channels.`webhook_token` AS webhook_token,
channels.`paused` AS 'channel_paused',
channels.`paused_until` AS 'channel_paused_until',
reminders.`enabled` AS 'enabled',
reminders.`tts` AS tts,
reminders.`pin` AS pin,
reminders.`content` AS content,
reminders.`attachment` AS attachment,
reminders.`attachment_name` AS attachment_name,
reminders.`utc_time` AS 'utc_time',
reminders.`timezone` AS timezone,
reminders.`restartable` AS restartable,
reminders.`expires` AS 'expires',
reminders.`interval_seconds` AS 'interval_seconds',
reminders.`interval_days` AS 'interval_days',
reminders.`interval_months` AS 'interval_months',
reminders.`avatar` AS avatar,
reminders.`username` AS username
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
WHERE
reminders.`status` = 'pending' AND
reminders.`id` IN (
SELECT
MIN(id)
FROM
reminders
WHERE
reminders.`utc_time` <= NOW() AND
`status` = 'pending' AND
(
reminders.`interval_seconds` IS NOT NULL
OR reminders.`interval_months` IS NOT NULL
OR reminders.`interval_days` IS NOT NULL
OR reminders.enabled
)
GROUP BY channel_id
)
"#,
)
.fetch_all(pool)
.await
{
Ok(reminders) => reminders
.into_iter()
.map(|mut rem| {
rem.content = substitute(&rem.content);
rem
})
.collect::<Vec<Self>>(),
Err(e) => {
warn!("Could not fetch reminders: {:?}", e);
vec![]
}
}
}
async fn reset_webhook(&self, pool: impl Executor<'_, Database = Database> + Copy) {
let _ = sqlx::query!(
"
UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ?
",
self.channel_id
)
.execute(pool)
.await;
}
async fn refresh(&self, pool: impl Executor<'_, Database = Database> + Copy) {
if self.interval_seconds.is_some()
|| self.interval_months.is_some()
|| self.interval_days.is_some()
{
// If all intervals are zero then dont care
if self.interval_seconds == Some(0)
&& self.interval_days == Some(0)
&& self.interval_months == Some(0)
{
self.set_sent(pool).await;
}
let now = Utc::now();
let mut updated_reminder_time =
self.utc_time.with_timezone(&self.timezone.parse().unwrap_or(Tz::UTC));
let mut fail_count = 0;
while updated_reminder_time < now && fail_count < 4 {
if let Some(interval) = self.interval_months {
if interval != 0 {
updated_reminder_time = updated_reminder_time
.checked_add_months(Months::new(interval))
.unwrap_or_else(|| {
warn!(
"{}: Could not add {} months to a reminder",
interval, self.id
);
fail_count += 1;
updated_reminder_time
});
}
}
if let Some(interval) = self.interval_days {
if interval != 0 {
updated_reminder_time = updated_reminder_time
.checked_add_days(Days::new(interval as u64))
.unwrap_or_else(|| {
warn!("{}: Could not add {} days to a reminder", self.id, interval);
fail_count += 1;
updated_reminder_time
})
}
}
if let Some(interval) = self.interval_seconds {
updated_reminder_time += TimeDelta::try_seconds(interval as i64)
.unwrap_or_else(|| {
warn!("{}: Could not add {} seconds to a reminder", self.id, interval);
fail_count += 1;
TimeDelta::zero()
});
}
}
if fail_count >= 4 {
self.log_error(
"Failed to update 4 times and so is being deleted",
None::<&'static str>,
)
.await;
self.set_failed(pool, "Failed to update 4 times and so is being deleted").await;
} else if self.expires.map_or(false, |expires| updated_reminder_time > expires) {
self.set_sent(pool).await;
} else {
sqlx::query!(
"UPDATE reminders SET `utc_time` = ? WHERE `id` = ?",
updated_reminder_time.with_timezone(&Utc),
self.id
)
.execute(pool)
.await
.expect(&format!("Could not update time on Reminder {}", self.id));
}
} else {
self.set_sent(pool).await;
}
}
async fn log_error(&self, error: &'static str, debug_info: Option<impl std::fmt::Debug>) {
let message = match debug_info {
Some(info) => format!(
"{}
{:?}",
error, info
),
None => error.to_string(),
};
error!("[Reminder {}] {}", self.id, message);
}
async fn log_success(&self) {}
async fn set_sent(&self, pool: impl Executor<'_, Database = Database> + Copy) {
sqlx::query!("UPDATE reminders SET `status` = 'sent' WHERE `id` = ?", self.id)
.execute(pool)
.await
.expect(&format!("Could not delete Reminder {}", self.id));
}
async fn set_failed(
&self,
pool: impl Executor<'_, Database = Database> + Copy,
message: &'static str,
) {
sqlx::query!(
"UPDATE reminders SET `status` = 'failed', `status_message` = ? WHERE `id` = ?",
message,
self.id
)
.execute(pool)
.await
.expect(&format!("Could not delete Reminder {}", self.id));
}
async fn pin_message<M: Into<MessageId>>(&self, message_id: M, http: impl AsRef<Http>) {
let _ = http.as_ref().pin_message(self.channel_id.into(), message_id.into(), None).await;
}
pub async fn send(
&self,
pool: impl Executor<'_, Database = Database> + Copy,
cache_http: impl CacheHttp,
) {
async fn send_to_channel(
cache_http: impl CacheHttp,
reminder: &Reminder,
embed: Option<CreateEmbed>,
) -> Result<()> {
let channel = if let Some(thread_id) = reminder.thread_id {
ChannelId::new(thread_id).to_channel(&cache_http).await
} else {
ChannelId::new(reminder.channel_id).to_channel(&cache_http).await
};
let mut message = CreateMessage::new().content(&reminder.content).tts(reminder.tts);
if let (Some(attachment), Some(name)) =
(&reminder.attachment, &reminder.attachment_name)
{
message =
message.add_file(CreateAttachment::bytes(attachment as &[u8], name.as_str()));
}
if let Some(embed) = embed {
message = message.embed(embed);
}
match channel {
Ok(Channel::Guild(channel)) => {
match channel.send_message(&cache_http, message).await {
Ok(m) => {
if reminder.pin {
reminder.pin_message(m.id, cache_http.http()).await;
}
Ok(())
}
Err(e) => Err(e),
}
}
Ok(Channel::Private(channel)) => {
match channel.send_message(&cache_http.http(), message).await {
Ok(m) => {
if reminder.pin {
reminder.pin_message(m.id, cache_http.http()).await;
}
Ok(())
}
Err(e) => Err(e),
}
}
Err(e) => Err(e),
_ => Err(Error::Other("Channel not of valid type")),
}
}
async fn send_to_webhook(
cache_http: impl CacheHttp,
reminder: &Reminder,
webhook: Webhook,
embed: Option<CreateEmbed>,
) -> Result<()> {
let mut builder = if let Some(thread_id) = reminder.thread_id {
ExecuteWebhook::new()
.content(&reminder.content)
.tts(reminder.tts)
.in_thread(thread_id)
} else {
ExecuteWebhook::new().content(&reminder.content).tts(reminder.tts)
};
if let Some(username) = &reminder.username {
if !username.is_empty() {
builder = builder.username(username);
}
}
if let Some(avatar) = &reminder.avatar {
builder = builder.avatar_url(avatar);
}
if let (Some(attachment), Some(name)) =
(&reminder.attachment, &reminder.attachment_name)
{
builder =
builder.add_file(CreateAttachment::bytes(attachment as &[u8], name.as_str()));
}
if let Some(embed) = embed {
builder = builder.embeds(vec![embed]);
}
match webhook
.execute(&cache_http.http(), reminder.pin || reminder.restartable, builder)
.await
{
Ok(m) => {
if reminder.pin {
if let Some(message) = m {
reminder.pin_message(message.id, cache_http.http()).await;
}
}
Ok(())
}
Err(e) => Err(e),
}
}
if self.enabled
&& !(self.channel_paused
&& self
.channel_paused_until
.map_or(true, |inner| inner >= Utc::now().naive_local()))
{
let _ = sqlx::query!(
"
UPDATE `channels` SET paused = 0, paused_until = NULL WHERE `channel` = ?
",
self.channel_id
)
.execute(pool)
.await;
let embed = Embed::from_id(pool, self.id).await.map(|e| e.into());
let result = if let (Some(webhook_id), Some(webhook_token)) =
(self.webhook_id, &self.webhook_token)
{
let webhook_res = cache_http
.http()
.get_webhook_with_token(webhook_id.into(), webhook_token)
.await;
if let Ok(webhook) = webhook_res {
send_to_webhook(cache_http, &self, webhook, embed).await
} else {
warn!("Webhook vanished for reminder {}: {:?}", self.id, webhook_res);
self.reset_webhook(pool).await;
send_to_channel(cache_http, &self, embed).await
}
} else {
send_to_channel(cache_http, &self, embed).await
};
if let Err(e) = result {
if let Error::Http(error) = e {
if let HttpError::UnsuccessfulRequest(http_error) = error {
match http_error.error.code {
10003 => {
self.log_error(
"Could not be sent as channel does not exist",
None::<&'static str>,
)
.await;
self.set_failed(
pool,
"Could not be sent as channel does not exist",
)
.await;
}
10004 => {
self.log_error(
"Could not be sent as guild does not exist",
None::<&'static str>,
)
.await;
self.set_failed(pool, "Could not be sent as guild does not exist")
.await;
}
50001 => {
self.log_error(
"Could not be sent as missing access",
None::<&'static str>,
)
.await;
self.set_failed(pool, "Could not be sent as missing access").await;
}
50007 => {
self.log_error(
"Could not be sent as user has DMs disabled",
None::<&'static str>,
)
.await;
self.set_failed(pool, "Could not be sent as user has DMs disabled")
.await;
}
50013 => {
self.log_error(
"Could not be sent as permissions are invalid",
None::<&'static str>,
)
.await;
self.set_failed(
pool,
"Could not be sent as permissions are invalid",
)
.await;
}
_ => {
self.log_error("HTTP error sending reminder", Some(http_error))
.await;
self.refresh(pool).await;
}
}
} else {
self.log_error("(Likely) a parsing error", Some(error)).await;
self.refresh(pool).await;
}
} else {
self.log_error("Non-HTTP error", Some(e)).await;
self.refresh(pool).await;
}
} else {
self.log_success().await;
self.refresh(pool).await;
}
} else {
info!("Reminder {} is paused", self.id);
self.refresh(pool).await;
}
}
}

40
src/web/catchers.rs Normal file
View File

@ -0,0 +1,40 @@
use std::collections::HashMap;
use rocket::{catch, serde::json::json};
use rocket_dyn_templates::Template;
use crate::web::JsonValue;
#[catch(403)]
pub(crate) async fn forbidden() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/403", &map)
}
#[catch(500)]
pub(crate) async fn internal_server_error() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/500", &map)
}
#[catch(401)]
pub(crate) async fn not_authorized() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/401", &map)
}
#[catch(404)]
pub(crate) async fn not_found() -> Template {
let map: HashMap<String, String> = HashMap::new();
Template::render("errors/404", &map)
}
#[catch(413)]
pub(crate) async fn payload_too_large() -> JsonValue {
json!({"error": "Data too large.", "errors": ["Data too large."]})
}
#[catch(422)]
pub(crate) async fn unprocessable_entity() -> JsonValue {
json!({"error": "Invalid request.", "errors": ["Invalid request."]})
}

48
src/web/consts.rs Normal file
View File

@ -0,0 +1,48 @@
pub const DISCORD_OAUTH_TOKEN: &'static str = "https://discord.com/api/oauth2/token";
pub const DISCORD_OAUTH_AUTHORIZE: &'static str = "https://discord.com/api/oauth2/authorize";
pub const DISCORD_API: &'static str = "https://discord.com/api";
pub const MAX_NAME_LENGTH: usize = 100;
pub const MAX_CONTENT_LENGTH: usize = 2000;
pub const MAX_EMBED_DESCRIPTION_LENGTH: usize = 4096;
pub const MAX_EMBED_TITLE_LENGTH: usize = 256;
pub const MAX_EMBED_AUTHOR_LENGTH: usize = 256;
pub const MAX_EMBED_FOOTER_LENGTH: usize = 2048;
pub const MAX_URL_LENGTH: usize = 512;
pub const MAX_USERNAME_LENGTH: usize = 100;
pub const MAX_EMBED_FIELDS: usize = 25;
pub const MAX_EMBED_FIELD_TITLE_LENGTH: usize = 256;
pub const MAX_EMBED_FIELD_VALUE_LENGTH: usize = 1024;
pub const MINUTE: usize = 60;
pub const HOUR: usize = 60 * MINUTE;
pub const DAY: usize = 24 * HOUR;
pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
use std::{collections::HashSet, env};
use lazy_static::lazy_static;
use serenity::builder::CreateAttachment;
lazy_static! {
pub static ref DEFAULT_AVATAR: CreateAttachment = CreateAttachment::bytes(
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/webhook.jpg")) as &[u8],
"webhook.jpg",
);
pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(
env::var("PATREON_ROLE_ID")
.map(|var| var
.split(',')
.filter_map(|item| { item.parse::<u64>().ok() })
.collect::<Vec<u64>>())
.unwrap_or_else(|_| Vec::new())
);
pub static ref CNC_GUILD: Option<u64> =
env::var("PATREON_GUILD_ID").map(|var| var.parse::<u64>().ok()).ok().flatten();
pub static ref MIN_INTERVAL: u32 = env::var("MIN_INTERVAL")
.ok()
.map(|inner| inner.parse::<u32>().ok())
.flatten()
.unwrap_or(600);
}

1
src/web/guards/mod.rs Normal file
View File

@ -0,0 +1 @@
pub(crate) mod transaction;

View File

@ -0,0 +1,43 @@
use rocket::{
http::Status,
request::{FromRequest, Outcome},
Request, State,
};
use sqlx::Pool;
use crate::Database;
pub struct Transaction<'a>(sqlx::Transaction<'a, Database>);
impl Transaction<'_> {
pub fn executor(&mut self) -> impl sqlx::Executor<'_, Database = Database> {
&mut *(self.0)
}
pub async fn commit(self) -> Result<(), sqlx::Error> {
self.0.commit().await
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub enum TransactionError {
Error(sqlx::Error),
Missing,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Transaction<'r> {
type Error = TransactionError;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.guard::<&State<Pool<Database>>>().await {
Outcome::Success(pool) => match pool.begin().await {
Ok(transaction) => Outcome::Success(Transaction(transaction)),
Err(e) => Outcome::Error((Status::InternalServerError, TransactionError::Error(e))),
},
Outcome::Error(e) => Outcome::Error((e.0, TransactionError::Missing)),
Outcome::Forward(f) => Outcome::Forward(f),
}
}
}

111
src/web/macros.rs Normal file
View File

@ -0,0 +1,111 @@
macro_rules! offline {
($field:expr) => {
if std::env::var("OFFLINE").map_or(false, |v| v == "1") {
return $field;
}
};
}
macro_rules! check_length {
($max:ident, $field:expr) => {
if $field.len() > $max {
return Err(json!({ "error": format!("{} exceeded", stringify!($max)) }));
}
};
($max:ident, $field:expr, $($fields:expr),+) => {
check_length!($max, $field);
check_length!($max, $($fields),+);
};
}
macro_rules! check_length_opt {
($max:ident, $field:expr) => {
if let Some(field) = &$field {
check_length!($max, field);
}
};
($max:ident, $field:expr, $($fields:expr),+) => {
check_length_opt!($max, $field);
check_length_opt!($max, $($fields),+);
};
}
macro_rules! check_url {
($field:expr) => {
if !($field.starts_with("http://") || $field.starts_with("https://")) {
return Err(json!({ "error": "URL invalid" }));
}
};
($field:expr, $($fields:expr),+) => {
check_url!($max, $field);
check_url!($max, $($fields),+);
};
}
macro_rules! check_url_opt {
($field:expr) => {
if let Some(field) = &$field {
check_url!(field);
}
};
($field:expr, $($fields:expr),+) => {
check_url_opt!($field);
check_url_opt!($($fields),+);
};
}
macro_rules! update_field {
($pool:expr, $error:ident, $reminder:ident.[$field:ident]) => {
if let Some(value) = &$reminder.$field {
match sqlx::query(concat!(
"UPDATE reminders SET `",
stringify!($field),
"` = ? WHERE uid = ?"
))
.bind(value)
.bind(&$reminder.uid)
.execute($pool)
.await
{
Ok(_) => {}
Err(e) => {
warn!(
concat!(
"Error in `update_field!(",
stringify!($pool),
stringify!($reminder),
stringify!($field),
")': {:?}"
),
e
);
$error.push(format!("Error setting field {}", stringify!($field)));
}
}
}
};
($pool:expr, $error:ident, $reminder:ident.[$field:ident, $($fields:ident),+]) => {
update_field!($pool, $error, $reminder.[$field]);
update_field!($pool, $error, $reminder.[$($fields),+]);
};
}
macro_rules! json_err {
($message:expr) => {
Err(json!({ "error": $message }))
};
}
macro_rules! path {
($path:expr) => {{
use rocket::fs::relative;
if Path::new(concat!("/lib/reminder-rs/", $path)).exists() {
concat!("/lib/reminder-rs/", $path)
} else {
relative!($path)
}
}};
}

27
src/web/metrics.rs Normal file
View File

@ -0,0 +1,27 @@
use rocket::{
fairing::{Fairing, Info, Kind},
Request, Response,
};
use crate::metrics::REQUEST_COUNTER;
pub struct MetricProducer;
#[rocket::async_trait]
impl Fairing for MetricProducer {
fn info(&self) -> Info {
Info { name: "Metrics fairing", kind: Kind::Response }
}
async fn on_response<'r>(&self, req: &'r Request<'_>, resp: &mut Response<'r>) {
if let Some(route) = req.route() {
REQUEST_COUNTER
.with_label_values(&[
req.method().as_str(),
&resp.status().code.to_string(),
&route.uri.to_string(),
])
.inc();
}
}
}

254
src/web/mod.rs Normal file
View File

@ -0,0 +1,254 @@
mod consts;
#[macro_use]
mod macros;
mod catchers;
mod guards;
mod metrics;
mod routes;
use std::{env, path::Path};
use log::{error, info, warn};
use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use poise::serenity_prelude::{
client::Context,
http::CacheHttp,
model::id::{GuildId, UserId},
};
use rocket::{
catchers,
fs::FileServer,
http::CookieJar,
routes,
serde::json::{json, Value as JsonValue},
tokio::sync::broadcast::Sender,
};
use rocket_dyn_templates::Template;
use sqlx::{MySql, Pool};
use crate::web::{
consts::{CNC_GUILD, DISCORD_OAUTH_AUTHORIZE, DISCORD_OAUTH_TOKEN, SUBSCRIPTION_ROLES},
metrics::MetricProducer,
};
type Database = MySql;
#[derive(Debug)]
enum Error {
#[allow(unused)]
SQLx(sqlx::Error),
#[allow(unused)]
Serenity(serenity::Error),
}
pub async fn initialize(
kill_channel: Sender<()>,
serenity_context: Context,
db_pool: Pool<Database>,
) -> Result<(), Box<dyn std::error::Error>> {
info!("Checking environment variables...");
if env::var("OFFLINE").map_or(true, |v| v != "1") {
env::var("OAUTH2_CLIENT_ID").expect("`OAUTH2_CLIENT_ID' not supplied");
env::var("OAUTH2_CLIENT_SECRET").expect("`OAUTH2_CLIENT_SECRET' not supplied");
env::var("OAUTH2_DISCORD_CALLBACK").expect("`OAUTH2_DISCORD_CALLBACK' not supplied");
env::var("PATREON_GUILD_ID").expect("`PATREON_GUILD_ID' not supplied");
}
info!("Done!");
let oauth2_client = BasicClient::new(
ClientId::new(env::var("OAUTH2_CLIENT_ID")?),
Some(ClientSecret::new(env::var("OAUTH2_CLIENT_SECRET")?)),
AuthUrl::new(DISCORD_OAUTH_AUTHORIZE.to_string())?,
Some(TokenUrl::new(DISCORD_OAUTH_TOKEN.to_string())?),
)
.set_redirect_uri(RedirectUrl::new(env::var("OAUTH2_DISCORD_CALLBACK")?)?);
let reqwest_client = reqwest::Client::new();
let static_path =
if Path::new("static").exists() { "static" } else { "/lib/reminder-rs/static" };
rocket::build()
.attach(MetricProducer)
.attach(Template::fairing())
.register(
"/",
catchers![
catchers::not_authorized,
catchers::forbidden,
catchers::not_found,
catchers::internal_server_error,
catchers::unprocessable_entity,
catchers::payload_too_large,
],
)
.manage(oauth2_client)
.manage(reqwest_client)
.manage(serenity_context)
.manage(db_pool)
.mount("/static", FileServer::from(static_path))
.mount(
"/",
routes![
routes::cookies,
routes::index,
routes::privacy,
routes::report::report_error,
routes::return_to_same_site,
routes::terms,
],
)
.mount(
"/help",
routes![
routes::help,
routes::help_timezone,
routes::help_create_reminder,
routes::help_delete_reminder,
routes::help_timers,
routes::help_todo_lists,
routes::help_macros,
routes::help_intervals,
routes::help_dashboard,
routes::help_iemanager,
],
)
.mount(
"/login",
routes![
routes::login::discord_login,
routes::login::discord_logout,
routes::login::discord_callback
],
)
.mount(
"/dashboard",
routes![
routes::dashboard::dashboard,
routes::dashboard::dashboard_home,
routes::dashboard::api::delete_reminder,
routes::dashboard::api::user::get_user_info,
routes::dashboard::api::user::update_user_info,
routes::dashboard::api::user::get_user_guilds,
routes::dashboard::api::user::get_reminders,
routes::dashboard::api::user::edit_reminder,
routes::dashboard::api::user::create_user_reminder,
routes::dashboard::api::guild::get_guild_info,
routes::dashboard::api::guild::get_guild_channels,
routes::dashboard::api::guild::get_guild_roles,
routes::dashboard::api::guild::get_reminder_templates,
routes::dashboard::api::guild::create_reminder_template,
routes::dashboard::api::guild::delete_reminder_template,
routes::dashboard::api::guild::create_guild_reminder,
routes::dashboard::api::guild::get_reminders,
routes::dashboard::api::guild::edit_reminder,
routes::dashboard::export::export_reminders,
routes::dashboard::export::export_reminder_templates,
routes::dashboard::export::export_todos,
routes::dashboard::export::import_reminders,
routes::dashboard::export::import_todos,
],
)
.launch()
.await?;
warn!("Exiting rocket runtime");
// distribute kill signal
match kill_channel.send(()) {
Ok(_) => {}
Err(e) => {
error!("Failed to issue kill signal: {:?}", e);
}
}
Ok(())
}
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
offline!(true);
if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId::new(subscription_guild).member(cache_http, user_id).await;
if let Ok(member) = guild_member {
for role in member.roles {
if SUBSCRIPTION_ROLES.contains(&role.get()) {
return true;
}
}
}
false
} else {
true
}
}
pub async fn check_guild_subscription(
cache_http: impl CacheHttp,
guild_id: impl Into<GuildId>,
) -> bool {
offline!(true);
if let Some(owner) = cache_http.cache().unwrap().guild(guild_id).map(|guild| guild.owner_id) {
check_subscription(&cache_http, owner).await
} else {
false
}
}
pub async fn check_authorization(
cookies: &CookieJar<'_>,
ctx: &Context,
guild_id: u64,
) -> Result<(), JsonValue> {
let user_id = cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten();
if std::env::var("OFFLINE").map_or(true, |v| v != "1") {
match user_id {
Some(user_id) => {
let admin_id = std::env::var("ADMIN_ID")
.map_or(false, |u| u.parse::<u64>().map_or(false, |u| u == user_id));
if admin_id {
return Ok(());
}
let guild_id = GuildId::new(guild_id);
match guild_id.member(ctx, UserId::new(user_id)).await {
Ok(member) => {
let permissions_res = member.permissions(ctx);
match permissions_res {
Err(_) => {
return Err(json!({"error": "Couldn't fetch permissions"}));
}
Ok(permissions) => {
if !(permissions.manage_messages()
|| permissions.manage_guild()
|| permissions.administrator())
{
return Err(json!({"error": "Incorrect permissions"}));
}
}
}
}
Err(_) => {
return Err(json!({"error": "User not in guild"}));
}
}
}
None => {
return Err(json!({"error": "User not authorized"}));
}
}
}
Ok(())
}

View File

@ -0,0 +1,61 @@
use rocket::{get, http::CookieJar, serde::json::json, State};
use serde::Serialize;
use serenity::{
client::Context,
model::{
channel::GuildChannel,
id::{ChannelId, GuildId},
},
};
use crate::web::{check_authorization, routes::JsonResult};
#[derive(Serialize)]
struct ChannelInfo {
id: String,
name: String,
webhook_avatar: Option<String>,
webhook_name: Option<String>,
}
#[get("/api/guild/<id>/channels")]
pub async fn get_guild_channels(
id: u64,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
) -> JsonResult {
offline!(Ok(json!(vec![ChannelInfo {
name: "general".to_string(),
id: "1".to_string(),
webhook_avatar: None,
webhook_name: None,
}])));
check_authorization(cookies, ctx.inner(), id).await?;
match GuildId::new(id).to_guild_cached(ctx.inner()) {
Some(guild) => {
let mut channels = guild
.channels
.iter()
.filter(|(_, channel)| channel.is_text_based())
.map(|(id, channel)| (id.to_owned(), channel.to_owned()))
.collect::<Vec<(ChannelId, GuildChannel)>>();
channels.sort_by(|(_, c1), (_, c2)| c1.position.cmp(&c2.position));
let channel_info = channels
.iter()
.map(|(channel_id, channel)| ChannelInfo {
name: channel.name.to_string(),
id: channel_id.to_string(),
webhook_avatar: None,
webhook_name: None,
})
.collect::<Vec<ChannelInfo>>();
Ok(json!(channel_info))
}
None => json_err!("Bot not in guild"),
}
}

View File

@ -0,0 +1,45 @@
mod channels;
mod reminders;
mod roles;
mod templates;
use std::env;
pub use channels::*;
pub use reminders::*;
use rocket::{get, http::CookieJar, serde::json::json, State};
pub use roles::*;
use serenity::{
client::Context,
model::id::{GuildId, RoleId},
};
pub use templates::*;
use crate::web::{check_authorization, routes::JsonResult};
#[get("/api/guild/<id>")]
pub async fn get_guild_info(id: u64, cookies: &CookieJar<'_>, ctx: &State<Context>) -> JsonResult {
offline!(Ok(json!({ "patreon": true, "name": "Guild" })));
check_authorization(cookies, ctx.inner(), id).await?;
match GuildId::new(id)
.to_guild_cached(ctx.inner())
.map(|guild| (guild.owner_id, guild.name.clone()))
{
Some((owner_id, name)) => {
let member_res = GuildId::new(env::var("PATREON_GUILD_ID").unwrap().parse().unwrap())
.member(&ctx.inner(), owner_id)
.await;
let patreon = member_res.map_or(false, |member| {
member
.roles
.contains(&RoleId::new(env::var("PATREON_ROLE_ID").unwrap().parse().unwrap()))
});
Ok(json!({ "patreon": patreon, "name": name }))
}
None => json_err!("Bot not in guild"),
}
}

View File

@ -0,0 +1,367 @@
use log::warn;
use rocket::{
get,
http::CookieJar,
patch, post,
serde::json::{json, Json},
State,
};
use serenity::{
client::Context,
model::id::{ChannelId, GuildId, UserId},
};
use sqlx::{MySql, Pool};
use crate::web::{
check_authorization, check_guild_subscription, check_subscription,
consts::MIN_INTERVAL,
guards::transaction::Transaction,
routes::{
dashboard::{create_database_channel, create_reminder, PatchReminder, Reminder},
JsonResult,
},
Database,
};
#[post("/api/guild/<id>/reminders", data = "<reminder>")]
pub async fn create_guild_reminder(
id: u64,
reminder: Json<Reminder>,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
mut transaction: Transaction<'_>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
let user_id =
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
match create_reminder(
ctx.inner(),
&mut transaction,
GuildId::new(id),
UserId::new(user_id),
reminder.into_inner(),
)
.await
{
Ok(r) => match transaction.commit().await {
Ok(_) => Ok(r),
Err(e) => {
warn!("Couldn't commit transaction: {:?}", e);
json_err!("Couldn't commit transaction.")
}
},
Err(e) => Err(e),
}
}
#[get("/api/guild/<id>/reminders")]
pub async fn get_reminders(
id: u64,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
let channels_res = GuildId::new(id).channels(&ctx.inner()).await;
match channels_res {
Ok(channels) => {
let channels = channels
.keys()
.into_iter()
.map(|k| k.get().to_string())
.collect::<Vec<String>>()
.join(",");
sqlx::query_as_unchecked!(
Reminder,
"SELECT
reminders.attachment,
reminders.attachment_name,
reminders.avatar,
channels.channel,
reminders.content,
reminders.embed_author,
reminders.embed_author_url,
reminders.embed_color,
reminders.embed_description,
reminders.embed_footer,
reminders.embed_footer_url,
reminders.embed_image_url,
reminders.embed_thumbnail_url,
reminders.embed_title,
IFNULL(reminders.embed_fields, '[]') AS embed_fields,
reminders.enabled,
reminders.expires,
reminders.interval_seconds,
reminders.interval_days,
reminders.interval_months,
reminders.name,
reminders.restartable,
reminders.tts,
reminders.uid,
reminders.username,
reminders.utc_time
FROM reminders
INNER JOIN channels ON channels.id = reminders.channel_id
WHERE `status` = 'pending' AND FIND_IN_SET(channels.channel, ?)",
channels
)
.fetch_all(pool.inner())
.await
.map(|r| Ok(json!(r)))
.unwrap_or_else(|e| {
warn!("Failed to complete SQL query: {:?}", e);
json_err!("Could not load reminders")
})
}
Err(e) => {
warn!("Could not fetch channels from {}: {:?}", id, e);
Ok(json!([]))
}
}
}
#[patch("/api/guild/<id>/reminders", data = "<reminder>")]
pub async fn edit_reminder(
id: u64,
reminder: Json<PatchReminder>,
ctx: &State<Context>,
mut transaction: Transaction<'_>,
pool: &State<Pool<Database>>,
cookies: &CookieJar<'_>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
let mut error = vec![];
let user_id =
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
if reminder.message_ok() {
update_field!(transaction.executor(), error, reminder.[
content,
embed_author,
embed_description,
embed_footer,
embed_title,
embed_fields,
username
]);
} else {
error.push("Message exceeds limits.".to_string());
}
update_field!(transaction.executor(), error, reminder.[
attachment,
attachment_name,
avatar,
embed_author_url,
embed_color,
embed_footer_url,
embed_image_url,
embed_thumbnail_url,
enabled,
expires,
name,
restartable,
tts,
utc_time
]);
if reminder.interval_days.flatten().is_some()
|| reminder.interval_months.flatten().is_some()
|| reminder.interval_seconds.flatten().is_some()
{
if check_guild_subscription(&ctx.inner(), id).await
|| check_subscription(&ctx.inner(), user_id).await
{
let new_interval_length = match reminder.interval_days {
Some(interval) => interval.unwrap_or(0),
None => sqlx::query!(
"SELECT interval_days AS days FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
json!({ "reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"] })
})?
.days
.unwrap_or(0),
} * 86400 + match reminder.interval_months {
Some(interval) => interval.unwrap_or(0),
None => sqlx::query!(
"SELECT interval_months AS months FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
json!({ "reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"] })
})?
.months
.unwrap_or(0),
} * 2592000 + match reminder.interval_seconds {
Some(interval) => interval.unwrap_or(0),
None => sqlx::query!(
"SELECT interval_seconds AS seconds FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
json!({ "reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"] })
})?
.seconds
.unwrap_or(0),
};
if new_interval_length < *MIN_INTERVAL {
error.push(String::from("New interval is too short."));
} else {
update_field!(transaction.executor(), error, reminder.[
interval_days,
interval_months,
interval_seconds
]);
}
}
} else {
sqlx::query!(
"
UPDATE reminders
SET interval_seconds = NULL, interval_days = NULL, interval_months = NULL
WHERE uid = ?
",
reminder.uid
)
.execute(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
json!({ "reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"] })
})?;
}
if reminder.channel > 0 {
let channel_guild = ChannelId::new(reminder.channel)
.to_channel_cached(&ctx.cache)
.map(|channel| channel.guild_id);
match channel_guild {
Some(channel_guild) => {
let channel_matches_guild = channel_guild.get() == id;
if !channel_matches_guild {
warn!(
"Error in `edit_reminder`: channel {:?} not found for guild {}",
reminder.channel, id
);
return Err(json!({"error": "Channel not found"}));
}
let channel = create_database_channel(
ctx.inner(),
ChannelId::new(reminder.channel),
&mut transaction,
)
.await;
if let Err(e) = channel {
warn!("`create_database_channel` returned an error code: {:?}", e);
return Err(
json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"}),
);
}
let channel = channel.unwrap();
match sqlx::query!(
"UPDATE reminders SET channel_id = ? WHERE uid = ?",
channel,
reminder.uid
)
.execute(transaction.executor())
.await
{
Ok(_) => {}
Err(e) => {
warn!("Error setting channel: {:?}", e);
error.push("Couldn't set channel".to_string())
}
}
}
None => {
warn!(
"Error in `edit_reminder`: channel {:?} not found for guild {}",
reminder.channel, id
);
return Err(json!({"error": "Channel not found"}));
}
}
}
if let Err(e) = transaction.commit().await {
warn!("Couldn't commit transaction: {:?}", e);
return json_err!("Couldn't commit transaction");
}
match sqlx::query_as_unchecked!(
Reminder,
"SELECT reminders.attachment,
reminders.attachment_name,
reminders.avatar,
channels.channel,
reminders.content,
reminders.embed_author,
reminders.embed_author_url,
reminders.embed_color,
reminders.embed_description,
reminders.embed_footer,
reminders.embed_footer_url,
reminders.embed_image_url,
reminders.embed_thumbnail_url,
reminders.embed_title,
reminders.embed_fields,
reminders.enabled,
reminders.expires,
reminders.interval_seconds,
reminders.interval_days,
reminders.interval_months,
reminders.name,
reminders.restartable,
reminders.tts,
reminders.uid,
reminders.username,
reminders.utc_time
FROM reminders
LEFT JOIN channels ON channels.id = reminders.channel_id
WHERE uid = ?",
reminder.uid
)
.fetch_one(pool.inner())
.await
{
Ok(reminder) => Ok(json!({"reminder": reminder, "errors": error})),
Err(e) => {
warn!("Error exiting `edit_reminder': {:?}", e);
Err(json!({"reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"]}))
}
}
}

View File

@ -0,0 +1,36 @@
use log::warn;
use rocket::{get, http::CookieJar, serde::json::json, State};
use serde::Serialize;
use serenity::client::Context;
use crate::web::{check_authorization, routes::JsonResult};
#[derive(Serialize)]
struct RoleInfo {
id: String,
name: String,
}
#[get("/api/guild/<id>/roles")]
pub async fn get_guild_roles(id: u64, cookies: &CookieJar<'_>, ctx: &State<Context>) -> JsonResult {
offline!(Ok(json!(vec![RoleInfo { name: "@everyone".to_string(), id: "1".to_string() }])));
check_authorization(cookies, ctx.inner(), id).await?;
let roles_res = ctx.cache.guild_roles(id);
match roles_res {
Some(roles) => {
let roles = roles
.iter()
.map(|(_, r)| RoleInfo { id: r.id.to_string(), name: r.name.to_string() })
.collect::<Vec<RoleInfo>>();
Ok(json!(roles))
}
None => {
warn!("Could not fetch roles from {}", id);
json_err!("Could not get roles")
}
}
}

View File

@ -0,0 +1,184 @@
use log::warn;
use rocket::{
delete, get,
http::CookieJar,
post,
serde::json::{json, Json},
State,
};
use serenity::client::Context;
use sqlx::{MySql, Pool};
use crate::web::{
check_authorization,
consts::{
MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH, MAX_EMBED_DESCRIPTION_LENGTH,
MAX_EMBED_FIELDS, MAX_EMBED_FIELD_TITLE_LENGTH, MAX_EMBED_FIELD_VALUE_LENGTH,
MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH,
},
routes::{
dashboard::{template_name_default, DeleteReminderTemplate, ReminderTemplate},
JsonResult,
},
};
#[get("/api/guild/<id>/templates")]
pub async fn get_reminder_templates(
id: u64,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
match sqlx::query_as_unchecked!(
ReminderTemplate,
"SELECT * FROM reminder_template WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
id
)
.fetch_all(pool.inner())
.await
{
Ok(templates) => Ok(json!(templates)),
Err(e) => {
warn!("Could not fetch templates from {}: {:?}", id, e);
json_err!("Could not get templates")
}
}
}
#[post("/api/guild/<id>/templates", data = "<reminder_template>")]
pub async fn create_reminder_template(
id: u64,
reminder_template: Json<ReminderTemplate>,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
// validate lengths
check_length!(MAX_CONTENT_LENGTH, reminder_template.content);
check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder_template.embed_description);
check_length!(MAX_EMBED_TITLE_LENGTH, reminder_template.embed_title);
check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder_template.embed_author);
check_length!(MAX_EMBED_FOOTER_LENGTH, reminder_template.embed_footer);
check_length_opt!(MAX_EMBED_FIELDS, reminder_template.embed_fields);
if let Some(fields) = &reminder_template.embed_fields {
for field in &fields.0 {
check_length!(MAX_EMBED_FIELD_VALUE_LENGTH, field.value);
check_length!(MAX_EMBED_FIELD_TITLE_LENGTH, field.title);
}
}
check_length_opt!(MAX_USERNAME_LENGTH, reminder_template.username);
check_length_opt!(
MAX_URL_LENGTH,
reminder_template.embed_footer_url,
reminder_template.embed_thumbnail_url,
reminder_template.embed_author_url,
reminder_template.embed_image_url,
reminder_template.avatar
);
// validate urls
check_url_opt!(
reminder_template.embed_footer_url,
reminder_template.embed_thumbnail_url,
reminder_template.embed_author_url,
reminder_template.embed_image_url,
reminder_template.avatar
);
let name = if reminder_template.name.is_empty() {
template_name_default()
} else {
reminder_template.name.clone()
};
match sqlx::query!(
"INSERT INTO reminder_template
(guild_id,
name,
attachment,
attachment_name,
avatar,
content,
embed_author,
embed_author_url,
embed_color,
embed_description,
embed_footer,
embed_footer_url,
embed_image_url,
embed_thumbnail_url,
embed_title,
embed_fields,
interval_seconds,
interval_days,
interval_months,
tts,
username
) VALUES ((SELECT id FROM guilds WHERE guild = ?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
?, ?, ?, ?, ?, ?, ?)",
id,
name,
reminder_template.attachment,
reminder_template.attachment_name,
reminder_template.avatar,
reminder_template.content,
reminder_template.embed_author,
reminder_template.embed_author_url,
reminder_template.embed_color,
reminder_template.embed_description,
reminder_template.embed_footer,
reminder_template.embed_footer_url,
reminder_template.embed_image_url,
reminder_template.embed_thumbnail_url,
reminder_template.embed_title,
reminder_template.embed_fields,
reminder_template.interval_seconds,
reminder_template.interval_days,
reminder_template.interval_months,
reminder_template.tts,
reminder_template.username,
)
.fetch_all(pool.inner())
.await
{
Ok(_) => Ok(json!({})),
Err(e) => {
warn!("Could not create template for {}: {:?}", id, e);
json_err!("Could not create template")
}
}
}
#[delete("/api/guild/<id>/templates", data = "<delete_reminder_template>")]
pub async fn delete_reminder_template(
id: u64,
delete_reminder_template: Json<DeleteReminderTemplate>,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
match sqlx::query!(
"DELETE FROM reminder_template WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND id = ?",
id, delete_reminder_template.id
)
.fetch_all(pool.inner())
.await
{
Ok(_) => {
Ok(json!({}))
}
Err(e) => {
warn!("Could not delete template from {}: {:?}", id, e);
json_err!("Could not delete template")
}
}
}

View File

@ -0,0 +1,42 @@
pub mod guild;
pub mod user;
use log::warn;
use rocket::{
delete,
http::CookieJar,
serde::json::{json, Json},
State,
};
use sqlx::{MySql, Pool};
use crate::web::routes::{dashboard::DeleteReminder, JsonResult};
#[delete("/api/reminders", data = "<reminder>")]
pub async fn delete_reminder(
cookies: &CookieJar<'_>,
reminder: Json<DeleteReminder>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
match cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten() {
Some(_) => {
match sqlx::query!(
"UPDATE reminders SET `status` = 'deleted' WHERE uid = ?",
reminder.uid
)
.execute(pool.inner())
.await
{
Ok(_) => Ok(json!({})),
Err(e) => {
warn!("Error in `delete_reminder`: {:?}", e);
Err(json!({"error": "Could not delete reminder"}))
}
}
}
None => Err(json!({"error": "User not authorized"})),
}
}

View File

@ -0,0 +1,83 @@
use log::warn;
use reqwest::Client;
use rocket::{
get,
http::CookieJar,
serde::json::{json, Value as JsonValue},
State,
};
use serde::{Deserialize, Serialize};
use serenity::model::{id::GuildId, permissions::Permissions};
use crate::web::consts::DISCORD_API;
#[derive(Serialize)]
struct GuildInfo {
id: String,
name: String,
}
#[derive(Deserialize)]
struct PartialGuild {
pub id: GuildId,
pub name: String,
#[serde(default)]
pub owner: bool,
#[serde(rename = "permissions_new")]
pub permissions: Option<String>,
}
#[get("/api/user/guilds")]
pub async fn get_user_guilds(cookies: &CookieJar<'_>, reqwest_client: &State<Client>) -> JsonValue {
offline!(json!(vec![GuildInfo { id: "1".to_string(), name: "Guild".to_string() }]));
if let Some(access_token) = cookies.get_private("access_token") {
let request_res = reqwest_client
.get(format!("{}/users/@me/guilds", DISCORD_API))
.bearer_auth(access_token.value())
.send()
.await;
match request_res {
Ok(response) => {
let guilds_res = response.json::<Vec<PartialGuild>>().await;
match guilds_res {
Ok(guilds) => {
let reduced_guilds = guilds
.iter()
.filter(|g| {
g.owner
|| g.permissions.as_ref().map_or(false, |p| {
let permissions =
Permissions::from_bits_truncate(p.parse().unwrap());
permissions.manage_messages()
|| permissions.manage_guild()
|| permissions.administrator()
})
})
.map(|g| GuildInfo { id: g.id.to_string(), name: g.name.to_string() })
.collect::<Vec<GuildInfo>>();
json!(reduced_guilds)
}
Err(e) => {
warn!("Error constructing user from request: {:?}", e);
json!({"error": "Could not get user details"})
}
}
}
Err(e) => {
warn!("Error getting user guilds: {:?}", e);
json!({"error": "Could not reach Discord"})
}
}
} else {
json!({"error": "Not authorized"})
}
}

View File

@ -0,0 +1,102 @@
mod guilds;
mod models;
mod reminders;
use std::env;
use chrono_tz::Tz;
pub use guilds::*;
pub use reminders::*;
use rocket::{
get,
http::CookieJar,
patch,
serde::json::{json, Json, Value as JsonValue},
State,
};
use serde::{Deserialize, Serialize};
use serenity::{
client::Context,
model::id::{GuildId, RoleId},
};
use sqlx::{MySql, Pool};
#[derive(Serialize)]
struct UserInfo {
name: String,
patreon: bool,
timezone: Option<String>,
}
#[derive(Deserialize)]
pub struct UpdateUser {
timezone: String,
}
#[get("/api/user")]
pub async fn get_user_info(
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonValue {
offline!(json!(UserInfo { name: "Discord".to_string(), patreon: true, timezone: None }));
if let Some(user_id) =
cookies.get_private("userid").map(|u| u.value().parse::<u64>().ok()).flatten()
{
let member_res = GuildId::new(env::var("PATREON_GUILD_ID").unwrap().parse().unwrap())
.member(&ctx.inner(), user_id)
.await;
let timezone = sqlx::query!(
"SELECT IFNULL(timezone, 'UTC') AS timezone FROM users WHERE user = ?",
user_id
)
.fetch_one(pool.inner())
.await
.map_or(None, |q| Some(q.timezone));
let user_info = UserInfo {
name: cookies
.get_private("username")
.map_or("Discord User".to_string(), |c| c.value().to_string()),
patreon: member_res.map_or(false, |member| {
member
.roles
.contains(&RoleId::new(env::var("PATREON_ROLE_ID").unwrap().parse().unwrap()))
}),
timezone,
};
json!(user_info)
} else {
json!({"error": "Not authorized"})
}
}
#[patch("/api/user", data = "<user>")]
pub async fn update_user_info(
cookies: &CookieJar<'_>,
user: Json<UpdateUser>,
pool: &State<Pool<MySql>>,
) -> JsonValue {
if let Some(user_id) =
cookies.get_private("userid").map(|u| u.value().parse::<u64>().ok()).flatten()
{
if user.timezone.parse::<Tz>().is_ok() {
let _ = sqlx::query!(
"UPDATE users SET timezone = ? WHERE user = ?",
user.timezone,
user_id,
)
.execute(pool.inner())
.await;
json!({})
} else {
json!({"error": "Timezone not recognized"})
}
} else {
json!({"error": "Not authorized"})
}
}

View File

@ -0,0 +1,231 @@
use chrono::{naive::NaiveDateTime, Utc};
use futures::TryFutureExt;
use log::warn;
use rocket::serde::json::json;
use serde::{Deserialize, Serialize};
use serenity::{client::Context, model::id::UserId};
use sqlx::types::Json;
use crate::web::{
check_subscription,
consts::{
DAY, MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH, MAX_EMBED_DESCRIPTION_LENGTH,
MAX_EMBED_FIELDS, MAX_EMBED_FIELD_TITLE_LENGTH, MAX_EMBED_FIELD_VALUE_LENGTH,
MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH, MAX_NAME_LENGTH, MAX_URL_LENGTH,
MIN_INTERVAL,
},
guards::transaction::Transaction,
routes::{
dashboard::{create_database_channel, generate_uid, name_default, Attachment, EmbedField},
JsonResult,
},
Error,
};
#[derive(Serialize, Deserialize)]
pub struct Reminder {
pub attachment: Option<Attachment>,
pub attachment_name: Option<String>,
pub content: String,
pub embed_author: String,
pub embed_author_url: Option<String>,
pub embed_color: u32,
pub embed_description: String,
pub embed_footer: String,
pub embed_footer_url: Option<String>,
pub embed_image_url: Option<String>,
pub embed_thumbnail_url: Option<String>,
pub embed_title: String,
pub embed_fields: Option<Json<Vec<EmbedField>>>,
pub enabled: bool,
pub expires: Option<NaiveDateTime>,
pub interval_seconds: Option<u32>,
pub interval_days: Option<u32>,
pub interval_months: Option<u32>,
#[serde(default = "name_default")]
pub name: String,
pub tts: bool,
#[serde(default)]
pub uid: String,
pub utc_time: NaiveDateTime,
}
pub async fn create_reminder(
ctx: &Context,
transaction: &mut Transaction<'_>,
user_id: UserId,
reminder: Reminder,
) -> JsonResult {
let channel = user_id
.create_dm_channel(&ctx)
.map_err(|e| Error::Serenity(e))
.and_then(|dm_channel| create_database_channel(&ctx, dm_channel.id, transaction))
.await;
if let Err(e) = channel {
warn!("`create_database_channel` returned an error code: {:?}", e);
return Err(json!({"error": "Failed to configure channel for reminders."}));
}
let channel = channel.unwrap();
// validate lengths
check_length!(MAX_NAME_LENGTH, reminder.name);
check_length!(MAX_CONTENT_LENGTH, reminder.content);
check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder.embed_description);
check_length!(MAX_EMBED_TITLE_LENGTH, reminder.embed_title);
check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder.embed_author);
check_length!(MAX_EMBED_FOOTER_LENGTH, reminder.embed_footer);
check_length_opt!(MAX_EMBED_FIELDS, reminder.embed_fields);
if let Some(fields) = &reminder.embed_fields {
for field in &fields.0 {
check_length!(MAX_EMBED_FIELD_VALUE_LENGTH, field.value);
check_length!(MAX_EMBED_FIELD_TITLE_LENGTH, field.title);
}
}
check_length_opt!(
MAX_URL_LENGTH,
reminder.embed_footer_url,
reminder.embed_thumbnail_url,
reminder.embed_author_url,
reminder.embed_image_url
);
// validate urls
check_url_opt!(
reminder.embed_footer_url,
reminder.embed_thumbnail_url,
reminder.embed_author_url,
reminder.embed_image_url
);
// validate time and interval
if reminder.utc_time < Utc::now().naive_utc() {
return Err(json!({"error": "Time must be in the future"}));
}
if reminder.interval_seconds.is_some()
|| reminder.interval_days.is_some()
|| reminder.interval_months.is_some()
{
if reminder.interval_months.unwrap_or(0) * 30 * DAY as u32
+ reminder.interval_days.unwrap_or(0) * DAY as u32
+ reminder.interval_seconds.unwrap_or(0)
< *MIN_INTERVAL
{
return Err(json!({"error": "Interval too short"}));
}
}
// check patreon if necessary
if reminder.interval_seconds.is_some()
|| reminder.interval_days.is_some()
|| reminder.interval_months.is_some()
{
if !check_subscription(&ctx, user_id).await {
return Err(json!({"error": "Patreon is required to set intervals"}));
}
}
let name = if reminder.name.is_empty() { name_default() } else { reminder.name.clone() };
let new_uid = generate_uid();
// write to db
match sqlx::query!(
"INSERT INTO reminders (
uid,
attachment,
attachment_name,
channel_id,
content,
embed_author,
embed_author_url,
embed_color,
embed_description,
embed_footer,
embed_footer_url,
embed_image_url,
embed_thumbnail_url,
embed_title,
embed_fields,
enabled,
expires,
interval_seconds,
interval_days,
interval_months,
name,
tts,
`utc_time`
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
new_uid,
reminder.attachment,
reminder.attachment_name,
channel,
reminder.content,
reminder.embed_author,
reminder.embed_author_url,
reminder.embed_color,
reminder.embed_description,
reminder.embed_footer,
reminder.embed_footer_url,
reminder.embed_image_url,
reminder.embed_thumbnail_url,
reminder.embed_title,
reminder.embed_fields,
reminder.enabled,
reminder.expires,
reminder.interval_seconds,
reminder.interval_days,
reminder.interval_months,
name,
reminder.tts,
reminder.utc_time,
)
.execute(transaction.executor())
.await
{
Ok(_) => sqlx::query_as_unchecked!(
Reminder,
"SELECT
reminders.attachment,
reminders.attachment_name,
reminders.content,
reminders.embed_author,
reminders.embed_author_url,
reminders.embed_color,
reminders.embed_description,
reminders.embed_footer,
reminders.embed_footer_url,
reminders.embed_image_url,
reminders.embed_thumbnail_url,
reminders.embed_title,
reminders.embed_fields,
reminders.enabled,
reminders.expires,
reminders.interval_seconds,
reminders.interval_days,
reminders.interval_months,
reminders.name,
reminders.tts,
reminders.uid,
reminders.utc_time
FROM reminders
WHERE uid = ?",
new_uid
)
.fetch_one(transaction.executor())
.await
.map(|r| Ok(json!(r)))
.unwrap_or_else(|e| {
warn!("Failed to complete SQL query: {:?}", e);
Err(json!({"error": "Could not load reminder"}))
}),
Err(e) => {
warn!("Error in `create_reminder`: Could not execute query: {:?}", e);
Err(json!({"error": "Unknown error"}))
}
}
}

View File

@ -0,0 +1,284 @@
use log::warn;
use rocket::{
get,
http::CookieJar,
patch, post,
serde::json::{json, Json},
State,
};
use serenity::{client::Context, model::id::UserId};
use sqlx::{MySql, Pool};
use crate::web::{
check_subscription,
guards::transaction::Transaction,
routes::{
dashboard::{
api::user::models::{create_reminder, Reminder},
PatchReminder, MIN_INTERVAL,
},
JsonResult,
},
Database,
};
#[post("/api/user/reminders", data = "<reminder>")]
pub async fn create_user_reminder(
reminder: Json<Reminder>,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
mut transaction: Transaction<'_>,
) -> JsonResult {
let user_id =
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
match create_reminder(
ctx.inner(),
&mut transaction,
UserId::new(user_id),
reminder.into_inner(),
)
.await
{
Ok(r) => match transaction.commit().await {
Ok(_) => Ok(r),
Err(e) => {
warn!("Couldn't commit transaction: {:?}", e);
json_err!("Couldn't commit transaction.")
}
},
Err(e) => Err(e),
}
}
#[get("/api/user/reminders")]
pub async fn get_reminders(
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
let user_id =
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
let channel = UserId::new(user_id).create_dm_channel(ctx.inner()).await;
match channel {
Ok(channel) => sqlx::query_as_unchecked!(
Reminder,
"
SELECT
reminders.attachment,
reminders.attachment_name,
reminders.content,
reminders.embed_author,
reminders.embed_author_url,
reminders.embed_color,
reminders.embed_description,
reminders.embed_footer,
reminders.embed_footer_url,
reminders.embed_image_url,
reminders.embed_thumbnail_url,
reminders.embed_title,
IFNULL(reminders.embed_fields, '[]') AS embed_fields,
reminders.enabled,
reminders.expires,
reminders.interval_seconds,
reminders.interval_days,
reminders.interval_months,
reminders.name,
reminders.tts,
reminders.uid,
reminders.utc_time
FROM reminders
INNER JOIN channels ON channels.id = reminders.channel_id
WHERE `status` = 'pending' AND channels.channel = ?
",
channel.id.get()
)
.fetch_all(pool.inner())
.await
.map(|r| Ok(json!(r)))
.unwrap_or_else(|e| {
warn!("Failed to complete SQL query: {:?}", e);
json_err!("Could not load reminders")
}),
Err(e) => {
warn!("Couldn't get DM channel: {:?}", e);
json_err!("Could not find a DM channel")
}
}
}
#[patch("/api/user/reminders", data = "<reminder>")]
pub async fn edit_reminder(
reminder: Json<PatchReminder>,
ctx: &State<Context>,
mut transaction: Transaction<'_>,
pool: &State<Pool<Database>>,
cookies: &CookieJar<'_>,
) -> JsonResult {
let user_id_cookie =
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten();
if user_id_cookie.is_none() {
return Err(json!({"error": "User not authorized"}));
}
let mut error = vec![];
let user_id = user_id_cookie.unwrap();
if reminder.message_ok() {
update_field!(transaction.executor(), error, reminder.[
content,
embed_author,
embed_description,
embed_footer,
embed_title,
embed_fields
]);
} else {
error.push("Message exceeds limits.".to_string());
}
update_field!(transaction.executor(), error, reminder.[
attachment,
attachment_name,
embed_author_url,
embed_color,
embed_footer_url,
embed_image_url,
embed_thumbnail_url,
enabled,
expires,
name,
tts,
utc_time
]);
if reminder.interval_days.flatten().is_some()
|| reminder.interval_months.flatten().is_some()
|| reminder.interval_seconds.flatten().is_some()
{
if check_subscription(&ctx.inner(), user_id).await {
let new_interval_length = match reminder.interval_days {
Some(interval) => interval.unwrap_or(0),
None => sqlx::query!(
"SELECT interval_days AS days FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
json!({ "reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"] })
})?
.days
.unwrap_or(0),
} * 86400 + match reminder.interval_months {
Some(interval) => interval.unwrap_or(0),
None => sqlx::query!(
"SELECT interval_months AS months FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
json!({ "reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"] })
})?
.months
.unwrap_or(0),
} * 2592000 + match reminder.interval_seconds {
Some(interval) => interval.unwrap_or(0),
None => sqlx::query!(
"SELECT interval_seconds AS seconds FROM reminders WHERE uid = ?",
reminder.uid
)
.fetch_one(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
json!({ "reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"] })
})?
.seconds
.unwrap_or(0),
};
if new_interval_length < *MIN_INTERVAL {
error.push(String::from("New interval is too short."));
} else {
update_field!(transaction.executor(), error, reminder.[
interval_days,
interval_months,
interval_seconds
]);
}
}
} else {
sqlx::query!(
"
UPDATE reminders
SET interval_seconds = NULL, interval_days = NULL, interval_months = NULL
WHERE uid = ?
",
reminder.uid
)
.execute(transaction.executor())
.await
.map_err(|e| {
warn!("Error updating reminder interval: {:?}", e);
json!({ "reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"] })
})?;
}
if let Err(e) = transaction.commit().await {
warn!("Couldn't commit transaction: {:?}", e);
return json_err!("Couldn't commit transaction");
}
match sqlx::query_as_unchecked!(
Reminder,
"
SELECT reminders.attachment,
reminders.attachment_name,
reminders.content,
reminders.embed_author,
reminders.embed_author_url,
reminders.embed_color,
reminders.embed_description,
reminders.embed_footer,
reminders.embed_footer_url,
reminders.embed_image_url,
reminders.embed_thumbnail_url,
reminders.embed_title,
reminders.embed_fields,
reminders.enabled,
reminders.expires,
reminders.interval_seconds,
reminders.interval_days,
reminders.interval_months,
reminders.name,
reminders.tts,
reminders.uid,
reminders.utc_time
FROM reminders
LEFT JOIN channels ON channels.id = reminders.channel_id
WHERE uid = ?
",
reminder.uid
)
.fetch_one(pool.inner())
.await
{
Ok(reminder) => Ok(json!({"reminder": reminder, "errors": error})),
Err(e) => {
warn!("Error exiting `edit_reminder': {:?}", e);
Err(json!({"reminder": Option::<Reminder>::None, "errors": vec!["Unknown error"]}))
}
}
}

View File

@ -0,0 +1,451 @@
use base64::{prelude::BASE64_STANDARD, Engine};
use csv::{QuoteStyle, WriterBuilder};
use log::warn;
use rocket::{
get,
http::CookieJar,
put,
serde::json::{json, Json},
State,
};
use serenity::{
client::Context,
model::id::{ChannelId, GuildId, UserId},
};
use sqlx::{MySql, Pool};
use crate::web::{
check_authorization,
guards::transaction::Transaction,
routes::{
dashboard::{
create_reminder, generate_uid, ImportBody, Reminder, ReminderCsv, ReminderTemplateCsv,
TodoCsv,
},
JsonResult,
},
};
#[get("/api/guild/<id>/export/reminders")]
pub async fn export_reminders(
id: u64,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
let mut csv_writer = WriterBuilder::new().quote_style(QuoteStyle::Always).from_writer(vec![]);
let channels_res = GuildId::new(id).channels(&ctx.inner()).await;
match channels_res {
Ok(channels) => {
let channels = channels
.keys()
.into_iter()
.map(|k| k.get().to_string())
.collect::<Vec<String>>()
.join(",");
let result = sqlx::query_as_unchecked!(
ReminderCsv,
"SELECT
reminders.attachment,
reminders.attachment_name,
reminders.avatar,
CONCAT('#', channels.channel) AS channel,
reminders.content,
reminders.embed_author,
reminders.embed_author_url,
reminders.embed_color,
reminders.embed_description,
reminders.embed_footer,
reminders.embed_footer_url,
reminders.embed_image_url,
reminders.embed_thumbnail_url,
reminders.embed_title,
reminders.embed_fields,
reminders.enabled,
reminders.expires,
reminders.interval_seconds,
reminders.interval_days,
reminders.interval_months,
reminders.name,
reminders.restartable,
reminders.tts,
reminders.username,
reminders.utc_time
FROM reminders
LEFT JOIN channels ON channels.id = reminders.channel_id
WHERE FIND_IN_SET(channels.channel, ?) AND status = 'pending'",
channels
)
.fetch_all(pool.inner())
.await;
match result {
Ok(reminders) => {
reminders.iter().for_each(|reminder| {
csv_writer.serialize(reminder).unwrap();
});
match csv_writer.into_inner() {
Ok(inner) => match String::from_utf8(inner) {
Ok(encoded) => Ok(json!({ "body": encoded })),
Err(e) => {
warn!("Failed to write UTF-8: {:?}", e);
Err(json!({"error": "Failed to write UTF-8"}))
}
},
Err(e) => {
warn!("Failed to extract CSV: {:?}", e);
Err(json!({"error": "Failed to extract CSV"}))
}
}
}
Err(e) => {
warn!("Failed to complete SQL query: {:?}", e);
Err(json!({"error": "Failed to query reminders"}))
}
}
}
Err(e) => {
warn!("Could not fetch channels from {}: {:?}", id, e);
Err(json!({"error": "Failed to get guild channels"}))
}
}
}
#[put("/api/guild/<id>/export/reminders", data = "<body>")]
pub(crate) async fn import_reminders(
id: u64,
cookies: &CookieJar<'_>,
body: Json<ImportBody>,
ctx: &State<Context>,
mut transaction: Transaction<'_>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
let user_id =
cookies.get_private("userid").map(|c| c.value().parse::<u64>().ok()).flatten().unwrap();
match BASE64_STANDARD.decode(&body.body) {
Ok(body) => {
let mut reader = csv::Reader::from_reader(body.as_slice());
let mut count = 0;
for result in reader.deserialize::<ReminderCsv>() {
match result {
Ok(record) => {
let channel_id = record.channel.split_at(1).1;
match channel_id.parse::<u64>() {
Ok(channel_id) => {
let reminder = Reminder {
attachment: record.attachment,
attachment_name: record.attachment_name,
avatar: record.avatar,
channel: channel_id,
content: record.content,
embed_author: record.embed_author,
embed_author_url: record.embed_author_url,
embed_color: record.embed_color,
embed_description: record.embed_description,
embed_footer: record.embed_footer,
embed_footer_url: record.embed_footer_url,
embed_image_url: record.embed_image_url,
embed_thumbnail_url: record.embed_thumbnail_url,
embed_title: record.embed_title,
embed_fields: record
.embed_fields
.map(|s| serde_json::from_str(&s).ok())
.flatten(),
enabled: record.enabled,
expires: record.expires,
interval_seconds: record.interval_seconds,
interval_days: record.interval_days,
interval_months: record.interval_months,
name: record.name,
restartable: record.restartable,
tts: record.tts,
uid: generate_uid(),
username: record.username,
utc_time: record.utc_time,
};
create_reminder(
ctx.inner(),
&mut transaction,
GuildId::new(id),
UserId::new(user_id),
reminder,
)
.await?;
count += 1;
}
Err(_) => {
return json_err!(format!(
"Failed to parse channel {}",
channel_id
));
}
}
}
Err(e) => {
warn!("Couldn't deserialize CSV row: {:?}", e);
return json_err!("Deserialize error. Aborted");
}
}
}
match transaction.commit().await {
Ok(_) => Ok(json!({
"message": format!("Imported {} reminders", count)
})),
Err(e) => {
warn!("Failed to commit transaction: {:?}", e);
json_err!("Couldn't commit transaction")
}
}
}
Err(_) => {
json_err!("Malformed base64")
}
}
}
#[get("/api/guild/<id>/export/todos")]
pub async fn export_todos(
id: u64,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
let mut csv_writer = WriterBuilder::new().quote_style(QuoteStyle::Always).from_writer(vec![]);
match sqlx::query_as_unchecked!(
TodoCsv,
"SELECT value, CONCAT('#', channels.channel) AS channel_id FROM todos
LEFT JOIN channels ON todos.channel_id = channels.id
INNER JOIN guilds ON todos.guild_id = guilds.id
WHERE guilds.guild = ?",
id
)
.fetch_all(pool.inner())
.await
{
Ok(todos) => {
todos.iter().for_each(|todo| {
csv_writer.serialize(todo).unwrap();
});
match csv_writer.into_inner() {
Ok(inner) => match String::from_utf8(inner) {
Ok(encoded) => Ok(json!({ "body": encoded })),
Err(e) => {
warn!("Failed to write UTF-8: {:?}", e);
json_err!("Failed to write UTF-8")
}
},
Err(e) => {
warn!("Failed to extract CSV: {:?}", e);
json_err!("Failed to extract CSV")
}
}
}
Err(e) => {
warn!("Could not fetch templates from {}: {:?}", id, e);
json_err!("Failed to query templates")
}
}
}
#[put("/api/guild/<id>/export/todos", data = "<body>")]
pub async fn import_todos(
id: u64,
cookies: &CookieJar<'_>,
body: Json<ImportBody>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
let channels_res = GuildId::new(id).channels(&ctx.inner()).await;
match channels_res {
Ok(channels) => match BASE64_STANDARD.decode(&body.body) {
Ok(body) => {
let mut reader = csv::Reader::from_reader(body.as_slice());
let query_placeholder = "(?, (SELECT id FROM channels WHERE channel = ?), (SELECT id FROM guilds WHERE guild = ?))";
let mut query_params = vec![];
for result in reader.deserialize::<TodoCsv>() {
match result {
Ok(record) => match record.channel_id {
Some(channel_id) => {
let channel_id = channel_id.split_at(1).1;
match channel_id.parse::<u64>() {
Ok(channel_id) => {
if channels.contains_key(&ChannelId::new(channel_id)) {
query_params.push((record.value, Some(channel_id), id));
} else {
return json_err!(format!(
"Invalid channel ID {}",
channel_id
));
}
}
Err(_) => {
return json_err!(format!(
"Invalid channel ID {}",
channel_id
));
}
}
}
None => {
query_params.push((record.value, None, id));
}
},
Err(e) => {
warn!("Couldn't deserialize CSV row: {:?}", e);
return json_err!("Deserialize error. Aborted");
}
}
}
let query_str = format!(
"INSERT INTO todos (value, channel_id, guild_id) VALUES {}",
vec![query_placeholder].repeat(query_params.len()).join(",")
);
let mut query = sqlx::query(&query_str);
for param in query_params {
query = query.bind(param.0).bind(param.1).bind(param.2);
}
let res = query.execute(pool.inner()).await;
match res {
Ok(_) => Ok(json!({})),
Err(e) => {
warn!("Couldn't execute todo query: {:?}", e);
json_err!("An unexpected error occured.")
}
}
}
Err(_) => {
json_err!("Malformed base64")
}
},
Err(e) => {
warn!("Couldn't fetch channels for guild {}: {:?}", id, e);
json_err!("Couldn't fetch channels.")
}
}
}
#[get("/api/guild/<id>/export/reminder_templates")]
pub async fn export_reminder_templates(
id: u64,
cookies: &CookieJar<'_>,
ctx: &State<Context>,
pool: &State<Pool<MySql>>,
) -> JsonResult {
check_authorization(cookies, ctx.inner(), id).await?;
let mut csv_writer = WriterBuilder::new().quote_style(QuoteStyle::Always).from_writer(vec![]);
match sqlx::query_as_unchecked!(
ReminderTemplateCsv,
"SELECT
name,
attachment,
attachment_name,
avatar,
content,
embed_author,
embed_author_url,
embed_color,
embed_description,
embed_footer,
embed_footer_url,
embed_image_url,
embed_thumbnail_url,
embed_title,
embed_fields,
interval_seconds,
interval_days,
interval_months,
tts,
username
FROM reminder_template WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
id
)
.fetch_all(pool.inner())
.await
{
Ok(templates) => {
templates.iter().for_each(|template| {
csv_writer.serialize(template).unwrap();
});
match csv_writer.into_inner() {
Ok(inner) => match String::from_utf8(inner) {
Ok(encoded) => Ok(json!({ "body": encoded })),
Err(e) => {
warn!("Failed to write UTF-8: {:?}", e);
json_err!("Failed to write UTF-8")
}
},
Err(e) => {
warn!("Failed to extract CSV: {:?}", e);
json_err!("Failed to extract CSV")
}
}
}
Err(e) => {
warn!("Could not fetch templates from {}: {:?}", id, e);
json_err!("Failed to query templates")
}
}
}

View File

@ -0,0 +1,718 @@
use std::path::Path;
use base64::{prelude::BASE64_STANDARD, Engine};
use chrono::{naive::NaiveDateTime, Utc};
use log::warn;
use rand::{rngs::OsRng, seq::IteratorRandom};
use rocket::{
fs::NamedFile, get, http::CookieJar, response::Redirect, serde::json::json, Responder,
};
use rocket_dyn_templates::Template;
use secrecy::ExposeSecret;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use serenity::{
all::CacheHttp,
builder::CreateWebhook,
client::Context,
model::id::{ChannelId, GuildId, UserId},
};
use sqlx::types::Json;
use crate::web::{
catchers::internal_server_error,
check_guild_subscription, check_subscription,
consts::{
CHARACTERS, DAY, DEFAULT_AVATAR, MAX_CONTENT_LENGTH, MAX_EMBED_AUTHOR_LENGTH,
MAX_EMBED_DESCRIPTION_LENGTH, MAX_EMBED_FIELDS, MAX_EMBED_FIELD_TITLE_LENGTH,
MAX_EMBED_FIELD_VALUE_LENGTH, MAX_EMBED_FOOTER_LENGTH, MAX_EMBED_TITLE_LENGTH,
MAX_NAME_LENGTH, MAX_URL_LENGTH, MAX_USERNAME_LENGTH, MIN_INTERVAL,
},
guards::transaction::Transaction,
routes::JsonResult,
Error,
};
pub mod api;
pub mod export;
type Unset<T> = Option<T>;
fn name_default() -> String {
"Reminder".to_string()
}
fn template_name_default() -> String {
"Template".to_string()
}
fn channel_default() -> u64 {
0
}
fn id_default() -> u32 {
0
}
fn interval_default() -> Unset<Option<u32>> {
None
}
#[derive(sqlx::Type)]
#[sqlx(transparent)]
pub struct Attachment(Vec<u8>);
impl<'de> Deserialize<'de> for Attachment {
fn deserialize<D>(deserializer: D) -> Result<Attachment, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
Ok(Attachment(BASE64_STANDARD.decode(string).map_err(de::Error::custom)?))
}
}
impl Serialize for Attachment {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.collect_str(&BASE64_STANDARD.encode(&self.0))
}
}
#[derive(Serialize, Deserialize)]
pub struct ReminderTemplate {
#[serde(default = "id_default")]
id: u32,
#[serde(default = "id_default")]
guild_id: u32,
#[serde(default = "template_name_default")]
name: String,
attachment: Option<Attachment>,
attachment_name: Option<String>,
avatar: Option<String>,
content: String,
embed_author: String,
embed_author_url: Option<String>,
embed_color: u32,
embed_description: String,
embed_footer: String,
embed_footer_url: Option<String>,
embed_image_url: Option<String>,
embed_thumbnail_url: Option<String>,
embed_title: String,
embed_fields: Option<Json<Vec<EmbedField>>>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
tts: bool,
username: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct ReminderTemplateCsv {
#[serde(default = "template_name_default")]
name: String,
attachment: Option<Attachment>,
attachment_name: Option<String>,
avatar: Option<String>,
content: String,
embed_author: String,
embed_author_url: Option<String>,
embed_color: u32,
embed_description: String,
embed_footer: String,
embed_footer_url: Option<String>,
embed_image_url: Option<String>,
embed_thumbnail_url: Option<String>,
embed_title: String,
embed_fields: Option<String>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
tts: bool,
username: Option<String>,
}
#[derive(Deserialize)]
pub struct DeleteReminderTemplate {
id: u32,
}
#[derive(Serialize, Deserialize)]
pub struct EmbedField {
title: String,
value: String,
inline: bool,
}
#[derive(Serialize, Deserialize)]
pub struct Reminder {
attachment: Option<Attachment>,
attachment_name: Option<String>,
avatar: Option<String>,
#[serde(with = "string")]
channel: u64,
content: String,
embed_author: String,
embed_author_url: Option<String>,
embed_color: u32,
embed_description: String,
embed_footer: String,
embed_footer_url: Option<String>,
embed_image_url: Option<String>,
embed_thumbnail_url: Option<String>,
embed_title: String,
embed_fields: Option<Json<Vec<EmbedField>>>,
enabled: bool,
expires: Option<NaiveDateTime>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
#[serde(default = "name_default")]
name: String,
restartable: bool,
tts: bool,
#[serde(default)]
uid: String,
username: Option<String>,
utc_time: NaiveDateTime,
}
#[derive(Serialize, Deserialize)]
pub struct ReminderCsv {
attachment: Option<Attachment>,
attachment_name: Option<String>,
avatar: Option<String>,
channel: String,
content: String,
embed_author: String,
embed_author_url: Option<String>,
embed_color: u32,
embed_description: String,
embed_footer: String,
embed_footer_url: Option<String>,
embed_image_url: Option<String>,
embed_thumbnail_url: Option<String>,
embed_title: String,
embed_fields: Option<String>,
enabled: bool,
expires: Option<NaiveDateTime>,
interval_seconds: Option<u32>,
interval_days: Option<u32>,
interval_months: Option<u32>,
#[serde(default = "name_default")]
name: String,
restartable: bool,
tts: bool,
username: Option<String>,
utc_time: NaiveDateTime,
}
#[derive(Deserialize)]
pub struct PatchReminder {
uid: String,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
attachment: Unset<Option<Attachment>>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
attachment_name: Unset<Option<String>>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
avatar: Unset<Option<String>>,
#[serde(default = "channel_default")]
#[serde(with = "string")]
channel: u64,
#[serde(default)]
content: Unset<String>,
#[serde(default)]
embed_author: Unset<String>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
embed_author_url: Unset<Option<String>>,
#[serde(default)]
embed_color: Unset<u32>,
#[serde(default)]
embed_description: Unset<String>,
#[serde(default)]
embed_footer: Unset<String>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
embed_footer_url: Unset<Option<String>>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
embed_image_url: Unset<Option<String>>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
embed_thumbnail_url: Unset<Option<String>>,
#[serde(default)]
embed_title: Unset<String>,
#[serde(default)]
embed_fields: Unset<Json<Vec<EmbedField>>>,
#[serde(default)]
enabled: Unset<bool>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
expires: Unset<Option<NaiveDateTime>>,
#[serde(default = "interval_default")]
#[serde(deserialize_with = "deserialize_optional_field")]
interval_seconds: Unset<Option<u32>>,
#[serde(default = "interval_default")]
#[serde(deserialize_with = "deserialize_optional_field")]
interval_days: Unset<Option<u32>>,
#[serde(default = "interval_default")]
#[serde(deserialize_with = "deserialize_optional_field")]
interval_months: Unset<Option<u32>>,
#[serde(default)]
name: Unset<String>,
#[serde(default)]
restartable: Unset<bool>,
#[serde(default)]
tts: Unset<bool>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_field")]
username: Unset<Option<String>>,
#[serde(default)]
utc_time: Unset<NaiveDateTime>,
}
impl PatchReminder {
fn message_ok(&self) -> bool {
self.content.as_ref().map_or(true, |c| c.len() <= MAX_CONTENT_LENGTH)
&& self.embed_author.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_AUTHOR_LENGTH)
&& self
.embed_description
.as_ref()
.map_or(true, |c| c.len() <= MAX_EMBED_DESCRIPTION_LENGTH)
&& self.embed_footer.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_FOOTER_LENGTH)
&& self.embed_title.as_ref().map_or(true, |c| c.len() <= MAX_EMBED_TITLE_LENGTH)
&& self.embed_fields.as_ref().map_or(true, |c| {
c.0.len() <= MAX_EMBED_FIELDS
&& c.0.iter().all(|f| {
f.title.len() <= MAX_EMBED_FIELD_TITLE_LENGTH
&& f.value.len() <= MAX_EMBED_FIELD_VALUE_LENGTH
})
})
&& self
.username
.as_ref()
.map_or(true, |c| c.as_ref().map_or(true, |v| v.len() <= MAX_USERNAME_LENGTH))
}
}
pub fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
(0..64)
.map(|_| CHARACTERS.chars().choose(&mut generator).unwrap().to_owned().to_string())
.collect::<Vec<String>>()
.join("")
}
fn deserialize_optional_field<'de, T, D>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
Ok(Some(Option::deserialize(deserializer)?))
}
// https://github.com/serde-rs/json/issues/329#issuecomment-305608405
mod string {
use std::{fmt::Display, str::FromStr};
use serde::{de, Deserialize, Deserializer, Serializer};
pub fn serialize<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
T: Display,
S: Serializer,
{
serializer.collect_str(value)
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr,
T::Err: Display,
D: Deserializer<'de>,
{
String::deserialize(deserializer)?.parse().map_err(de::Error::custom)
}
}
#[derive(Deserialize)]
pub struct DeleteReminder {
uid: String,
}
#[derive(Deserialize)]
pub struct ImportBody {
body: String,
}
#[derive(Serialize, Deserialize)]
pub struct TodoCsv {
value: String,
channel_id: Option<String>,
}
pub(crate) async fn create_reminder(
ctx: &Context,
transaction: &mut Transaction<'_>,
guild_id: GuildId,
user_id: UserId,
reminder: Reminder,
) -> JsonResult {
// check guild in db
match sqlx::query!("SELECT 1 as A FROM guilds WHERE guild = ?", guild_id.get())
.fetch_one(transaction.executor())
.await
{
Err(sqlx::Error::RowNotFound) => {
if sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id.get())
.execute(transaction.executor())
.await
.is_err()
{
return Err(json!({"error": "Guild could not be created"}));
}
}
_ => {}
}
{
// validate channel
let channel = ChannelId::new(reminder.channel).to_channel_cached(&ctx.cache);
let channel_exists = channel.is_some();
let channel_matches_guild =
channel.map_or(false, |c| c.guild(&ctx.cache).map_or(false, |c| c.id == guild_id));
if !channel_matches_guild || !channel_exists {
warn!(
"Error in `create_reminder`: channel {} not found for guild {} (channel exists: {})",
reminder.channel, guild_id, channel_exists
);
return Err(json!({"error": "Channel not found"}));
}
}
let channel =
create_database_channel(&ctx, ChannelId::new(reminder.channel), transaction).await;
if let Err(e) = channel {
warn!("`create_database_channel` returned an error code: {:?}", e);
return Err(
json!({"error": "Failed to configure channel for reminders. Please check the bot permissions"}),
);
}
let channel = channel.unwrap();
// validate lengths
check_length!(MAX_NAME_LENGTH, reminder.name);
check_length!(MAX_CONTENT_LENGTH, reminder.content);
check_length!(MAX_EMBED_DESCRIPTION_LENGTH, reminder.embed_description);
check_length!(MAX_EMBED_TITLE_LENGTH, reminder.embed_title);
check_length!(MAX_EMBED_AUTHOR_LENGTH, reminder.embed_author);
check_length!(MAX_EMBED_FOOTER_LENGTH, reminder.embed_footer);
check_length_opt!(MAX_EMBED_FIELDS, reminder.embed_fields);
if let Some(fields) = &reminder.embed_fields {
for field in &fields.0 {
check_length!(MAX_EMBED_FIELD_VALUE_LENGTH, field.value);
check_length!(MAX_EMBED_FIELD_TITLE_LENGTH, field.title);
}
}
check_length_opt!(MAX_USERNAME_LENGTH, reminder.username);
check_length_opt!(
MAX_URL_LENGTH,
reminder.embed_footer_url,
reminder.embed_thumbnail_url,
reminder.embed_author_url,
reminder.embed_image_url,
reminder.avatar
);
// validate urls
check_url_opt!(
reminder.embed_footer_url,
reminder.embed_thumbnail_url,
reminder.embed_author_url,
reminder.embed_image_url,
reminder.avatar
);
// validate time and interval
if reminder.utc_time < Utc::now().naive_utc() {
return Err(json!({"error": "Time must be in the future"}));
}
if reminder.interval_seconds.is_some()
|| reminder.interval_days.is_some()
|| reminder.interval_months.is_some()
{
if reminder.interval_months.unwrap_or(0) * 30 * DAY as u32
+ reminder.interval_days.unwrap_or(0) * DAY as u32
+ reminder.interval_seconds.unwrap_or(0)
< *MIN_INTERVAL
{
return Err(json!({"error": "Interval too short"}));
}
}
// check patreon if necessary
if reminder.interval_seconds.is_some()
|| reminder.interval_days.is_some()
|| reminder.interval_months.is_some()
{
if !check_guild_subscription(&ctx, guild_id).await
&& !check_subscription(&ctx, user_id).await
{
return Err(json!({"error": "Patreon is required to set intervals"}));
}
}
let name = if reminder.name.is_empty() { name_default() } else { reminder.name.clone() };
let username = if reminder.username.as_ref().map(|s| s.is_empty()).unwrap_or(true) {
None
} else {
reminder.username
};
let new_uid = generate_uid();
// write to db
match sqlx::query!(
"INSERT INTO reminders (
uid,
attachment,
attachment_name,
channel_id,
avatar,
content,
embed_author,
embed_author_url,
embed_color,
embed_description,
embed_footer,
embed_footer_url,
embed_image_url,
embed_thumbnail_url,
embed_title,
embed_fields,
enabled,
expires,
interval_seconds,
interval_days,
interval_months,
name,
restartable,
tts,
username,
`utc_time`
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
new_uid,
reminder.attachment,
reminder.attachment_name,
channel,
reminder.avatar,
reminder.content,
reminder.embed_author,
reminder.embed_author_url,
reminder.embed_color,
reminder.embed_description,
reminder.embed_footer,
reminder.embed_footer_url,
reminder.embed_image_url,
reminder.embed_thumbnail_url,
reminder.embed_title,
reminder.embed_fields,
reminder.enabled,
reminder.expires,
reminder.interval_seconds,
reminder.interval_days,
reminder.interval_months,
name,
reminder.restartable,
reminder.tts,
username,
reminder.utc_time,
)
.execute(transaction.executor())
.await
{
Ok(_) => sqlx::query_as_unchecked!(
Reminder,
"SELECT
reminders.attachment,
reminders.attachment_name,
reminders.avatar,
channels.channel,
reminders.content,
reminders.embed_author,
reminders.embed_author_url,
reminders.embed_color,
reminders.embed_description,
reminders.embed_footer,
reminders.embed_footer_url,
reminders.embed_image_url,
reminders.embed_thumbnail_url,
reminders.embed_title,
reminders.embed_fields,
reminders.enabled,
reminders.expires,
reminders.interval_seconds,
reminders.interval_days,
reminders.interval_months,
reminders.name,
reminders.restartable,
reminders.tts,
reminders.uid,
reminders.username,
reminders.utc_time
FROM reminders
LEFT JOIN channels ON channels.id = reminders.channel_id
WHERE uid = ?",
new_uid
)
.fetch_one(transaction.executor())
.await
.map(|r| Ok(json!(r)))
.unwrap_or_else(|e| {
warn!("Failed to complete SQL query: {:?}", e);
Err(json!({"error": "Could not load reminder"}))
}),
Err(e) => {
warn!("Error in `create_reminder`: Could not execute query: {:?}", e);
Err(json!({"error": "Unknown error"}))
}
}
}
async fn create_database_channel(
ctx: impl CacheHttp,
channel: ChannelId,
transaction: &mut Transaction<'_>,
) -> Result<u32, Error> {
let row = sqlx::query!(
"SELECT webhook_token, webhook_id FROM channels WHERE channel = ?",
channel.get()
)
.fetch_one(transaction.executor())
.await;
match row {
Ok(row) => {
let is_dm =
channel.to_channel(&ctx).await.map_err(|e| Error::Serenity(e))?.private().is_some();
if !is_dm && (row.webhook_token.is_none() || row.webhook_id.is_none()) {
let webhook = channel
.create_webhook(&ctx, CreateWebhook::new("Reminder").avatar(&*DEFAULT_AVATAR))
.await
.map_err(|e| Error::Serenity(e))?;
let token = webhook.token.unwrap();
sqlx::query!(
"
UPDATE channels SET webhook_id = ?, webhook_token = ? WHERE channel = ?
",
webhook.id.get(),
token.expose_secret(),
channel.get()
)
.execute(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;
}
Ok(())
}
Err(sqlx::Error::RowNotFound) => {
// create webhook
let webhook = channel
.create_webhook(&ctx, CreateWebhook::new("Reminder").avatar(&*DEFAULT_AVATAR))
.await
.map_err(|e| Error::Serenity(e))?;
let token = webhook.token.unwrap();
// create database entry
sqlx::query!(
"
INSERT INTO channels (
webhook_id,
webhook_token,
channel
) VALUES (?, ?, ?)
",
webhook.id.get(),
token.expose_secret(),
channel.get()
)
.execute(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;
Ok(())
}
Err(e) => Err(Error::SQLx(e)),
}?;
let row = sqlx::query!("SELECT id FROM channels WHERE channel = ?", channel.get())
.fetch_one(transaction.executor())
.await
.map_err(|e| Error::SQLx(e))?;
Ok(row.id)
}
#[derive(Responder)]
pub enum DashboardPage {
#[response(status = 200)]
Ok(NamedFile),
#[response(status = 302)]
Unauthorised(Redirect),
#[response(status = 500)]
NotConfigured(Template),
}
#[get("/")]
pub async fn dashboard_home(cookies: &CookieJar<'_>) -> DashboardPage {
if cookies.get_private("userid").is_some() {
match NamedFile::open(Path::new(path!("static/index.html"))).await {
Ok(f) => DashboardPage::Ok(f),
Err(e) => {
warn!("Couldn't render dashboard: {:?}", e);
DashboardPage::NotConfigured(internal_server_error().await)
}
}
} else {
DashboardPage::Unauthorised(Redirect::to("/login/discord"))
}
}
#[get("/<_..>")]
pub async fn dashboard(cookies: &CookieJar<'_>) -> DashboardPage {
if cookies.get_private("userid").is_some() {
match NamedFile::open(Path::new(path!("static/index.html"))).await {
Ok(f) => DashboardPage::Ok(f),
Err(e) => {
warn!("Couldn't render dashboard: {:?}", e);
DashboardPage::NotConfigured(internal_server_error().await)
}
}
} else {
DashboardPage::Unauthorised(Redirect::to("/login/discord"))
}
}

154
src/web/routes/login.rs Normal file
View File

@ -0,0 +1,154 @@
use log::warn;
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthorizationCode, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse,
};
use reqwest::Client;
use rocket::{
get,
http::{private::cookie::Expiration, Cookie, CookieJar, SameSite},
response::{Flash, Redirect},
uri, State,
};
use serenity::model::user::User;
use crate::web::{consts::DISCORD_API, routes};
#[get("/discord")]
pub async fn discord_login(
oauth2_client: &State<BasicClient>,
cookies: &CookieJar<'_>,
) -> Redirect {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token) = oauth2_client
.authorize_url(CsrfToken::new_random)
// Set the desired scopes.
.add_scope(Scope::new("identify".to_string()))
.add_scope(Scope::new("guilds".to_string()))
// Set the PKCE code challenge.
.set_pkce_challenge(pkce_challenge)
.url();
// store the pkce secret to verify the authorization later
cookies.add_private(
Cookie::build(("verify", pkce_verifier.secret().to_string()))
.http_only(true)
.path("/login")
.same_site(SameSite::Lax)
.expires(Expiration::Session),
);
// store the csrf token to verify no interference
cookies.add_private(
Cookie::build(("csrf", csrf_token.secret().to_string()))
.http_only(true)
.path("/login")
.same_site(SameSite::Lax)
.expires(Expiration::Session),
);
Redirect::to(auth_url.to_string())
}
#[get("/discord/logout")]
pub async fn discord_logout(cookies: &CookieJar<'_>) -> Redirect {
cookies.remove_private(Cookie::from("username"));
cookies.remove_private(Cookie::from("userid"));
cookies.remove_private(Cookie::from("access_token"));
Redirect::to(uri!(routes::index))
}
#[get("/discord/authorized?<code>&<state>")]
pub async fn discord_callback(
code: &str,
state: &str,
cookies: &CookieJar<'_>,
oauth2_client: &State<BasicClient>,
reqwest_client: &State<Client>,
) -> Result<Redirect, Flash<Redirect>> {
if let (Some(pkce_secret), Some(csrf_token)) =
(cookies.get_private("verify"), cookies.get_private("csrf"))
{
if state == csrf_token.value() {
let token_result = oauth2_client
.exchange_code(AuthorizationCode::new(code.to_string()))
// Set the PKCE code verifier.
.set_pkce_verifier(PkceCodeVerifier::new(pkce_secret.value().to_string()))
.request_async(async_http_client)
.await;
cookies.remove_private(Cookie::from("verify"));
cookies.remove_private(Cookie::from("csrf"));
match token_result {
Ok(token) => {
cookies.add_private(
Cookie::build(("access_token", token.access_token().secret().to_string()))
.secure(true)
.http_only(true)
.path("/dashboard"),
);
let request_res = reqwest_client
.get(format!("{}/users/@me", DISCORD_API))
.bearer_auth(token.access_token().secret())
.send()
.await;
match request_res {
Ok(response) => {
let user_res = response.json::<User>().await;
match user_res {
Ok(user) => {
let user_id = user.id.get().to_string();
cookies.add_private(Cookie::new("username", user.name));
cookies.add_private(Cookie::new("userid", user_id));
Ok(Redirect::to(uri!(super::return_to_same_site("dashboard"))))
}
Err(e) => {
warn!("Error constructing user from request: {:?}", e);
Err(Flash::new(
Redirect::to(uri!(super::return_to_same_site(""))),
"danger",
"Failed to contact Discord",
))
}
}
}
Err(e) => {
warn!("Error getting user info: {:?}", e);
Err(Flash::new(
Redirect::to(uri!(super::return_to_same_site(""))),
"danger",
"Failed to contact Discord",
))
}
}
}
Err(e) => {
warn!("Error in discord callback: {:?}", e);
Err(Flash::new(
Redirect::to(uri!(super::return_to_same_site(""))),
"warning",
"Your login request was rejected. The server may be misconfigured. Please retry or alert us in Discord.",
))
}
}
} else {
Err(Flash::new(Redirect::to(uri!(super::return_to_same_site(""))), "danger", "Your request failed to validate, and so has been rejected (CSRF Validation Failure)"))
}
} else {
Err(Flash::new(Redirect::to(uri!(super::return_to_same_site(""))), "warning", "Your request was missing information, and so has been rejected (CSRF Validation Tokens Missing)"))
}
}

109
src/web/routes/mod.rs Normal file
View File

@ -0,0 +1,109 @@
pub mod dashboard;
pub mod login;
pub mod report;
use std::collections::HashMap;
use rocket::{get, request::FlashMessage, serde::json::Value as JsonValue};
use rocket_dyn_templates::Template;
pub type JsonResult = Result<JsonValue, JsonValue>;
#[get("/")]
pub async fn index(flash: Option<FlashMessage<'_>>) -> Template {
let mut map: HashMap<&str, String> = HashMap::new();
if let Some(message) = flash {
map.insert("flashed_message", message.message().to_string());
map.insert("flashed_grade", message.kind().to_string());
}
Template::render("index", &map)
}
#[get("/ret?<to>")]
pub async fn return_to_same_site(to: &str) -> Template {
let mut map: HashMap<&str, String> = HashMap::new();
map.insert("to", to.to_string());
Template::render("return", &map)
}
#[get("/cookies")]
pub async fn cookies() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("cookies", &map)
}
#[get("/privacy")]
pub async fn privacy() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("privacy", &map)
}
#[get("/terms")]
pub async fn terms() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("terms", &map)
}
#[get("/")]
pub async fn help() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("help", &map)
}
#[get("/timezone")]
pub async fn help_timezone() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/timezone", &map)
}
#[get("/create_reminder")]
pub async fn help_create_reminder() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/create_reminder", &map)
}
#[get("/delete_reminder")]
pub async fn help_delete_reminder() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/delete_reminder", &map)
}
#[get("/timers")]
pub async fn help_timers() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/timers", &map)
}
#[get("/todo_lists")]
pub async fn help_todo_lists() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/todo_lists", &map)
}
#[get("/macros")]
pub async fn help_macros() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/macros", &map)
}
#[get("/intervals")]
pub async fn help_intervals() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/intervals", &map)
}
#[get("/dashboard")]
pub async fn help_dashboard() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/dashboard", &map)
}
#[get("/iemanager")]
pub async fn help_iemanager() -> Template {
let map: HashMap<&str, String> = HashMap::new();
Template::render("support/iemanager", &map)
}

50
src/web/routes/report.rs Normal file
View File

@ -0,0 +1,50 @@
use log::error;
use rocket::{
http::CookieJar,
post,
serde::{
json::{json, Json},
Deserialize,
},
};
use crate::web::routes::JsonResult;
#[derive(Deserialize)]
pub struct ClientError {
#[serde(rename = "reporterId")]
reporter_id: String,
url: String,
#[serde(rename = "relativeTimestamp")]
relative_timestamp: i64,
#[serde(rename = "errorMessage")]
error_message: String,
#[serde(rename = "errorLine")]
error_line: u64,
#[serde(rename = "errorFile")]
error_file: String,
#[serde(rename = "errorType")]
error_type: String,
}
#[post("/report", data = "<client_error>")]
pub async fn report_error(cookies: &CookieJar<'_>, client_error: Json<ClientError>) -> JsonResult {
if let Some(user_id) = cookies.get_private("userid") {
error!(
"User {} reports a client-side error.
{}, {}:{} at {}ms
{}: {}
Chain: {}",
user_id,
client_error.url,
client_error.error_file,
client_error.error_line,
client_error.relative_timestamp,
client_error.error_type,
client_error.error_message,
client_error.reporter_id
);
}
Ok(json!({}))
}