restructured all the reminder creation stuff into builders
This commit is contained in:
365
src/models/reminder/builder.rs
Normal file
365
src/models/reminder/builder.rs
Normal file
@ -0,0 +1,365 @@
|
||||
use serenity::{
|
||||
client::Context,
|
||||
http::CacheHttp,
|
||||
model::{
|
||||
channel::GuildChannel,
|
||||
id::{ChannelId, GuildId, UserId},
|
||||
webhook::Webhook,
|
||||
},
|
||||
Result as SerenityResult,
|
||||
};
|
||||
|
||||
use chrono::{Duration, NaiveDateTime, Utc};
|
||||
use chrono_tz::Tz;
|
||||
|
||||
use crate::{
|
||||
consts::{MAX_TIME, MIN_INTERVAL},
|
||||
models::{
|
||||
channel_data::ChannelData,
|
||||
reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
|
||||
user_data::UserData,
|
||||
},
|
||||
time_parser::TimeParser,
|
||||
SQLPool,
|
||||
};
|
||||
|
||||
use sqlx::MySqlPool;
|
||||
|
||||
use std::{collections::HashSet, fmt::Display};
|
||||
|
||||
async fn create_webhook(
|
||||
ctx: impl CacheHttp,
|
||||
channel: GuildChannel,
|
||||
name: impl Display,
|
||||
) -> SerenityResult<Webhook> {
|
||||
channel
|
||||
.create_webhook_with_avatar(
|
||||
ctx.http(),
|
||||
name,
|
||||
(
|
||||
include_bytes!(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/assets/",
|
||||
env!(
|
||||
"WEBHOOK_AVATAR",
|
||||
"WEBHOOK_AVATAR not provided for compilation"
|
||||
)
|
||||
)) as &[u8],
|
||||
env!("WEBHOOK_AVATAR"),
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
pub enum ReminderScope {
|
||||
User(u64),
|
||||
Channel(u64),
|
||||
}
|
||||
|
||||
impl ReminderScope {
|
||||
pub fn mention(&self) -> String {
|
||||
match self {
|
||||
Self::User(id) => format!("<@{}>", id),
|
||||
Self::Channel(id) => format!("<#{}>", id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReminderBuilder {
|
||||
pool: MySqlPool,
|
||||
uid: String,
|
||||
channel: u32,
|
||||
utc_time: NaiveDateTime,
|
||||
timezone: String,
|
||||
interval: Option<i64>,
|
||||
expires: Option<NaiveDateTime>,
|
||||
content: String,
|
||||
tts: bool,
|
||||
attachment_name: Option<String>,
|
||||
attachment: Option<Vec<u8>>,
|
||||
set_by: Option<u32>,
|
||||
}
|
||||
|
||||
impl ReminderBuilder {
|
||||
pub async fn build(self) -> Result<Reminder, ReminderError> {
|
||||
let queried_time = sqlx::query!(
|
||||
"SELECT DATE_ADD(?, INTERVAL (SELECT nudge FROM channels WHERE id = ?) SECOND) AS `utc_time`",
|
||||
self.utc_time,
|
||||
self.channel,
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match queried_time.utc_time {
|
||||
Some(utc_time) => {
|
||||
if utc_time < (Utc::now() + Duration::seconds(60)).naive_local() {
|
||||
Err(ReminderError::PastTime)
|
||||
} else {
|
||||
sqlx::query!(
|
||||
"
|
||||
INSERT INTO reminders (
|
||||
`uid`,
|
||||
`channel_id`,
|
||||
`utc_time`,
|
||||
`timezone`,
|
||||
`interval`,
|
||||
`expires`,
|
||||
`content`,
|
||||
`tts`,
|
||||
`attachment_name`,
|
||||
`attachment`,
|
||||
`set_by`
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?
|
||||
)
|
||||
",
|
||||
self.uid,
|
||||
self.channel,
|
||||
utc_time,
|
||||
self.timezone,
|
||||
self.interval,
|
||||
self.expires,
|
||||
self.content,
|
||||
self.tts,
|
||||
self.attachment_name,
|
||||
self.attachment,
|
||||
self.set_by
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await;
|
||||
|
||||
Ok(Reminder::from_uid(&self.pool, self.uid).await.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
None => Err(ReminderError::LongTime),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MultiReminderBuilder<'a> {
|
||||
scopes: Vec<ReminderScope>,
|
||||
utc_time: NaiveDateTime,
|
||||
utc_time_parser: Option<TimeParser>,
|
||||
timezone: Tz,
|
||||
interval: Option<i64>,
|
||||
expires: Option<NaiveDateTime>,
|
||||
expires_parser: Option<TimeParser>,
|
||||
content: Content,
|
||||
set_by: Option<u32>,
|
||||
ctx: &'a Context,
|
||||
guild_id: Option<GuildId>,
|
||||
}
|
||||
|
||||
impl<'a> MultiReminderBuilder<'a> {
|
||||
pub fn new(ctx: &'a Context, guild_id: Option<GuildId>) -> Self {
|
||||
MultiReminderBuilder {
|
||||
scopes: vec![],
|
||||
utc_time: Utc::now().naive_utc(),
|
||||
utc_time_parser: None,
|
||||
timezone: Tz::UTC,
|
||||
interval: None,
|
||||
expires: None,
|
||||
expires_parser: None,
|
||||
content: Content::new(),
|
||||
set_by: None,
|
||||
ctx,
|
||||
guild_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn content(mut self, content: Content) -> Self {
|
||||
self.content = content;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn time<T: Into<i64>>(mut self, time: T) -> Self {
|
||||
self.utc_time = NaiveDateTime::from_timestamp(time.into(), 0);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn time_parser(mut self, parser: TimeParser) -> Self {
|
||||
self.utc_time_parser = Some(parser);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn expires<T: Into<i64>>(mut self, time: Option<T>) -> Self {
|
||||
if let Some(t) = time {
|
||||
self.expires = Some(NaiveDateTime::from_timestamp(t.into(), 0));
|
||||
} else {
|
||||
self.expires = None;
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn expires_parser(mut self, parser: Option<TimeParser>) -> Self {
|
||||
self.expires_parser = parser;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn author(mut self, user: UserData) -> Self {
|
||||
self.set_by = Some(user.id);
|
||||
self.timezone = user.timezone();
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn interval(mut self, interval: Option<i64>) -> Self {
|
||||
self.interval = interval;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_scopes(&mut self, scopes: Vec<ReminderScope>) {
|
||||
self.scopes = scopes;
|
||||
}
|
||||
|
||||
pub async fn build(mut self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) {
|
||||
let pool = self
|
||||
.ctx
|
||||
.data
|
||||
.read()
|
||||
.await
|
||||
.get::<SQLPool>()
|
||||
.cloned()
|
||||
.unwrap();
|
||||
|
||||
let mut errors = HashSet::new();
|
||||
|
||||
let mut ok_locs = HashSet::new();
|
||||
|
||||
if let Some(expire_parser) = self.expires_parser {
|
||||
if let Ok(expires) = expire_parser.timestamp() {
|
||||
self.expires = Some(NaiveDateTime::from_timestamp(expires, 0));
|
||||
} else {
|
||||
errors.insert(ReminderError::InvalidExpiration);
|
||||
|
||||
return (errors, ok_locs);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(time_parser) = self.utc_time_parser {
|
||||
if let Ok(time) = time_parser.timestamp() {
|
||||
self.utc_time = NaiveDateTime::from_timestamp(time, 0);
|
||||
} else {
|
||||
errors.insert(ReminderError::InvalidTime);
|
||||
|
||||
return (errors, ok_locs);
|
||||
}
|
||||
}
|
||||
|
||||
if self.interval.map_or(false, |i| (i as i64) < *MIN_INTERVAL) {
|
||||
errors.insert(ReminderError::ShortInterval);
|
||||
} else if self.interval.map_or(false, |i| (i as i64) > *MAX_TIME) {
|
||||
errors.insert(ReminderError::LongInterval);
|
||||
} else {
|
||||
for scope in self.scopes {
|
||||
let db_channel_id = match scope {
|
||||
ReminderScope::User(user_id) => {
|
||||
if let Ok(user) = UserId(user_id).to_user(&self.ctx).await {
|
||||
let user_data =
|
||||
UserData::from_user(&user, &self.ctx, &pool).await.unwrap();
|
||||
|
||||
if let Some(guild_id) = self.guild_id {
|
||||
if guild_id.member(&self.ctx, user).await.is_err() {
|
||||
Err(ReminderError::InvalidTag)
|
||||
} else {
|
||||
Ok(user_data.dm_channel)
|
||||
}
|
||||
} else {
|
||||
Ok(user_data.dm_channel)
|
||||
}
|
||||
} else {
|
||||
Err(ReminderError::InvalidTag)
|
||||
}
|
||||
}
|
||||
ReminderScope::Channel(channel_id) => {
|
||||
let channel = ChannelId(channel_id).to_channel(&self.ctx).await.unwrap();
|
||||
|
||||
if let Some(guild_channel) = channel.clone().guild() {
|
||||
if Some(guild_channel.guild_id) != self.guild_id {
|
||||
Err(ReminderError::InvalidTag)
|
||||
} else {
|
||||
let mut channel_data =
|
||||
ChannelData::from_channel(channel, &pool).await.unwrap();
|
||||
|
||||
if channel_data.webhook_id.is_none()
|
||||
|| channel_data.webhook_token.is_none()
|
||||
{
|
||||
match create_webhook(&self.ctx, guild_channel, "Reminder").await
|
||||
{
|
||||
Ok(webhook) => {
|
||||
channel_data.webhook_id =
|
||||
Some(webhook.id.as_u64().to_owned());
|
||||
channel_data.webhook_token = webhook.token;
|
||||
|
||||
channel_data.commit_changes(&pool).await;
|
||||
|
||||
Ok(channel_data.id)
|
||||
}
|
||||
|
||||
Err(e) => Err(ReminderError::DiscordError(e.to_string())),
|
||||
}
|
||||
} else {
|
||||
Ok(channel_data.id)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(ReminderError::InvalidTag)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match db_channel_id {
|
||||
Ok(c) => {
|
||||
let builder = ReminderBuilder {
|
||||
pool: pool.clone(),
|
||||
uid: generate_uid(),
|
||||
channel: c,
|
||||
utc_time: self.utc_time,
|
||||
timezone: self.timezone.to_string(),
|
||||
interval: self.interval,
|
||||
expires: self.expires,
|
||||
content: self.content.content.clone(),
|
||||
tts: self.content.tts,
|
||||
attachment_name: self.content.attachment_name.clone(),
|
||||
attachment: self.content.attachment.clone(),
|
||||
set_by: self.set_by,
|
||||
};
|
||||
|
||||
match builder.build().await {
|
||||
Ok(_) => {
|
||||
ok_locs.insert(scope);
|
||||
}
|
||||
Err(e) => {
|
||||
errors.insert(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
errors.insert(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(errors, ok_locs)
|
||||
}
|
||||
}
|
74
src/models/reminder/content.rs
Normal file
74
src/models/reminder/content.rs
Normal file
@ -0,0 +1,74 @@
|
||||
use serenity::model::{channel::Message, guild::Guild, misc::Mentionable};
|
||||
|
||||
use regex::Captures;
|
||||
|
||||
use crate::{consts::REGEX_CONTENT_SUBSTITUTION, models::reminder::errors::ContentError};
|
||||
|
||||
pub struct Content {
|
||||
pub content: String,
|
||||
pub tts: bool,
|
||||
pub attachment: Option<Vec<u8>>,
|
||||
pub attachment_name: Option<String>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
content: "".to_string(),
|
||||
tts: false,
|
||||
attachment: None,
|
||||
attachment_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn build<S: ToString>(content: S, message: &Message) -> Result<Self, ContentError> {
|
||||
if message.attachments.len() > 1 {
|
||||
Err(ContentError::TooManyAttachments)
|
||||
} else if let Some(attachment) = message.attachments.get(0) {
|
||||
if attachment.size > 8_000_000 {
|
||||
Err(ContentError::AttachmentTooLarge)
|
||||
} else if let Ok(attachment_bytes) = attachment.download().await {
|
||||
Ok(Self {
|
||||
content: content.to_string(),
|
||||
tts: false,
|
||||
attachment: Some(attachment_bytes),
|
||||
attachment_name: Some(attachment.filename.clone()),
|
||||
})
|
||||
} else {
|
||||
Err(ContentError::AttachmentDownloadFailed)
|
||||
}
|
||||
} else {
|
||||
Ok(Self {
|
||||
content: content.to_string(),
|
||||
tts: false,
|
||||
attachment: None,
|
||||
attachment_name: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn substitute(&mut self, guild: Guild) {
|
||||
if self.content.starts_with("/tts ") {
|
||||
self.tts = true;
|
||||
self.content = self.content.split_off(5);
|
||||
}
|
||||
|
||||
self.content = REGEX_CONTENT_SUBSTITUTION
|
||||
.replace(&self.content, |caps: &Captures| {
|
||||
if let Some(user) = caps.name("user") {
|
||||
format!("<@{}>", user.as_str())
|
||||
} else if let Some(role_name) = caps.name("role") {
|
||||
if let Some(role) = guild.role_by_name(role_name.as_str()) {
|
||||
role.mention().to_string()
|
||||
} else {
|
||||
format!("<<{}>>", role_name.as_str().to_string())
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.replace("<<everyone>>", "@everyone")
|
||||
.replace("<<here>>", "@here");
|
||||
}
|
||||
}
|
81
src/models/reminder/errors.rs
Normal file
81
src/models/reminder/errors.rs
Normal file
@ -0,0 +1,81 @@
|
||||
use crate::consts::{MAX_TIME, MIN_INTERVAL};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum InteractionError {
|
||||
InvalidFormat,
|
||||
InvalidBase64,
|
||||
InvalidSize,
|
||||
NoReminder,
|
||||
SignatureMismatch,
|
||||
InvalidAction,
|
||||
}
|
||||
|
||||
impl ToString for InteractionError {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
InteractionError::InvalidFormat => {
|
||||
String::from("The interaction data was improperly formatted")
|
||||
}
|
||||
InteractionError::InvalidBase64 => String::from("The interaction data was invalid"),
|
||||
InteractionError::InvalidSize => String::from("The interaction data was invalid"),
|
||||
InteractionError::NoReminder => String::from("Reminder could not be found"),
|
||||
InteractionError::SignatureMismatch => {
|
||||
String::from("Only the user who did the command can use interactions")
|
||||
}
|
||||
InteractionError::InvalidAction => String::from("The action was invalid"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Debug)]
|
||||
pub enum ReminderError {
|
||||
LongTime,
|
||||
LongInterval,
|
||||
PastTime,
|
||||
ShortInterval,
|
||||
InvalidTag,
|
||||
InvalidTime,
|
||||
InvalidExpiration,
|
||||
DiscordError(String),
|
||||
}
|
||||
|
||||
impl ReminderError {
|
||||
pub fn display(&self, is_natural: bool) -> String {
|
||||
match self {
|
||||
ReminderError::LongTime => "That time is too far in the future. Please specify a shorter time.".to_string(),
|
||||
ReminderError::LongInterval => format!("Please ensure the interval specified is less than {max_time} days", max_time = *MAX_TIME / 86_400),
|
||||
ReminderError::PastTime => "Please ensure the time provided is in the future. If the time should be in the future, please be more specific with the definition.".to_string(),
|
||||
ReminderError::ShortInterval => format!("Please ensure the interval provided is longer than {min_interval} seconds", min_interval = *MIN_INTERVAL),
|
||||
ReminderError::InvalidTag => "Couldn't find a location by your tag. Your tag must be either a channel or a user (not a role)".to_string(),
|
||||
ReminderError::InvalidTime => if is_natural {
|
||||
"Your time failed to process. Please make it as clear as possible, for example `\"16th of july\"` or `\"in 20 minutes\"`".to_string()
|
||||
} else {
|
||||
"Make sure the time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or `day/month/year-hour:minute:second`".to_string()
|
||||
},
|
||||
ReminderError::InvalidExpiration => if is_natural {
|
||||
"Your expiration time failed to process. Please make it as clear as possible, for example `\"16th of july\"` or `\"in 20 minutes\"`".to_string()
|
||||
} else {
|
||||
"Make sure the expiration time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or `day/month/year-hour:minute:second`".to_string()
|
||||
},
|
||||
ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ContentError {
|
||||
TooManyAttachments,
|
||||
AttachmentTooLarge,
|
||||
AttachmentDownloadFailed,
|
||||
}
|
||||
|
||||
impl ToString for ContentError {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
ContentError::TooManyAttachments => "remind/too_many_attachments",
|
||||
ContentError::AttachmentTooLarge => "remind/attachment_too_large",
|
||||
ContentError::AttachmentDownloadFailed => "remind/attachment_download_failed",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
}
|
40
src/models/reminder/helper.rs
Normal file
40
src/models/reminder/helper.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use crate::consts::{CHARACTERS, DAY, HOUR, MINUTE};
|
||||
|
||||
use num_integer::Integer;
|
||||
|
||||
use rand::{rngs::OsRng, seq::IteratorRandom};
|
||||
|
||||
pub fn longhand_displacement(seconds: u64) -> String {
|
||||
let (days, seconds) = seconds.div_rem(&DAY);
|
||||
let (hours, seconds) = seconds.div_rem(&HOUR);
|
||||
let (minutes, seconds) = seconds.div_rem(&MINUTE);
|
||||
|
||||
let mut sections = vec![];
|
||||
|
||||
for (var, name) in [days, hours, minutes, seconds]
|
||||
.iter()
|
||||
.zip(["days", "hours", "minutes", "seconds"].iter())
|
||||
{
|
||||
if *var > 0 {
|
||||
sections.push(format!("{} {}", var, name));
|
||||
}
|
||||
}
|
||||
|
||||
sections.join(", ")
|
||||
}
|
||||
|
||||
pub fn generate_uid() -> String {
|
||||
let mut generator: OsRng = Default::default();
|
||||
|
||||
(0..64)
|
||||
.map(|_| {
|
||||
CHARACTERS
|
||||
.chars()
|
||||
.choose(&mut generator)
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
.to_string()
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("")
|
||||
}
|
59
src/models/reminder/look_flags.rs
Normal file
59
src/models/reminder/look_flags.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use serenity::model::id::ChannelId;
|
||||
|
||||
use crate::consts::REGEX_CHANNEL;
|
||||
|
||||
pub enum TimeDisplayType {
|
||||
Absolute,
|
||||
Relative,
|
||||
}
|
||||
|
||||
pub struct LookFlags {
|
||||
pub limit: u16,
|
||||
pub show_disabled: bool,
|
||||
pub channel_id: Option<ChannelId>,
|
||||
pub time_display: TimeDisplayType,
|
||||
}
|
||||
|
||||
impl Default for LookFlags {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
limit: u16::MAX,
|
||||
show_disabled: true,
|
||||
channel_id: None,
|
||||
time_display: TimeDisplayType::Relative,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LookFlags {
|
||||
pub fn from_string(args: &str) -> Self {
|
||||
let mut new_flags: Self = Default::default();
|
||||
|
||||
for arg in args.split(' ') {
|
||||
match arg {
|
||||
"enabled" => {
|
||||
new_flags.show_disabled = false;
|
||||
}
|
||||
|
||||
"time" => {
|
||||
new_flags.time_display = TimeDisplayType::Absolute;
|
||||
}
|
||||
|
||||
param => {
|
||||
if let Ok(val) = param.parse::<u16>() {
|
||||
new_flags.limit = val;
|
||||
} else if let Some(channel) = REGEX_CHANNEL
|
||||
.captures(arg)
|
||||
.map(|cap| cap.get(1))
|
||||
.flatten()
|
||||
.map(|c| c.as_str().parse::<u64>().unwrap())
|
||||
{
|
||||
new_flags.channel_id = Some(ChannelId(channel));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_flags
|
||||
}
|
||||
}
|
416
src/models/reminder/mod.rs
Normal file
416
src/models/reminder/mod.rs
Normal file
@ -0,0 +1,416 @@
|
||||
pub mod builder;
|
||||
pub mod content;
|
||||
pub mod errors;
|
||||
mod helper;
|
||||
pub mod look_flags;
|
||||
|
||||
use serenity::{
|
||||
client::Context,
|
||||
model::id::{ChannelId, GuildId, UserId},
|
||||
};
|
||||
|
||||
use chrono::NaiveDateTime;
|
||||
|
||||
use crate::{
|
||||
models::reminder::{
|
||||
errors::InteractionError,
|
||||
helper::longhand_displacement,
|
||||
look_flags::{LookFlags, TimeDisplayType},
|
||||
},
|
||||
SQLPool,
|
||||
};
|
||||
|
||||
use ring::hmac;
|
||||
|
||||
use sqlx::MySqlPool;
|
||||
use std::{
|
||||
convert::{TryFrom, TryInto},
|
||||
env,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ReminderAction {
|
||||
Delete,
|
||||
}
|
||||
|
||||
impl ToString for ReminderAction {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Self::Delete => String::from("del"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ReminderAction {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
"del" => Ok(Self::Delete),
|
||||
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Reminder {
|
||||
pub id: u32,
|
||||
pub uid: String,
|
||||
pub channel: u64,
|
||||
pub utc_time: NaiveDateTime,
|
||||
pub interval: Option<u32>,
|
||||
pub expires: Option<NaiveDateTime>,
|
||||
pub enabled: bool,
|
||||
pub content: String,
|
||||
pub embed_description: String,
|
||||
pub set_by: Option<u64>,
|
||||
}
|
||||
|
||||
impl Reminder {
|
||||
pub async fn from_uid(pool: &MySqlPool, uid: String) -> Option<Self> {
|
||||
sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
SELECT
|
||||
reminders.id,
|
||||
reminders.uid,
|
||||
channels.channel,
|
||||
reminders.utc_time,
|
||||
reminders.interval,
|
||||
reminders.expires,
|
||||
reminders.enabled,
|
||||
reminders.content,
|
||||
reminders.embed_description,
|
||||
users.user AS set_by
|
||||
FROM
|
||||
reminders
|
||||
INNER JOIN
|
||||
channels
|
||||
ON
|
||||
reminders.channel_id = channels.id
|
||||
LEFT JOIN
|
||||
users
|
||||
ON
|
||||
reminders.set_by = users.id
|
||||
WHERE
|
||||
reminders.uid = ?
|
||||
",
|
||||
uid
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub async fn from_id(ctx: &Context, id: u32) -> Option<Self> {
|
||||
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
SELECT
|
||||
reminders.id,
|
||||
reminders.uid,
|
||||
channels.channel,
|
||||
reminders.utc_time,
|
||||
reminders.interval,
|
||||
reminders.expires,
|
||||
reminders.enabled,
|
||||
reminders.content,
|
||||
reminders.embed_description,
|
||||
users.user AS set_by
|
||||
FROM
|
||||
reminders
|
||||
INNER JOIN
|
||||
channels
|
||||
ON
|
||||
reminders.channel_id = channels.id
|
||||
LEFT JOIN
|
||||
users
|
||||
ON
|
||||
reminders.set_by = users.id
|
||||
WHERE
|
||||
reminders.id = ?
|
||||
",
|
||||
id
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub async fn from_channel<C: Into<ChannelId>>(
|
||||
ctx: &Context,
|
||||
channel_id: C,
|
||||
flags: &LookFlags,
|
||||
) -> Vec<Self> {
|
||||
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
let enabled = if flags.show_disabled { "0,1" } else { "1" };
|
||||
let channel_id = channel_id.into();
|
||||
|
||||
sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
SELECT
|
||||
reminders.id,
|
||||
reminders.uid,
|
||||
channels.channel,
|
||||
reminders.utc_time,
|
||||
reminders.interval,
|
||||
reminders.expires,
|
||||
reminders.enabled,
|
||||
reminders.content,
|
||||
reminders.embed_description,
|
||||
users.user AS set_by
|
||||
FROM
|
||||
reminders
|
||||
INNER JOIN
|
||||
channels
|
||||
ON
|
||||
reminders.channel_id = channels.id
|
||||
LEFT JOIN
|
||||
users
|
||||
ON
|
||||
reminders.set_by = users.id
|
||||
WHERE
|
||||
channels.channel = ? AND
|
||||
FIND_IN_SET(reminders.enabled, ?)
|
||||
ORDER BY
|
||||
reminders.utc_time
|
||||
LIMIT
|
||||
?
|
||||
",
|
||||
channel_id.as_u64(),
|
||||
enabled,
|
||||
flags.limit
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn from_guild(ctx: &Context, guild_id: Option<GuildId>, user: UserId) -> Vec<Self> {
|
||||
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
if let Some(guild_id) = guild_id {
|
||||
let guild_opt = guild_id.to_guild_cached(&ctx);
|
||||
|
||||
if let Some(guild) = guild_opt {
|
||||
let channels = guild
|
||||
.channels
|
||||
.keys()
|
||||
.into_iter()
|
||||
.map(|k| k.as_u64().to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(",");
|
||||
|
||||
sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
SELECT
|
||||
reminders.id,
|
||||
reminders.uid,
|
||||
channels.channel,
|
||||
reminders.utc_time,
|
||||
reminders.interval,
|
||||
reminders.expires,
|
||||
reminders.enabled,
|
||||
reminders.content,
|
||||
reminders.embed_description,
|
||||
users.user AS set_by
|
||||
FROM
|
||||
reminders
|
||||
LEFT JOIN
|
||||
channels
|
||||
ON
|
||||
channels.id = reminders.channel_id
|
||||
LEFT JOIN
|
||||
users
|
||||
ON
|
||||
reminders.set_by = users.id
|
||||
WHERE
|
||||
FIND_IN_SET(channels.channel, ?)
|
||||
",
|
||||
channels
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
SELECT
|
||||
reminders.id,
|
||||
reminders.uid,
|
||||
channels.channel,
|
||||
reminders.utc_time,
|
||||
reminders.interval,
|
||||
reminders.expires,
|
||||
reminders.enabled,
|
||||
reminders.content,
|
||||
reminders.embed_description,
|
||||
users.user AS set_by
|
||||
FROM
|
||||
reminders
|
||||
LEFT JOIN
|
||||
channels
|
||||
ON
|
||||
channels.id = reminders.channel_id
|
||||
LEFT JOIN
|
||||
users
|
||||
ON
|
||||
reminders.set_by = users.id
|
||||
WHERE
|
||||
channels.guild_id = (SELECT id FROM guilds WHERE guild = ?)
|
||||
",
|
||||
guild_id.as_u64()
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
}
|
||||
} else {
|
||||
sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
SELECT
|
||||
reminders.id,
|
||||
reminders.uid,
|
||||
channels.channel,
|
||||
reminders.utc_time,
|
||||
reminders.interval,
|
||||
reminders.expires,
|
||||
reminders.enabled,
|
||||
reminders.content,
|
||||
reminders.embed_description,
|
||||
users.user AS set_by
|
||||
FROM
|
||||
reminders
|
||||
INNER JOIN
|
||||
channels
|
||||
ON
|
||||
channels.id = reminders.channel_id
|
||||
LEFT JOIN
|
||||
users
|
||||
ON
|
||||
reminders.set_by = users.id
|
||||
WHERE
|
||||
channels.id = (SELECT dm_channel FROM users WHERE user = ?)
|
||||
",
|
||||
user.as_u64()
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
}
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn display_content(&self) -> &str {
|
||||
if self.content.is_empty() {
|
||||
&self.embed_description
|
||||
} else {
|
||||
&self.content
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display(&self, flags: &LookFlags, inter: &str) -> String {
|
||||
let time_display = match flags.time_display {
|
||||
TimeDisplayType::Absolute => format!("<t:{}>", self.utc_time.timestamp()),
|
||||
|
||||
TimeDisplayType::Relative => format!("<t:{}:R>", self.utc_time.timestamp()),
|
||||
};
|
||||
|
||||
if let Some(interval) = self.interval {
|
||||
format!(
|
||||
"'{}' *{}* **{}**, repeating every **{}** (set by {})",
|
||||
self.display_content(),
|
||||
&inter,
|
||||
time_display,
|
||||
longhand_displacement(interval as u64),
|
||||
self.set_by
|
||||
.map(|i| format!("<@{}>", i))
|
||||
.unwrap_or_else(|| "unknown".to_string())
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"'{}' *{}* **{}** (set by {})",
|
||||
self.display_content(),
|
||||
&inter,
|
||||
time_display,
|
||||
self.set_by
|
||||
.map(|i| format!("<@{}>", i))
|
||||
.unwrap_or_else(|| "unknown".to_string())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_interaction<U: Into<u64>>(
|
||||
ctx: &Context,
|
||||
member_id: U,
|
||||
payload: String,
|
||||
) -> Result<(Self, ReminderAction), InteractionError> {
|
||||
let sections = payload.split('.').collect::<Vec<&str>>();
|
||||
|
||||
if sections.len() != 3 {
|
||||
Err(InteractionError::InvalidFormat)
|
||||
} else {
|
||||
let action = ReminderAction::try_from(sections[0])
|
||||
.map_err(|_| InteractionError::InvalidAction)?;
|
||||
|
||||
let reminder_id = u32::from_le_bytes(
|
||||
base64::decode(sections[1])
|
||||
.map_err(|_| InteractionError::InvalidBase64)?
|
||||
.try_into()
|
||||
.map_err(|_| InteractionError::InvalidSize)?,
|
||||
);
|
||||
|
||||
if let Some(reminder) = Self::from_id(ctx, reminder_id).await {
|
||||
if reminder.signed_action(member_id, action) == payload {
|
||||
Ok((reminder, action))
|
||||
} else {
|
||||
Err(InteractionError::SignatureMismatch)
|
||||
}
|
||||
} else {
|
||||
Err(InteractionError::NoReminder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signed_action<U: Into<u64>>(&self, member_id: U, action: ReminderAction) -> String {
|
||||
let s_key = hmac::Key::new(
|
||||
hmac::HMAC_SHA256,
|
||||
env::var("SECRET_KEY")
|
||||
.expect("No SECRET_KEY provided")
|
||||
.as_bytes(),
|
||||
);
|
||||
|
||||
let mut context = hmac::Context::with_key(&s_key);
|
||||
|
||||
context.update(&self.id.to_le_bytes());
|
||||
context.update(&member_id.into().to_le_bytes());
|
||||
|
||||
let signature = context.sign();
|
||||
|
||||
format!(
|
||||
"{}.{}.{}",
|
||||
action.to_string(),
|
||||
base64::encode(self.id.to_le_bytes()),
|
||||
base64::encode(&signature)
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn delete(&self, ctx: &Context) {
|
||||
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
sqlx::query!(
|
||||
"
|
||||
DELETE FROM reminders WHERE id = ?
|
||||
",
|
||||
self.id
|
||||
)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user