diff --git a/Cargo.toml b/Cargo.toml index 16cf289..beb2e09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "reminder_rs" -version = "1.6.0-beta2" +version = "1.6.0-beta3" authors = ["jellywx "] edition = "2018" diff --git a/README.md b/README.md index bb8e49e..1b52b54 100644 --- a/README.md +++ b/README.md @@ -41,4 +41,3 @@ __Other Variables__ * Convert aliases to macros * Help command -* Test everything diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index c1731dd..82b6455 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -323,7 +323,7 @@ async fn look(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) { .iter() .map(|reminder| reminder.display(&flags, &timezone)) .fold(0, |t, r| t + r.len()) - .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); + .div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH); let pager = LookPager::new(flags, timezone); diff --git a/src/component_models/mod.rs b/src/component_models/mod.rs index bf16b2b..18e0c2e 100644 --- a/src/component_models/mod.rs +++ b/src/component_models/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod pager; use std::io::Cursor; use chrono_tz::Tz; +use num_integer::Integer; use rmp_serde::Serializer; use serde::{Deserialize, Serialize}; use serenity::{ @@ -78,7 +79,7 @@ impl ComponentDataModel { .iter() .map(|reminder| reminder.display(&flags, &pager.timezone)) .fold(0, |t, r| t + r.len()) - .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); + .div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH); let channel_name = if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx) { diff --git a/src/consts.rs b/src/consts.rs index e1dc822..63e9511 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -14,6 +14,11 @@ use regex::Regex; use serenity::model::prelude::AttachmentType; lazy_static! { + pub static ref REMIND_INTERVAL: u64 = env::var("REMIND_INTERVAL") + .map(|inner| inner.parse::().ok()) + .ok() + .flatten() + .unwrap_or(10); pub static ref DEFAULT_AVATAR: AttachmentType<'static> = ( include_bytes!(concat!( env!("CARGO_MANIFEST_DIR"), diff --git a/src/main.rs b/src/main.rs index a5fa808..5813643 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,9 +9,17 @@ mod framework; mod hooks; mod interval_parser; mod models; +mod sender; mod time_parser; -use std::{collections::HashMap, env, sync::Arc}; +use std::{ + collections::HashMap, + env, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; use chrono_tz::Tz; use dotenv::dotenv; @@ -31,12 +39,15 @@ use serenity::{ utils::shard_id, }; use sqlx::mysql::MySqlPool; -use tokio::sync::RwLock; +use tokio::{ + sync::RwLock, + time::{Duration, Instant}, +}; use crate::{ commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds}, component_models::ComponentDataModel, - consts::{CNC_GUILD, SUBSCRIPTION_ROLES, THEME_COLOR}, + consts::{CNC_GUILD, REMIND_INTERVAL, SUBSCRIPTION_ROLES, THEME_COLOR}, framework::RegexFramework, models::command_macro::CommandMacro, }; @@ -65,10 +76,42 @@ impl TypeMapKey for RecordingMacros { type Value = Arc>>; } -struct Handler; +struct Handler { + is_loop_running: AtomicBool, +} #[async_trait] impl EventHandler for Handler { + async fn cache_ready(&self, ctx_base: Context, _guilds: Vec) { + info!("Cache Ready!"); + info!("Preparing to send reminders"); + + if !self.is_loop_running.load(Ordering::Relaxed) { + let ctx = ctx_base.clone(); + + tokio::spawn(async move { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + + loop { + let sleep_until = Instant::now() + Duration::from_secs(*REMIND_INTERVAL); + let reminders = sender::Reminder::fetch_reminders(&pool).await; + + if reminders.len() > 0 { + info!("Preparing to send {} reminders.", reminders.len()); + + for reminder in reminders { + reminder.send(pool.clone(), ctx.clone()).await; + } + } + + tokio::time::sleep_until(sleep_until).await; + } + }); + + self.is_loop_running.swap(true, Ordering::Relaxed); + } + } + async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) { let pool = ctx .data @@ -186,9 +229,11 @@ async fn main() -> Result<(), Box> { let token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment"); - let http = Http::new_with_token(&token); + let application_id = { + let http = Http::new_with_token(&token); - let application_id = http.get_current_application_info().await?.id; + http.get_current_application_info().await?.id + }; let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1"); @@ -226,7 +271,7 @@ async fn main() -> Result<(), Box> { let mut client = Client::builder(&token) .intents(GatewayIntents::GUILDS) .application_id(application_id.0) - .event_handler(Handler) + .event_handler(Handler { is_loop_running: AtomicBool::from(false) }) .await .expect("Error occurred creating client"); diff --git a/src/sender.rs b/src/sender.rs new file mode 100644 index 0000000..8a04631 --- /dev/null +++ b/src/sender.rs @@ -0,0 +1,552 @@ +use chrono::Duration; +use chrono_tz::Tz; +use log::{error, info, warn}; +use num_integer::Integer; +use regex::{Captures, Regex}; +use serenity::{ + builder::CreateEmbed, + http::{CacheHttp, Http, StatusCode}, + model::{ + channel::{Channel, Embed as SerenityEmbed}, + id::ChannelId, + webhook::Webhook, + }, + Error, Result, +}; +use sqlx::{ + types::chrono::{NaiveDateTime, Utc}, + MySqlPool, +}; + +lazy_static! { + pub static ref TIMEFROM_REGEX: Regex = + Regex::new(r#"<\d+):(?P.+)?>>"#).unwrap(); + pub static ref TIMENOW_REGEX: Regex = + Regex::new(r#"<(?:\w|/|_)+):(?P.+)?>>"#).unwrap(); +} + +fn fmt_displacement(format: &str, seconds: u64) -> String { + let mut seconds = seconds; + let mut days: u64 = 0; + let mut hours: u64 = 0; + let mut minutes: u64 = 0; + + for (rep, time_type, div) in + [("%d", &mut days, 86400), ("%h", &mut hours, 3600), ("%m", &mut minutes, 60)].iter_mut() + { + if format.contains(*rep) { + let (divided, new_seconds) = seconds.div_rem(&div); + + **time_type = divided; + seconds = new_seconds; + } + } + + format + .replace("%s", &seconds.to_string()) + .replace("%m", &minutes.to_string()) + .replace("%h", &hours.to_string()) + .replace("%d", &days.to_string()) +} + +pub fn substitute(string: &str) -> String { + let new = TIMEFROM_REGEX.replace(string, |caps: &Captures| { + let final_time = caps.name("time").unwrap().as_str(); + let format = caps.name("format").unwrap().as_str(); + + if let Ok(final_time) = final_time.parse::() { + let dt = NaiveDateTime::from_timestamp(final_time, 0); + let now = Utc::now().naive_utc(); + + let difference = { + if now < dt { + dt - Utc::now().naive_utc() + } else { + Utc::now().naive_utc() - dt + } + }; + + fmt_displacement(format, difference.num_seconds() as u64) + } else { + String::new() + } + }); + + TIMENOW_REGEX + .replace(&new, |caps: &Captures| { + let timezone = caps.name("timezone").unwrap().as_str(); + + println!("{}", timezone); + + if let Ok(tz) = timezone.parse::() { + let format = caps.name("format").unwrap().as_str(); + let now = Utc::now().with_timezone(&tz); + + now.format(format).to_string() + } else { + String::new() + } + }) + .to_string() +} + +struct Embed { + inner: EmbedInner, + fields: Vec, +} + +struct EmbedInner { + title: String, + description: String, + image_url: Option, + thumbnail_url: Option, + footer: String, + footer_url: Option, + author: String, + author_url: Option, + color: u32, +} + +struct EmbedField { + title: String, + value: String, + inline: bool, +} + +impl Embed { + pub async fn from_id(pool: &MySqlPool, id: u32) -> Option { + let mut inner = sqlx::query_as_unchecked!( + EmbedInner, + " +SELECT + `embed_title` AS title, + `embed_description` AS description, + `embed_image_url` AS image_url, + `embed_thumbnail_url` AS thumbnail_url, + `embed_footer` AS footer, + `embed_footer_url` AS footer_url, + `embed_author` AS author, + `embed_author_url` AS author_url, + `embed_color` AS color +FROM + reminders +WHERE + `id` = ? + ", + id + ) + .fetch_one(&pool.clone()) + .await + .unwrap(); + + inner.title = substitute(&inner.title); + inner.description = substitute(&inner.description); + inner.footer = substitute(&inner.footer); + + let mut fields = sqlx::query_as_unchecked!( + EmbedField, + " +SELECT + title, + value, + inline +FROM + embed_fields +WHERE + reminder_id = ? + ", + id + ) + .fetch_all(pool) + .await + .unwrap(); + + fields.iter_mut().for_each(|mut field| { + field.title = substitute(&field.title); + field.value = substitute(&field.value); + }); + + let e = Embed { inner, fields }; + + if e.has_content() { + Some(e) + } else { + None + } + } + + pub fn has_content(&self) -> bool { + if self.inner.title.is_empty() + && self.inner.description.is_empty() + && self.inner.image_url.is_none() + && self.inner.thumbnail_url.is_none() + && self.inner.footer.is_empty() + && self.inner.footer_url.is_none() + && self.inner.author.is_empty() + && self.inner.author_url.is_none() + && self.fields.is_empty() + { + false + } else { + true + } + } +} + +impl Into for Embed { + fn into(self) -> CreateEmbed { + let mut c = CreateEmbed::default(); + + c.title(&self.inner.title) + .description(&self.inner.description) + .color(self.inner.color) + .author(|a| { + a.name(&self.inner.author); + + if let Some(author_icon) = &self.inner.author_url { + a.icon_url(author_icon); + } + + a + }) + .footer(|f| { + f.text(&self.inner.footer); + + if let Some(footer_icon) = &self.inner.footer_url { + f.icon_url(footer_icon); + } + + f + }); + + for field in &self.fields { + c.field(&field.title, &field.value, field.inline); + } + + if let Some(image_url) = &self.inner.image_url { + c.image(image_url); + } + + if let Some(thumbnail_url) = &self.inner.thumbnail_url { + c.thumbnail(thumbnail_url); + } + + c + } +} + +#[derive(Debug)] +pub struct Reminder { + id: u32, + + channel_id: u64, + webhook_id: Option, + webhook_token: Option, + + channel_paused: bool, + channel_paused_until: Option, + enabled: bool, + + tts: bool, + pin: bool, + content: String, + attachment: Option>, + attachment_name: Option, + + utc_time: NaiveDateTime, + timezone: String, + restartable: bool, + expires: Option, + interval: Option, + + avatar: Option, + username: Option, +} + +impl Reminder { + pub async fn fetch_reminders(pool: &MySqlPool) -> Vec { + sqlx::query_as_unchecked!( + Reminder, + " +SELECT + reminders.`id` AS id, + + channels.`channel` AS channel_id, + channels.`webhook_id` AS webhook_id, + channels.`webhook_token` AS webhook_token, + + channels.`paused` AS channel_paused, + channels.`paused_until` AS channel_paused_until, + reminders.`enabled` AS enabled, + + reminders.`tts` AS tts, + reminders.`pin` AS pin, + reminders.`content` AS content, + reminders.`attachment` AS attachment, + reminders.`attachment_name` AS attachment_name, + + reminders.`utc_time` AS 'utc_time', + reminders.`timezone` AS timezone, + reminders.`restartable` AS restartable, + reminders.`expires` AS expires, + reminders.`interval` AS 'interval', + + reminders.`avatar` AS avatar, + reminders.`username` AS username +FROM + reminders +INNER JOIN + channels +ON + reminders.channel_id = channels.id +WHERE + reminders.`utc_time` < NOW() + ", + ) + .fetch_all(pool) + .await + .unwrap() + .into_iter() + .map(|mut rem| { + rem.content = substitute(&rem.content); + + rem + }) + .collect::>() + } + + async fn reset_webhook(&self, pool: &MySqlPool) { + let _ = sqlx::query!( + " +UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ? + ", + self.channel_id + ) + .execute(pool) + .await; + } + + async fn refresh(&self, pool: &MySqlPool) { + if let Some(interval) = self.interval { + let now = Utc::now().naive_local(); + let mut updated_reminder_time = self.utc_time; + + while updated_reminder_time < now { + updated_reminder_time += Duration::seconds(interval as i64); + } + + if self.expires.map_or(false, |expires| { + NaiveDateTime::from_timestamp(updated_reminder_time.timestamp(), 0) > expires + }) { + self.force_delete(pool).await; + } else { + sqlx::query!( + " +UPDATE reminders SET `utc_time` = ? WHERE `id` = ? + ", + updated_reminder_time, + self.id + ) + .execute(pool) + .await + .expect(&format!("Could not update time on Reminder {}", self.id)); + } + } else { + self.force_delete(pool).await; + } + } + + async fn force_delete(&self, pool: &MySqlPool) { + sqlx::query!( + " +DELETE FROM reminders WHERE `id` = ? + ", + self.id + ) + .execute(pool) + .await + .expect(&format!("Could not delete Reminder {}", self.id)); + } + + async fn pin_message>(&self, message_id: M, http: impl AsRef) { + let _ = http.as_ref().pin_message(self.channel_id, message_id.into(), None).await; + } + + pub async fn send(&self, pool: MySqlPool, cache_http: impl CacheHttp) { + async fn send_to_channel( + cache_http: impl CacheHttp, + reminder: &Reminder, + embed: Option, + ) -> Result<()> { + let channel = ChannelId(reminder.channel_id).to_channel(&cache_http).await; + + match channel { + Ok(Channel::Guild(channel)) => { + match channel + .send_message(&cache_http, |m| { + m.content(&reminder.content).tts(reminder.tts); + + if let (Some(attachment), Some(name)) = + (&reminder.attachment, &reminder.attachment_name) + { + m.add_file((attachment as &[u8], name.as_str())); + } + + if let Some(embed) = embed { + m.set_embed(embed); + } + + m + }) + .await + { + Ok(m) => { + if reminder.pin { + reminder.pin_message(m.id, cache_http.http()).await; + } + + Ok(()) + } + Err(e) => Err(e), + } + } + Ok(Channel::Private(channel)) => { + match channel + .send_message(&cache_http.http(), |m| { + m.content(&reminder.content).tts(reminder.tts); + + if let (Some(attachment), Some(name)) = + (&reminder.attachment, &reminder.attachment_name) + { + m.add_file((attachment as &[u8], name.as_str())); + } + + if let Some(embed) = embed { + m.set_embed(embed); + } + + m + }) + .await + { + Ok(m) => { + if reminder.pin { + reminder.pin_message(m.id, cache_http.http()).await; + } + + Ok(()) + } + Err(e) => Err(e), + } + } + Err(e) => Err(e), + _ => Err(Error::Other("Channel not of valid type")), + } + } + + async fn send_to_webhook( + cache_http: impl CacheHttp, + reminder: &Reminder, + webhook: Webhook, + embed: Option, + ) -> Result<()> { + match webhook + .execute(&cache_http.http(), reminder.pin || reminder.restartable, |w| { + w.content(&reminder.content).tts(reminder.tts); + + if let Some(username) = &reminder.username { + w.username(username); + } + + if let Some(avatar) = &reminder.avatar { + w.avatar_url(avatar); + } + + if let (Some(attachment), Some(name)) = + (&reminder.attachment, &reminder.attachment_name) + { + w.add_file((attachment as &[u8], name.as_str())); + } + + if let Some(embed) = embed { + w.embeds(vec![SerenityEmbed::fake(|c| { + *c = embed; + c + })]); + } + + w + }) + .await + { + Ok(m) => { + if reminder.pin { + if let Some(message) = m { + reminder.pin_message(message.id, cache_http.http()).await; + } + } + + Ok(()) + } + Err(e) => Err(e), + } + } + + if self.enabled + && !(self.channel_paused + && self + .channel_paused_until + .map_or(true, |inner| inner >= Utc::now().naive_local())) + { + let _ = sqlx::query!( + " +UPDATE `channels` SET paused = 0, paused_until = NULL WHERE `channel` = ? + ", + self.channel_id + ) + .execute(&pool.clone()) + .await; + + let embed = Embed::from_id(&pool.clone(), self.id).await.map(|e| e.into()); + + let result = if let (Some(webhook_id), Some(webhook_token)) = + (self.webhook_id, &self.webhook_token) + { + let webhook_res = + cache_http.http().get_webhook_with_token(webhook_id, webhook_token).await; + + if let Ok(webhook) = webhook_res { + send_to_webhook(cache_http, &self, webhook, embed).await + } else { + warn!("Webhook vanished: {:?}", webhook_res); + + self.reset_webhook(&pool.clone()).await; + send_to_channel(cache_http, &self, embed).await + } + } else { + send_to_channel(cache_http, &self, embed).await + }; + + if let Err(e) = result { + error!("Error sending {:?}: {:?}", self, e); + + if let Error::Http(error) = e { + if error.status_code() == Some(StatusCode::from_u16(404).unwrap()) { + error!("Seeing channel is deleted. Removing reminder"); + self.force_delete(&pool).await; + } else { + self.refresh(&pool).await; + } + } else { + self.refresh(&pool).await; + } + } else { + self.refresh(&pool).await; + } + } else { + info!("Reminder {} is paused", self.id); + + self.refresh(&pool).await; + } + } +}