357 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			357 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| use std::{collections::HashSet, fmt::Display};
 | |
| 
 | |
| use chrono::{Duration, NaiveDateTime, Utc};
 | |
| use chrono_tz::Tz;
 | |
| use poise::serenity_prelude::{
 | |
|     http::CacheHttp,
 | |
|     model::{
 | |
|         channel::GuildChannel,
 | |
|         id::{ChannelId, GuildId, UserId},
 | |
|         webhook::Webhook,
 | |
|     },
 | |
|     ChannelType, CreateWebhook, Result as SerenityResult,
 | |
| };
 | |
| use secrecy::ExposeSecret;
 | |
| use sqlx::MySqlPool;
 | |
| 
 | |
| use crate::{
 | |
|     consts::{DAY, DEFAULT_AVATAR, MAX_TIME, MIN_INTERVAL},
 | |
|     interval_parser::Interval,
 | |
|     models::{
 | |
|         channel_data::ChannelData,
 | |
|         reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
 | |
|         user_data::UserData,
 | |
|     },
 | |
|     Context,
 | |
| };
 | |
| 
 | |
| async fn create_webhook(
 | |
|     ctx: impl CacheHttp,
 | |
|     channel: GuildChannel,
 | |
|     name: impl ToString,
 | |
| ) -> SerenityResult<Webhook> {
 | |
|     channel.create_webhook(ctx.http(), CreateWebhook::new(name).avatar(&*DEFAULT_AVATAR)).await
 | |
| }
 | |
| 
 | |
| #[derive(Hash, PartialEq, Eq)]
 | |
| pub enum ReminderScope {
 | |
|     User(u64),
 | |
|     Channel(u64),
 | |
| }
 | |
| 
 | |
| impl ReminderScope {
 | |
|     pub fn mention(&self) -> String {
 | |
|         match self {
 | |
|             Self::User(id) => format!("<@{}>", id),
 | |
|             Self::Channel(id) => format!("<#{}>", id),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| pub struct ReminderBuilder {
 | |
|     pool: MySqlPool,
 | |
|     uid: String,
 | |
|     channel: u32,
 | |
|     thread_id: Option<u64>,
 | |
|     utc_time: NaiveDateTime,
 | |
|     timezone: String,
 | |
|     interval_seconds: Option<i64>,
 | |
|     interval_days: Option<i64>,
 | |
|     interval_months: Option<i64>,
 | |
|     expires: Option<NaiveDateTime>,
 | |
|     content: String,
 | |
|     tts: bool,
 | |
|     attachment_name: Option<String>,
 | |
|     attachment: Option<Vec<u8>>,
 | |
|     set_by: Option<u32>,
 | |
| }
 | |
| 
 | |
| impl ReminderBuilder {
 | |
|     pub async fn build(self) -> Result<Reminder, ReminderError> {
 | |
|         let queried_time = sqlx::query!(
 | |
|             "SELECT DATE_ADD(?, INTERVAL (SELECT nudge FROM channels WHERE id = ?) SECOND) AS `utc_time`",
 | |
|             self.utc_time,
 | |
|             self.channel,
 | |
|         )
 | |
|         .fetch_one(&self.pool)
 | |
|         .await
 | |
|         .unwrap();
 | |
| 
 | |
|         match queried_time.utc_time {
 | |
|             Some(utc_time) => {
 | |
|                 if utc_time < (Utc::now() - Duration::seconds(60)).naive_local() {
 | |
|                     Err(ReminderError::PastTime)
 | |
|                 } else {
 | |
|                     sqlx::query!(
 | |
|                         "
 | |
| INSERT INTO reminders (
 | |
|     `uid`,
 | |
|     `channel_id`,
 | |
|     `utc_time`,
 | |
|     `timezone`,
 | |
|     `interval_seconds`,
 | |
|     `interval_days`,
 | |
|     `interval_months`,
 | |
|     `expires`,
 | |
|     `content`,
 | |
|     `tts`,
 | |
|     `attachment_name`,
 | |
|     `attachment`,
 | |
|     `set_by`
 | |
| ) VALUES (
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?,
 | |
|     ?
 | |
| )
 | |
|             ",
 | |
|                         self.uid,
 | |
|                         self.channel,
 | |
|                         utc_time,
 | |
|                         self.timezone,
 | |
|                         self.interval_seconds,
 | |
|                         self.interval_days,
 | |
|                         self.interval_months,
 | |
|                         self.expires,
 | |
|                         self.content,
 | |
|                         self.tts,
 | |
|                         self.attachment_name,
 | |
|                         self.attachment,
 | |
|                         self.set_by
 | |
|                     )
 | |
|                     .execute(&self.pool)
 | |
|                     .await
 | |
|                     .unwrap();
 | |
| 
 | |
|                     Ok(Reminder::from_uid(&self.pool, &self.uid).await.unwrap())
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|             None => Err(ReminderError::LongTime),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| pub struct MultiReminderBuilder<'a> {
 | |
|     scopes: Vec<ReminderScope>,
 | |
|     utc_time: NaiveDateTime,
 | |
|     timezone: Tz,
 | |
|     interval: Option<Interval>,
 | |
|     expires: Option<NaiveDateTime>,
 | |
|     content: Content,
 | |
|     set_by: Option<u32>,
 | |
|     ctx: &'a Context<'a>,
 | |
|     guild_id: Option<GuildId>,
 | |
| }
 | |
| 
 | |
| impl<'a> MultiReminderBuilder<'a> {
 | |
|     pub fn new(ctx: &'a Context, guild_id: Option<GuildId>) -> Self {
 | |
|         MultiReminderBuilder {
 | |
|             scopes: vec![],
 | |
|             utc_time: Utc::now().naive_utc(),
 | |
|             timezone: Tz::UTC,
 | |
|             interval: None,
 | |
|             expires: None,
 | |
|             content: Content::new(),
 | |
|             set_by: None,
 | |
|             ctx,
 | |
|             guild_id,
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     pub fn timezone(mut self, timezone: Tz) -> Self {
 | |
|         self.timezone = timezone;
 | |
| 
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     pub fn content(mut self, content: Content) -> Self {
 | |
|         self.content = content;
 | |
| 
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     pub fn time<T: Into<i64>>(mut self, time: T) -> Self {
 | |
|         if let Some(utc_time) = NaiveDateTime::from_timestamp_opt(time.into(), 0) {
 | |
|             self.utc_time = utc_time;
 | |
|         }
 | |
| 
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     pub fn expires<T: Into<i64>>(mut self, time: Option<T>) -> Self {
 | |
|         self.expires = time.map(|t| NaiveDateTime::from_timestamp_opt(t.into(), 0)).flatten();
 | |
| 
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     pub fn author(mut self, user: UserData) -> Self {
 | |
|         self.set_by = Some(user.id);
 | |
|         self.timezone = user.timezone();
 | |
| 
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     pub fn interval(mut self, interval: Option<Interval>) -> Self {
 | |
|         self.interval = interval;
 | |
| 
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     pub fn set_scopes(&mut self, scopes: Vec<ReminderScope>) {
 | |
|         self.scopes = scopes;
 | |
|     }
 | |
| 
 | |
|     pub async fn build(self) -> (HashSet<ReminderError>, HashSet<(Reminder, ReminderScope)>) {
 | |
|         let mut errors = HashSet::new();
 | |
| 
 | |
|         let mut ok_locs = HashSet::new();
 | |
| 
 | |
|         if self
 | |
|             .interval
 | |
|             .map_or(false, |i| ((i.sec + i.day * DAY + i.month * 30 * DAY) as i64) < *MIN_INTERVAL)
 | |
|         {
 | |
|             errors.insert(ReminderError::ShortInterval);
 | |
|         } else if self
 | |
|             .interval
 | |
|             .map_or(false, |i| ((i.sec + i.day * DAY + i.month * 30 * DAY) as i64) > *MAX_TIME)
 | |
|         {
 | |
|             errors.insert(ReminderError::LongInterval);
 | |
|         } else {
 | |
|             for scope in self.scopes {
 | |
|                 let thread_id = None;
 | |
|                 let db_channel_id = match scope {
 | |
|                     ReminderScope::User(user_id) => {
 | |
|                         if let Ok(user) = UserId::new(user_id).to_user(&self.ctx).await {
 | |
|                             let user_data = UserData::from_user(
 | |
|                                 &user,
 | |
|                                 &self.ctx.serenity_context(),
 | |
|                                 &self.ctx.data().database,
 | |
|                             )
 | |
|                             .await
 | |
|                             .unwrap();
 | |
| 
 | |
|                             if let Some(guild_id) = self.guild_id {
 | |
|                                 if guild_id.member(&self.ctx, user).await.is_err() {
 | |
|                                     Err(ReminderError::InvalidTag)
 | |
|                                 } else if self.set_by.map_or(true, |i| i != user_data.id)
 | |
|                                     && !user_data.allowed_dm
 | |
|                                 {
 | |
|                                     Err(ReminderError::UserBlockedDm)
 | |
|                                 } else {
 | |
|                                     Ok(user_data.dm_channel)
 | |
|                                 }
 | |
|                             } else {
 | |
|                                 Ok(user_data.dm_channel)
 | |
|                             }
 | |
|                         } else {
 | |
|                             Err(ReminderError::InvalidTag)
 | |
|                         }
 | |
|                     }
 | |
|                     ReminderScope::Channel(channel_id) => {
 | |
|                         let channel =
 | |
|                             ChannelId::new(channel_id).to_channel(&self.ctx).await.unwrap();
 | |
| 
 | |
|                         if let Some(mut guild_channel) = channel.clone().guild() {
 | |
|                             if Some(guild_channel.guild_id) != self.guild_id {
 | |
|                                 Err(ReminderError::InvalidTag)
 | |
|                             } else {
 | |
|                                 let mut channel_data = if guild_channel.kind
 | |
|                                     == ChannelType::PublicThread
 | |
|                                 {
 | |
|                                     // fixme jesus christ
 | |
|                                     let parent = guild_channel
 | |
|                                         .parent_id
 | |
|                                         .unwrap()
 | |
|                                         .to_channel(&self.ctx)
 | |
|                                         .await
 | |
|                                         .unwrap();
 | |
|                                     guild_channel = parent.clone().guild().unwrap();
 | |
|                                     ChannelData::from_channel(&parent, &self.ctx.data().database)
 | |
|                                         .await
 | |
|                                         .unwrap()
 | |
|                                 } else {
 | |
|                                     ChannelData::from_channel(&channel, &self.ctx.data().database)
 | |
|                                         .await
 | |
|                                         .unwrap()
 | |
|                                 };
 | |
| 
 | |
|                                 if channel_data.webhook_id.is_none()
 | |
|                                     || channel_data.webhook_token.is_none()
 | |
|                                 {
 | |
|                                     match create_webhook(&self.ctx, guild_channel, "Reminder").await
 | |
|                                     {
 | |
|                                         Ok(webhook) => {
 | |
|                                             channel_data.webhook_id =
 | |
|                                                 Some(webhook.id.get().to_owned());
 | |
|                                             channel_data.webhook_token =
 | |
|                                                 webhook.token.map(|s| s.expose_secret());
 | |
| 
 | |
|                                             channel_data
 | |
|                                                 .commit_changes(&self.ctx.data().database)
 | |
|                                                 .await;
 | |
| 
 | |
|                                             Ok(channel_data.id)
 | |
|                                         }
 | |
| 
 | |
|                                         Err(e) => Err(ReminderError::DiscordError(e.to_string())),
 | |
|                                     }
 | |
|                                 } else {
 | |
|                                     Ok(channel_data.id)
 | |
|                                 }
 | |
|                             }
 | |
|                         } else {
 | |
|                             Err(ReminderError::InvalidTag)
 | |
|                         }
 | |
|                     }
 | |
|                 };
 | |
| 
 | |
|                 match db_channel_id {
 | |
|                     Ok(c) => {
 | |
|                         let builder = ReminderBuilder {
 | |
|                             pool: self.ctx.data().database.clone(),
 | |
|                             uid: generate_uid(),
 | |
|                             channel: c,
 | |
|                             thread_id,
 | |
|                             utc_time: self.utc_time,
 | |
|                             timezone: self.timezone.to_string(),
 | |
|                             interval_seconds: self.interval.map(|i| i.sec as i64),
 | |
|                             interval_days: self.interval.map(|i| i.day as i64),
 | |
|                             interval_months: self.interval.map(|i| i.month as i64),
 | |
|                             expires: self.expires,
 | |
|                             content: self.content.content.clone(),
 | |
|                             tts: self.content.tts,
 | |
|                             attachment_name: self.content.attachment_name.clone(),
 | |
|                             attachment: self.content.attachment.clone(),
 | |
|                             set_by: self.set_by,
 | |
|                         };
 | |
| 
 | |
|                         match builder.build().await {
 | |
|                             Ok(r) => {
 | |
|                                 ok_locs.insert((r, scope));
 | |
|                             }
 | |
|                             Err(e) => {
 | |
|                                 errors.insert(e);
 | |
|                             }
 | |
|                         }
 | |
|                     }
 | |
|                     Err(e) => {
 | |
|                         errors.insert(e);
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         (errors, ok_locs)
 | |
|     }
 | |
| }
 |