restructured all the reminder creation stuff into builders
This commit is contained in:
		@@ -4,34 +4,46 @@ pub mod reminder;
 | 
			
		||||
pub mod timer;
 | 
			
		||||
pub mod user_data;
 | 
			
		||||
 | 
			
		||||
use serenity::{async_trait, model::id::GuildId, prelude::Context};
 | 
			
		||||
use serenity::{
 | 
			
		||||
    async_trait,
 | 
			
		||||
    model::id::{GuildId, UserId},
 | 
			
		||||
    prelude::Context,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{consts::DEFAULT_PREFIX, GuildDataCache, SQLPool};
 | 
			
		||||
 | 
			
		||||
use guild_data::GuildData;
 | 
			
		||||
 | 
			
		||||
use crate::models::user_data::UserData;
 | 
			
		||||
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
 | 
			
		||||
use tokio::sync::RwLock;
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
pub trait CtxGuildData {
 | 
			
		||||
pub trait CtxData {
 | 
			
		||||
    async fn guild_data<G: Into<GuildId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        guild_id: G,
 | 
			
		||||
    ) -> Result<Arc<RwLock<GuildData>>, sqlx::Error>;
 | 
			
		||||
 | 
			
		||||
    async fn user_data<U: Into<UserId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        user_id: U,
 | 
			
		||||
    ) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
 | 
			
		||||
 | 
			
		||||
    async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl CtxGuildData for Context {
 | 
			
		||||
impl CtxData for Context {
 | 
			
		||||
    async fn guild_data<G: Into<GuildId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        guild_id: G,
 | 
			
		||||
    ) -> Result<Arc<RwLock<GuildData>>, sqlx::Error> {
 | 
			
		||||
        let guild_id = guild_id.into();
 | 
			
		||||
 | 
			
		||||
        let guild = guild_id.to_guild_cached(&self.cache).await.unwrap();
 | 
			
		||||
        let guild = guild_id.to_guild_cached(&self.cache).unwrap();
 | 
			
		||||
 | 
			
		||||
        let guild_cache = self
 | 
			
		||||
            .data
 | 
			
		||||
@@ -62,6 +74,18 @@ impl CtxGuildData for Context {
 | 
			
		||||
        x
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn user_data<U: Into<UserId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        user_id: U,
 | 
			
		||||
    ) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        let user_id = user_id.into();
 | 
			
		||||
        let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
        let user = user_id.to_user(self).await.unwrap();
 | 
			
		||||
 | 
			
		||||
        UserData::from_user(&user, &self, &pool).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String {
 | 
			
		||||
        if let Some(guild_id) = guild_id {
 | 
			
		||||
            self.guild_data(guild_id)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										365
									
								
								src/models/reminder/builder.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										365
									
								
								src/models/reminder/builder.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,365 @@
 | 
			
		||||
use serenity::{
 | 
			
		||||
    client::Context,
 | 
			
		||||
    http::CacheHttp,
 | 
			
		||||
    model::{
 | 
			
		||||
        channel::GuildChannel,
 | 
			
		||||
        id::{ChannelId, GuildId, UserId},
 | 
			
		||||
        webhook::Webhook,
 | 
			
		||||
    },
 | 
			
		||||
    Result as SerenityResult,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use chrono::{Duration, NaiveDateTime, Utc};
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    consts::{MAX_TIME, MIN_INTERVAL},
 | 
			
		||||
    models::{
 | 
			
		||||
        channel_data::ChannelData,
 | 
			
		||||
        reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
 | 
			
		||||
        user_data::UserData,
 | 
			
		||||
    },
 | 
			
		||||
    time_parser::TimeParser,
 | 
			
		||||
    SQLPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
 | 
			
		||||
use std::{collections::HashSet, fmt::Display};
 | 
			
		||||
 | 
			
		||||
async fn create_webhook(
 | 
			
		||||
    ctx: impl CacheHttp,
 | 
			
		||||
    channel: GuildChannel,
 | 
			
		||||
    name: impl Display,
 | 
			
		||||
) -> SerenityResult<Webhook> {
 | 
			
		||||
    channel
 | 
			
		||||
        .create_webhook_with_avatar(
 | 
			
		||||
            ctx.http(),
 | 
			
		||||
            name,
 | 
			
		||||
            (
 | 
			
		||||
                include_bytes!(concat!(
 | 
			
		||||
                    env!("CARGO_MANIFEST_DIR"),
 | 
			
		||||
                    "/assets/",
 | 
			
		||||
                    env!(
 | 
			
		||||
                        "WEBHOOK_AVATAR",
 | 
			
		||||
                        "WEBHOOK_AVATAR not provided for compilation"
 | 
			
		||||
                    )
 | 
			
		||||
                )) as &[u8],
 | 
			
		||||
                env!("WEBHOOK_AVATAR"),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        .await
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Hash, PartialEq, Eq)]
 | 
			
		||||
pub enum ReminderScope {
 | 
			
		||||
    User(u64),
 | 
			
		||||
    Channel(u64),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ReminderScope {
 | 
			
		||||
    pub fn mention(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            Self::User(id) => format!("<@{}>", id),
 | 
			
		||||
            Self::Channel(id) => format!("<#{}>", id),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct ReminderBuilder {
 | 
			
		||||
    pool: MySqlPool,
 | 
			
		||||
    uid: String,
 | 
			
		||||
    channel: u32,
 | 
			
		||||
    utc_time: NaiveDateTime,
 | 
			
		||||
    timezone: String,
 | 
			
		||||
    interval: Option<i64>,
 | 
			
		||||
    expires: Option<NaiveDateTime>,
 | 
			
		||||
    content: String,
 | 
			
		||||
    tts: bool,
 | 
			
		||||
    attachment_name: Option<String>,
 | 
			
		||||
    attachment: Option<Vec<u8>>,
 | 
			
		||||
    set_by: Option<u32>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ReminderBuilder {
 | 
			
		||||
    pub async fn build(self) -> Result<Reminder, ReminderError> {
 | 
			
		||||
        let queried_time = sqlx::query!(
 | 
			
		||||
            "SELECT DATE_ADD(?, INTERVAL (SELECT nudge FROM channels WHERE id = ?) SECOND) AS `utc_time`",
 | 
			
		||||
            self.utc_time,
 | 
			
		||||
            self.channel,
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(&self.pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        match queried_time.utc_time {
 | 
			
		||||
            Some(utc_time) => {
 | 
			
		||||
                if utc_time < (Utc::now() + Duration::seconds(60)).naive_local() {
 | 
			
		||||
                    Err(ReminderError::PastTime)
 | 
			
		||||
                } else {
 | 
			
		||||
                    sqlx::query!(
 | 
			
		||||
                        "
 | 
			
		||||
INSERT INTO reminders (
 | 
			
		||||
    `uid`,
 | 
			
		||||
    `channel_id`,
 | 
			
		||||
    `utc_time`,
 | 
			
		||||
    `timezone`,
 | 
			
		||||
    `interval`,
 | 
			
		||||
    `expires`,
 | 
			
		||||
    `content`,
 | 
			
		||||
    `tts`,
 | 
			
		||||
    `attachment_name`,
 | 
			
		||||
    `attachment`,
 | 
			
		||||
    `set_by`
 | 
			
		||||
) VALUES (
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?
 | 
			
		||||
)
 | 
			
		||||
            ",
 | 
			
		||||
                        self.uid,
 | 
			
		||||
                        self.channel,
 | 
			
		||||
                        utc_time,
 | 
			
		||||
                        self.timezone,
 | 
			
		||||
                        self.interval,
 | 
			
		||||
                        self.expires,
 | 
			
		||||
                        self.content,
 | 
			
		||||
                        self.tts,
 | 
			
		||||
                        self.attachment_name,
 | 
			
		||||
                        self.attachment,
 | 
			
		||||
                        self.set_by
 | 
			
		||||
                    )
 | 
			
		||||
                    .execute(&self.pool)
 | 
			
		||||
                    .await;
 | 
			
		||||
 | 
			
		||||
                    Ok(Reminder::from_uid(&self.pool, self.uid).await.unwrap())
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            None => Err(ReminderError::LongTime),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct MultiReminderBuilder<'a> {
 | 
			
		||||
    scopes: Vec<ReminderScope>,
 | 
			
		||||
    utc_time: NaiveDateTime,
 | 
			
		||||
    utc_time_parser: Option<TimeParser>,
 | 
			
		||||
    timezone: Tz,
 | 
			
		||||
    interval: Option<i64>,
 | 
			
		||||
    expires: Option<NaiveDateTime>,
 | 
			
		||||
    expires_parser: Option<TimeParser>,
 | 
			
		||||
    content: Content,
 | 
			
		||||
    set_by: Option<u32>,
 | 
			
		||||
    ctx: &'a Context,
 | 
			
		||||
    guild_id: Option<GuildId>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> MultiReminderBuilder<'a> {
 | 
			
		||||
    pub fn new(ctx: &'a Context, guild_id: Option<GuildId>) -> Self {
 | 
			
		||||
        MultiReminderBuilder {
 | 
			
		||||
            scopes: vec![],
 | 
			
		||||
            utc_time: Utc::now().naive_utc(),
 | 
			
		||||
            utc_time_parser: None,
 | 
			
		||||
            timezone: Tz::UTC,
 | 
			
		||||
            interval: None,
 | 
			
		||||
            expires: None,
 | 
			
		||||
            expires_parser: None,
 | 
			
		||||
            content: Content::new(),
 | 
			
		||||
            set_by: None,
 | 
			
		||||
            ctx,
 | 
			
		||||
            guild_id,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn content(mut self, content: Content) -> Self {
 | 
			
		||||
        self.content = content;
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn time<T: Into<i64>>(mut self, time: T) -> Self {
 | 
			
		||||
        self.utc_time = NaiveDateTime::from_timestamp(time.into(), 0);
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn time_parser(mut self, parser: TimeParser) -> Self {
 | 
			
		||||
        self.utc_time_parser = Some(parser);
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn expires<T: Into<i64>>(mut self, time: Option<T>) -> Self {
 | 
			
		||||
        if let Some(t) = time {
 | 
			
		||||
            self.expires = Some(NaiveDateTime::from_timestamp(t.into(), 0));
 | 
			
		||||
        } else {
 | 
			
		||||
            self.expires = None;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn expires_parser(mut self, parser: Option<TimeParser>) -> Self {
 | 
			
		||||
        self.expires_parser = parser;
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn author(mut self, user: UserData) -> Self {
 | 
			
		||||
        self.set_by = Some(user.id);
 | 
			
		||||
        self.timezone = user.timezone();
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn interval(mut self, interval: Option<i64>) -> Self {
 | 
			
		||||
        self.interval = interval;
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn set_scopes(&mut self, scopes: Vec<ReminderScope>) {
 | 
			
		||||
        self.scopes = scopes;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn build(mut self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) {
 | 
			
		||||
        let pool = self
 | 
			
		||||
            .ctx
 | 
			
		||||
            .data
 | 
			
		||||
            .read()
 | 
			
		||||
            .await
 | 
			
		||||
            .get::<SQLPool>()
 | 
			
		||||
            .cloned()
 | 
			
		||||
            .unwrap();
 | 
			
		||||
 | 
			
		||||
        let mut errors = HashSet::new();
 | 
			
		||||
 | 
			
		||||
        let mut ok_locs = HashSet::new();
 | 
			
		||||
 | 
			
		||||
        if let Some(expire_parser) = self.expires_parser {
 | 
			
		||||
            if let Ok(expires) = expire_parser.timestamp() {
 | 
			
		||||
                self.expires = Some(NaiveDateTime::from_timestamp(expires, 0));
 | 
			
		||||
            } else {
 | 
			
		||||
                errors.insert(ReminderError::InvalidExpiration);
 | 
			
		||||
 | 
			
		||||
                return (errors, ok_locs);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if let Some(time_parser) = self.utc_time_parser {
 | 
			
		||||
            if let Ok(time) = time_parser.timestamp() {
 | 
			
		||||
                self.utc_time = NaiveDateTime::from_timestamp(time, 0);
 | 
			
		||||
            } else {
 | 
			
		||||
                errors.insert(ReminderError::InvalidTime);
 | 
			
		||||
 | 
			
		||||
                return (errors, ok_locs);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if self.interval.map_or(false, |i| (i as i64) < *MIN_INTERVAL) {
 | 
			
		||||
            errors.insert(ReminderError::ShortInterval);
 | 
			
		||||
        } else if self.interval.map_or(false, |i| (i as i64) > *MAX_TIME) {
 | 
			
		||||
            errors.insert(ReminderError::LongInterval);
 | 
			
		||||
        } else {
 | 
			
		||||
            for scope in self.scopes {
 | 
			
		||||
                let db_channel_id = match scope {
 | 
			
		||||
                    ReminderScope::User(user_id) => {
 | 
			
		||||
                        if let Ok(user) = UserId(user_id).to_user(&self.ctx).await {
 | 
			
		||||
                            let user_data =
 | 
			
		||||
                                UserData::from_user(&user, &self.ctx, &pool).await.unwrap();
 | 
			
		||||
 | 
			
		||||
                            if let Some(guild_id) = self.guild_id {
 | 
			
		||||
                                if guild_id.member(&self.ctx, user).await.is_err() {
 | 
			
		||||
                                    Err(ReminderError::InvalidTag)
 | 
			
		||||
                                } else {
 | 
			
		||||
                                    Ok(user_data.dm_channel)
 | 
			
		||||
                                }
 | 
			
		||||
                            } else {
 | 
			
		||||
                                Ok(user_data.dm_channel)
 | 
			
		||||
                            }
 | 
			
		||||
                        } else {
 | 
			
		||||
                            Err(ReminderError::InvalidTag)
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                    ReminderScope::Channel(channel_id) => {
 | 
			
		||||
                        let channel = ChannelId(channel_id).to_channel(&self.ctx).await.unwrap();
 | 
			
		||||
 | 
			
		||||
                        if let Some(guild_channel) = channel.clone().guild() {
 | 
			
		||||
                            if Some(guild_channel.guild_id) != self.guild_id {
 | 
			
		||||
                                Err(ReminderError::InvalidTag)
 | 
			
		||||
                            } else {
 | 
			
		||||
                                let mut channel_data =
 | 
			
		||||
                                    ChannelData::from_channel(channel, &pool).await.unwrap();
 | 
			
		||||
 | 
			
		||||
                                if channel_data.webhook_id.is_none()
 | 
			
		||||
                                    || channel_data.webhook_token.is_none()
 | 
			
		||||
                                {
 | 
			
		||||
                                    match create_webhook(&self.ctx, guild_channel, "Reminder").await
 | 
			
		||||
                                    {
 | 
			
		||||
                                        Ok(webhook) => {
 | 
			
		||||
                                            channel_data.webhook_id =
 | 
			
		||||
                                                Some(webhook.id.as_u64().to_owned());
 | 
			
		||||
                                            channel_data.webhook_token = webhook.token;
 | 
			
		||||
 | 
			
		||||
                                            channel_data.commit_changes(&pool).await;
 | 
			
		||||
 | 
			
		||||
                                            Ok(channel_data.id)
 | 
			
		||||
                                        }
 | 
			
		||||
 | 
			
		||||
                                        Err(e) => Err(ReminderError::DiscordError(e.to_string())),
 | 
			
		||||
                                    }
 | 
			
		||||
                                } else {
 | 
			
		||||
                                    Ok(channel_data.id)
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                        } else {
 | 
			
		||||
                            Err(ReminderError::InvalidTag)
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                match db_channel_id {
 | 
			
		||||
                    Ok(c) => {
 | 
			
		||||
                        let builder = ReminderBuilder {
 | 
			
		||||
                            pool: pool.clone(),
 | 
			
		||||
                            uid: generate_uid(),
 | 
			
		||||
                            channel: c,
 | 
			
		||||
                            utc_time: self.utc_time,
 | 
			
		||||
                            timezone: self.timezone.to_string(),
 | 
			
		||||
                            interval: self.interval,
 | 
			
		||||
                            expires: self.expires,
 | 
			
		||||
                            content: self.content.content.clone(),
 | 
			
		||||
                            tts: self.content.tts,
 | 
			
		||||
                            attachment_name: self.content.attachment_name.clone(),
 | 
			
		||||
                            attachment: self.content.attachment.clone(),
 | 
			
		||||
                            set_by: self.set_by,
 | 
			
		||||
                        };
 | 
			
		||||
 | 
			
		||||
                        match builder.build().await {
 | 
			
		||||
                            Ok(_) => {
 | 
			
		||||
                                ok_locs.insert(scope);
 | 
			
		||||
                            }
 | 
			
		||||
                            Err(e) => {
 | 
			
		||||
                                errors.insert(e);
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                    Err(e) => {
 | 
			
		||||
                        errors.insert(e);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        (errors, ok_locs)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										74
									
								
								src/models/reminder/content.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								src/models/reminder/content.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,74 @@
 | 
			
		||||
use serenity::model::{channel::Message, guild::Guild, misc::Mentionable};
 | 
			
		||||
 | 
			
		||||
use regex::Captures;
 | 
			
		||||
 | 
			
		||||
use crate::{consts::REGEX_CONTENT_SUBSTITUTION, models::reminder::errors::ContentError};
 | 
			
		||||
 | 
			
		||||
pub struct Content {
 | 
			
		||||
    pub content: String,
 | 
			
		||||
    pub tts: bool,
 | 
			
		||||
    pub attachment: Option<Vec<u8>>,
 | 
			
		||||
    pub attachment_name: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Content {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            content: "".to_string(),
 | 
			
		||||
            tts: false,
 | 
			
		||||
            attachment: None,
 | 
			
		||||
            attachment_name: None,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn build<S: ToString>(content: S, message: &Message) -> Result<Self, ContentError> {
 | 
			
		||||
        if message.attachments.len() > 1 {
 | 
			
		||||
            Err(ContentError::TooManyAttachments)
 | 
			
		||||
        } else if let Some(attachment) = message.attachments.get(0) {
 | 
			
		||||
            if attachment.size > 8_000_000 {
 | 
			
		||||
                Err(ContentError::AttachmentTooLarge)
 | 
			
		||||
            } else if let Ok(attachment_bytes) = attachment.download().await {
 | 
			
		||||
                Ok(Self {
 | 
			
		||||
                    content: content.to_string(),
 | 
			
		||||
                    tts: false,
 | 
			
		||||
                    attachment: Some(attachment_bytes),
 | 
			
		||||
                    attachment_name: Some(attachment.filename.clone()),
 | 
			
		||||
                })
 | 
			
		||||
            } else {
 | 
			
		||||
                Err(ContentError::AttachmentDownloadFailed)
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            Ok(Self {
 | 
			
		||||
                content: content.to_string(),
 | 
			
		||||
                tts: false,
 | 
			
		||||
                attachment: None,
 | 
			
		||||
                attachment_name: None,
 | 
			
		||||
            })
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn substitute(&mut self, guild: Guild) {
 | 
			
		||||
        if self.content.starts_with("/tts ") {
 | 
			
		||||
            self.tts = true;
 | 
			
		||||
            self.content = self.content.split_off(5);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.content = REGEX_CONTENT_SUBSTITUTION
 | 
			
		||||
            .replace(&self.content, |caps: &Captures| {
 | 
			
		||||
                if let Some(user) = caps.name("user") {
 | 
			
		||||
                    format!("<@{}>", user.as_str())
 | 
			
		||||
                } else if let Some(role_name) = caps.name("role") {
 | 
			
		||||
                    if let Some(role) = guild.role_by_name(role_name.as_str()) {
 | 
			
		||||
                        role.mention().to_string()
 | 
			
		||||
                    } else {
 | 
			
		||||
                        format!("<<{}>>", role_name.as_str().to_string())
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    String::new()
 | 
			
		||||
                }
 | 
			
		||||
            })
 | 
			
		||||
            .to_string()
 | 
			
		||||
            .replace("<<everyone>>", "@everyone")
 | 
			
		||||
            .replace("<<here>>", "@here");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										81
									
								
								src/models/reminder/errors.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								src/models/reminder/errors.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,81 @@
 | 
			
		||||
use crate::consts::{MAX_TIME, MIN_INTERVAL};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum InteractionError {
 | 
			
		||||
    InvalidFormat,
 | 
			
		||||
    InvalidBase64,
 | 
			
		||||
    InvalidSize,
 | 
			
		||||
    NoReminder,
 | 
			
		||||
    SignatureMismatch,
 | 
			
		||||
    InvalidAction,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToString for InteractionError {
 | 
			
		||||
    fn to_string(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            InteractionError::InvalidFormat => {
 | 
			
		||||
                String::from("The interaction data was improperly formatted")
 | 
			
		||||
            }
 | 
			
		||||
            InteractionError::InvalidBase64 => String::from("The interaction data was invalid"),
 | 
			
		||||
            InteractionError::InvalidSize => String::from("The interaction data was invalid"),
 | 
			
		||||
            InteractionError::NoReminder => String::from("Reminder could not be found"),
 | 
			
		||||
            InteractionError::SignatureMismatch => {
 | 
			
		||||
                String::from("Only the user who did the command can use interactions")
 | 
			
		||||
            }
 | 
			
		||||
            InteractionError::InvalidAction => String::from("The action was invalid"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(PartialEq, Eq, Hash, Debug)]
 | 
			
		||||
pub enum ReminderError {
 | 
			
		||||
    LongTime,
 | 
			
		||||
    LongInterval,
 | 
			
		||||
    PastTime,
 | 
			
		||||
    ShortInterval,
 | 
			
		||||
    InvalidTag,
 | 
			
		||||
    InvalidTime,
 | 
			
		||||
    InvalidExpiration,
 | 
			
		||||
    DiscordError(String),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ReminderError {
 | 
			
		||||
    pub fn display(&self, is_natural: bool) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            ReminderError::LongTime => "That time is too far in the future. Please specify a shorter time.".to_string(),
 | 
			
		||||
            ReminderError::LongInterval => format!("Please ensure the interval specified is less than {max_time} days", max_time = *MAX_TIME / 86_400),
 | 
			
		||||
            ReminderError::PastTime => "Please ensure the time provided is in the future. If the time should be in the future, please be more specific with the definition.".to_string(),
 | 
			
		||||
            ReminderError::ShortInterval => format!("Please ensure the interval provided is longer than {min_interval} seconds", min_interval = *MIN_INTERVAL),
 | 
			
		||||
            ReminderError::InvalidTag => "Couldn't find a location by your tag. Your tag must be either a channel or a user (not a role)".to_string(),
 | 
			
		||||
            ReminderError::InvalidTime => if is_natural {
 | 
			
		||||
                "Your time failed to process. Please make it as clear as possible, for example `\"16th of july\"` or `\"in 20 minutes\"`".to_string()
 | 
			
		||||
            } else {
 | 
			
		||||
                "Make sure the time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or `day/month/year-hour:minute:second`".to_string()
 | 
			
		||||
            },
 | 
			
		||||
            ReminderError::InvalidExpiration => if is_natural {
 | 
			
		||||
                "Your expiration time failed to process. Please make it as clear as possible, for example `\"16th of july\"` or `\"in 20 minutes\"`".to_string()
 | 
			
		||||
            } else {
 | 
			
		||||
                "Make sure the expiration time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or `day/month/year-hour:minute:second`".to_string()
 | 
			
		||||
            },
 | 
			
		||||
            ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum ContentError {
 | 
			
		||||
    TooManyAttachments,
 | 
			
		||||
    AttachmentTooLarge,
 | 
			
		||||
    AttachmentDownloadFailed,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToString for ContentError {
 | 
			
		||||
    fn to_string(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            ContentError::TooManyAttachments => "remind/too_many_attachments",
 | 
			
		||||
            ContentError::AttachmentTooLarge => "remind/attachment_too_large",
 | 
			
		||||
            ContentError::AttachmentDownloadFailed => "remind/attachment_download_failed",
 | 
			
		||||
        }
 | 
			
		||||
        .to_string()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										40
									
								
								src/models/reminder/helper.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								src/models/reminder/helper.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,40 @@
 | 
			
		||||
use crate::consts::{CHARACTERS, DAY, HOUR, MINUTE};
 | 
			
		||||
 | 
			
		||||
use num_integer::Integer;
 | 
			
		||||
 | 
			
		||||
use rand::{rngs::OsRng, seq::IteratorRandom};
 | 
			
		||||
 | 
			
		||||
pub fn longhand_displacement(seconds: u64) -> String {
 | 
			
		||||
    let (days, seconds) = seconds.div_rem(&DAY);
 | 
			
		||||
    let (hours, seconds) = seconds.div_rem(&HOUR);
 | 
			
		||||
    let (minutes, seconds) = seconds.div_rem(&MINUTE);
 | 
			
		||||
 | 
			
		||||
    let mut sections = vec![];
 | 
			
		||||
 | 
			
		||||
    for (var, name) in [days, hours, minutes, seconds]
 | 
			
		||||
        .iter()
 | 
			
		||||
        .zip(["days", "hours", "minutes", "seconds"].iter())
 | 
			
		||||
    {
 | 
			
		||||
        if *var > 0 {
 | 
			
		||||
            sections.push(format!("{} {}", var, name));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    sections.join(", ")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn generate_uid() -> String {
 | 
			
		||||
    let mut generator: OsRng = Default::default();
 | 
			
		||||
 | 
			
		||||
    (0..64)
 | 
			
		||||
        .map(|_| {
 | 
			
		||||
            CHARACTERS
 | 
			
		||||
                .chars()
 | 
			
		||||
                .choose(&mut generator)
 | 
			
		||||
                .unwrap()
 | 
			
		||||
                .to_owned()
 | 
			
		||||
                .to_string()
 | 
			
		||||
        })
 | 
			
		||||
        .collect::<Vec<String>>()
 | 
			
		||||
        .join("")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										59
									
								
								src/models/reminder/look_flags.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								src/models/reminder/look_flags.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,59 @@
 | 
			
		||||
use serenity::model::id::ChannelId;
 | 
			
		||||
 | 
			
		||||
use crate::consts::REGEX_CHANNEL;
 | 
			
		||||
 | 
			
		||||
pub enum TimeDisplayType {
 | 
			
		||||
    Absolute,
 | 
			
		||||
    Relative,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct LookFlags {
 | 
			
		||||
    pub limit: u16,
 | 
			
		||||
    pub show_disabled: bool,
 | 
			
		||||
    pub channel_id: Option<ChannelId>,
 | 
			
		||||
    pub time_display: TimeDisplayType,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for LookFlags {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            limit: u16::MAX,
 | 
			
		||||
            show_disabled: true,
 | 
			
		||||
            channel_id: None,
 | 
			
		||||
            time_display: TimeDisplayType::Relative,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl LookFlags {
 | 
			
		||||
    pub fn from_string(args: &str) -> Self {
 | 
			
		||||
        let mut new_flags: Self = Default::default();
 | 
			
		||||
 | 
			
		||||
        for arg in args.split(' ') {
 | 
			
		||||
            match arg {
 | 
			
		||||
                "enabled" => {
 | 
			
		||||
                    new_flags.show_disabled = false;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                "time" => {
 | 
			
		||||
                    new_flags.time_display = TimeDisplayType::Absolute;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                param => {
 | 
			
		||||
                    if let Ok(val) = param.parse::<u16>() {
 | 
			
		||||
                        new_flags.limit = val;
 | 
			
		||||
                    } else if let Some(channel) = REGEX_CHANNEL
 | 
			
		||||
                        .captures(arg)
 | 
			
		||||
                        .map(|cap| cap.get(1))
 | 
			
		||||
                        .flatten()
 | 
			
		||||
                        .map(|c| c.as_str().parse::<u64>().unwrap())
 | 
			
		||||
                    {
 | 
			
		||||
                        new_flags.channel_id = Some(ChannelId(channel));
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        new_flags
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,3 +1,9 @@
 | 
			
		||||
pub mod builder;
 | 
			
		||||
pub mod content;
 | 
			
		||||
pub mod errors;
 | 
			
		||||
mod helper;
 | 
			
		||||
pub mod look_flags;
 | 
			
		||||
 | 
			
		||||
use serenity::{
 | 
			
		||||
    client::Context,
 | 
			
		||||
    model::id::{ChannelId, GuildId, UserId},
 | 
			
		||||
@@ -6,32 +12,45 @@ use serenity::{
 | 
			
		||||
use chrono::NaiveDateTime;
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    consts::{DAY, HOUR, MINUTE, REGEX_CHANNEL},
 | 
			
		||||
    models::reminder::{
 | 
			
		||||
        errors::InteractionError,
 | 
			
		||||
        helper::longhand_displacement,
 | 
			
		||||
        look_flags::{LookFlags, TimeDisplayType},
 | 
			
		||||
    },
 | 
			
		||||
    SQLPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use num_integer::Integer;
 | 
			
		||||
use ring::hmac;
 | 
			
		||||
use std::convert::{TryFrom, TryInto};
 | 
			
		||||
use std::env;
 | 
			
		||||
 | 
			
		||||
fn longhand_displacement(seconds: u64) -> String {
 | 
			
		||||
    let (days, seconds) = seconds.div_rem(&DAY);
 | 
			
		||||
    let (hours, seconds) = seconds.div_rem(&HOUR);
 | 
			
		||||
    let (minutes, seconds) = seconds.div_rem(&MINUTE);
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
use std::{
 | 
			
		||||
    convert::{TryFrom, TryInto},
 | 
			
		||||
    env,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
    let mut sections = vec![];
 | 
			
		||||
#[derive(Clone, Copy)]
 | 
			
		||||
pub enum ReminderAction {
 | 
			
		||||
    Delete,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    for (var, name) in [days, hours, minutes, seconds]
 | 
			
		||||
        .iter()
 | 
			
		||||
        .zip(["days", "hours", "minutes", "seconds"].iter())
 | 
			
		||||
    {
 | 
			
		||||
        if *var > 0 {
 | 
			
		||||
            sections.push(format!("{} {}", var, name));
 | 
			
		||||
impl ToString for ReminderAction {
 | 
			
		||||
    fn to_string(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            Self::Delete => String::from("del"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    sections.join(", ")
 | 
			
		||||
impl TryFrom<&str> for ReminderAction {
 | 
			
		||||
    type Error = ();
 | 
			
		||||
 | 
			
		||||
    fn try_from(value: &str) -> Result<Self, Self::Error> {
 | 
			
		||||
        match value {
 | 
			
		||||
            "del" => Ok(Self::Delete),
 | 
			
		||||
 | 
			
		||||
            _ => Err(()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
@@ -49,9 +68,7 @@ pub struct Reminder {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Reminder {
 | 
			
		||||
    pub async fn from_uid(ctx: &Context, uid: String) -> Option<Self> {
 | 
			
		||||
        let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
    pub async fn from_uid(pool: &MySqlPool, uid: String) -> Option<Self> {
 | 
			
		||||
        sqlx::query_as_unchecked!(
 | 
			
		||||
            Self,
 | 
			
		||||
            "
 | 
			
		||||
@@ -81,7 +98,7 @@ WHERE
 | 
			
		||||
            ",
 | 
			
		||||
            uid
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(&pool)
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .ok()
 | 
			
		||||
    }
 | 
			
		||||
@@ -178,7 +195,7 @@ LIMIT
 | 
			
		||||
        let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
        if let Some(guild_id) = guild_id {
 | 
			
		||||
            let guild_opt = guild_id.to_guild_cached(&ctx).await;
 | 
			
		||||
            let guild_opt = guild_id.to_guild_cached(&ctx);
 | 
			
		||||
 | 
			
		||||
            if let Some(guild) = guild_opt {
 | 
			
		||||
                let channels = guild
 | 
			
		||||
@@ -333,7 +350,7 @@ WHERE
 | 
			
		||||
        member_id: U,
 | 
			
		||||
        payload: String,
 | 
			
		||||
    ) -> Result<(Self, ReminderAction), InteractionError> {
 | 
			
		||||
        let sections = payload.split(".").collect::<Vec<&str>>();
 | 
			
		||||
        let sections = payload.split('.').collect::<Vec<&str>>();
 | 
			
		||||
 | 
			
		||||
        if sections.len() != 3 {
 | 
			
		||||
            Err(InteractionError::InvalidFormat)
 | 
			
		||||
@@ -397,111 +414,3 @@ DELETE FROM reminders WHERE id = ?
 | 
			
		||||
        .unwrap();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum InteractionError {
 | 
			
		||||
    InvalidFormat,
 | 
			
		||||
    InvalidBase64,
 | 
			
		||||
    InvalidSize,
 | 
			
		||||
    NoReminder,
 | 
			
		||||
    SignatureMismatch,
 | 
			
		||||
    InvalidAction,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToString for InteractionError {
 | 
			
		||||
    fn to_string(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            InteractionError::InvalidFormat => {
 | 
			
		||||
                String::from("The interaction data was improperly formatted")
 | 
			
		||||
            }
 | 
			
		||||
            InteractionError::InvalidBase64 => String::from("The interaction data was invalid"),
 | 
			
		||||
            InteractionError::InvalidSize => String::from("The interaction data was invalid"),
 | 
			
		||||
            InteractionError::NoReminder => String::from("Reminder could not be found"),
 | 
			
		||||
            InteractionError::SignatureMismatch => {
 | 
			
		||||
                String::from("Only the user who did the command can use interactions")
 | 
			
		||||
            }
 | 
			
		||||
            InteractionError::InvalidAction => String::from("The action was invalid"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Clone, Copy)]
 | 
			
		||||
pub enum ReminderAction {
 | 
			
		||||
    Delete,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToString for ReminderAction {
 | 
			
		||||
    fn to_string(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            Self::Delete => String::from("del"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl TryFrom<&str> for ReminderAction {
 | 
			
		||||
    type Error = ();
 | 
			
		||||
 | 
			
		||||
    fn try_from(value: &str) -> Result<Self, Self::Error> {
 | 
			
		||||
        match value {
 | 
			
		||||
            "del" => Ok(Self::Delete),
 | 
			
		||||
 | 
			
		||||
            _ => Err(()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum TimeDisplayType {
 | 
			
		||||
    Absolute,
 | 
			
		||||
    Relative,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct LookFlags {
 | 
			
		||||
    pub limit: u16,
 | 
			
		||||
    pub show_disabled: bool,
 | 
			
		||||
    pub channel_id: Option<ChannelId>,
 | 
			
		||||
    time_display: TimeDisplayType,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for LookFlags {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            limit: u16::MAX,
 | 
			
		||||
            show_disabled: true,
 | 
			
		||||
            channel_id: None,
 | 
			
		||||
            time_display: TimeDisplayType::Relative,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl LookFlags {
 | 
			
		||||
    pub fn from_string(args: &str) -> Self {
 | 
			
		||||
        let mut new_flags: Self = Default::default();
 | 
			
		||||
 | 
			
		||||
        for arg in args.split(' ') {
 | 
			
		||||
            match arg {
 | 
			
		||||
                "enabled" => {
 | 
			
		||||
                    new_flags.show_disabled = false;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                "time" => {
 | 
			
		||||
                    new_flags.time_display = TimeDisplayType::Absolute;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                param => {
 | 
			
		||||
                    if let Ok(val) = param.parse::<u16>() {
 | 
			
		||||
                        new_flags.limit = val;
 | 
			
		||||
                    } else if let Some(channel) = REGEX_CHANNEL
 | 
			
		||||
                        .captures(&arg)
 | 
			
		||||
                        .map(|cap| cap.get(1))
 | 
			
		||||
                        .flatten()
 | 
			
		||||
                        .map(|c| c.as_str().parse::<u64>().unwrap())
 | 
			
		||||
                    {
 | 
			
		||||
                        new_flags.channel_id = Some(ChannelId(channel));
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        new_flags
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user