357 lines
12 KiB
Rust
357 lines
12 KiB
Rust
use std::{collections::HashSet, fmt::Display};
|
|
|
|
use chrono::{Duration, NaiveDateTime, Utc};
|
|
use chrono_tz::Tz;
|
|
use poise::serenity_prelude::{
|
|
http::CacheHttp,
|
|
model::{
|
|
channel::GuildChannel,
|
|
id::{ChannelId, GuildId, UserId},
|
|
webhook::Webhook,
|
|
},
|
|
ChannelType, CreateWebhook, Result as SerenityResult,
|
|
};
|
|
use secrecy::ExposeSecret;
|
|
use sqlx::MySqlPool;
|
|
|
|
use crate::{
|
|
consts::{DAY, DEFAULT_AVATAR, MAX_TIME, MIN_INTERVAL},
|
|
interval_parser::Interval,
|
|
models::{
|
|
channel_data::ChannelData,
|
|
reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
|
|
user_data::UserData,
|
|
},
|
|
Context,
|
|
};
|
|
|
|
async fn create_webhook(
|
|
ctx: impl CacheHttp,
|
|
channel: GuildChannel,
|
|
name: impl ToString,
|
|
) -> SerenityResult<Webhook> {
|
|
channel.create_webhook(ctx.http(), CreateWebhook::new(name).avatar(&*DEFAULT_AVATAR)).await
|
|
}
|
|
|
|
#[derive(Hash, PartialEq, Eq)]
|
|
pub enum ReminderScope {
|
|
User(u64),
|
|
Channel(u64),
|
|
}
|
|
|
|
impl ReminderScope {
|
|
pub fn mention(&self) -> String {
|
|
match self {
|
|
Self::User(id) => format!("<@{}>", id),
|
|
Self::Channel(id) => format!("<#{}>", id),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ReminderBuilder {
|
|
pool: MySqlPool,
|
|
uid: String,
|
|
channel: u32,
|
|
thread_id: Option<u64>,
|
|
utc_time: NaiveDateTime,
|
|
timezone: String,
|
|
interval_seconds: Option<i64>,
|
|
interval_days: Option<i64>,
|
|
interval_months: Option<i64>,
|
|
expires: Option<NaiveDateTime>,
|
|
content: String,
|
|
tts: bool,
|
|
attachment_name: Option<String>,
|
|
attachment: Option<Vec<u8>>,
|
|
set_by: Option<u32>,
|
|
}
|
|
|
|
impl ReminderBuilder {
|
|
pub async fn build(self) -> Result<Reminder, ReminderError> {
|
|
let queried_time = sqlx::query!(
|
|
"SELECT DATE_ADD(?, INTERVAL (SELECT nudge FROM channels WHERE id = ?) SECOND) AS `utc_time`",
|
|
self.utc_time,
|
|
self.channel,
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await
|
|
.unwrap();
|
|
|
|
match queried_time.utc_time {
|
|
Some(utc_time) => {
|
|
if utc_time < (Utc::now() - Duration::seconds(60)).naive_local() {
|
|
Err(ReminderError::PastTime)
|
|
} else {
|
|
sqlx::query!(
|
|
"
|
|
INSERT INTO reminders (
|
|
`uid`,
|
|
`channel_id`,
|
|
`utc_time`,
|
|
`timezone`,
|
|
`interval_seconds`,
|
|
`interval_days`,
|
|
`interval_months`,
|
|
`expires`,
|
|
`content`,
|
|
`tts`,
|
|
`attachment_name`,
|
|
`attachment`,
|
|
`set_by`
|
|
) VALUES (
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?,
|
|
?
|
|
)
|
|
",
|
|
self.uid,
|
|
self.channel,
|
|
utc_time,
|
|
self.timezone,
|
|
self.interval_seconds,
|
|
self.interval_days,
|
|
self.interval_months,
|
|
self.expires,
|
|
self.content,
|
|
self.tts,
|
|
self.attachment_name,
|
|
self.attachment,
|
|
self.set_by
|
|
)
|
|
.execute(&self.pool)
|
|
.await
|
|
.unwrap();
|
|
|
|
Ok(Reminder::from_uid(&self.pool, &self.uid).await.unwrap())
|
|
}
|
|
}
|
|
|
|
None => Err(ReminderError::LongTime),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct MultiReminderBuilder<'a> {
|
|
scopes: Vec<ReminderScope>,
|
|
utc_time: NaiveDateTime,
|
|
timezone: Tz,
|
|
interval: Option<Interval>,
|
|
expires: Option<NaiveDateTime>,
|
|
content: Content,
|
|
set_by: Option<u32>,
|
|
ctx: &'a Context<'a>,
|
|
guild_id: Option<GuildId>,
|
|
}
|
|
|
|
impl<'a> MultiReminderBuilder<'a> {
|
|
pub fn new(ctx: &'a Context, guild_id: Option<GuildId>) -> Self {
|
|
MultiReminderBuilder {
|
|
scopes: vec![],
|
|
utc_time: Utc::now().naive_utc(),
|
|
timezone: Tz::UTC,
|
|
interval: None,
|
|
expires: None,
|
|
content: Content::new(),
|
|
set_by: None,
|
|
ctx,
|
|
guild_id,
|
|
}
|
|
}
|
|
|
|
pub fn timezone(mut self, timezone: Tz) -> Self {
|
|
self.timezone = timezone;
|
|
|
|
self
|
|
}
|
|
|
|
pub fn content(mut self, content: Content) -> Self {
|
|
self.content = content;
|
|
|
|
self
|
|
}
|
|
|
|
pub fn time<T: Into<i64>>(mut self, time: T) -> Self {
|
|
if let Some(utc_time) = NaiveDateTime::from_timestamp_opt(time.into(), 0) {
|
|
self.utc_time = utc_time;
|
|
}
|
|
|
|
self
|
|
}
|
|
|
|
pub fn expires<T: Into<i64>>(mut self, time: Option<T>) -> Self {
|
|
self.expires = time.map(|t| NaiveDateTime::from_timestamp_opt(t.into(), 0)).flatten();
|
|
|
|
self
|
|
}
|
|
|
|
pub fn author(mut self, user: UserData) -> Self {
|
|
self.set_by = Some(user.id);
|
|
self.timezone = user.timezone();
|
|
|
|
self
|
|
}
|
|
|
|
pub fn interval(mut self, interval: Option<Interval>) -> Self {
|
|
self.interval = interval;
|
|
|
|
self
|
|
}
|
|
|
|
pub fn set_scopes(&mut self, scopes: Vec<ReminderScope>) {
|
|
self.scopes = scopes;
|
|
}
|
|
|
|
pub async fn build(self) -> (HashSet<ReminderError>, HashSet<(Reminder, ReminderScope)>) {
|
|
let mut errors = HashSet::new();
|
|
|
|
let mut ok_locs = HashSet::new();
|
|
|
|
if self
|
|
.interval
|
|
.map_or(false, |i| ((i.sec + i.day * DAY + i.month * 30 * DAY) as i64) < *MIN_INTERVAL)
|
|
{
|
|
errors.insert(ReminderError::ShortInterval);
|
|
} else if self
|
|
.interval
|
|
.map_or(false, |i| ((i.sec + i.day * DAY + i.month * 30 * DAY) as i64) > *MAX_TIME)
|
|
{
|
|
errors.insert(ReminderError::LongInterval);
|
|
} else {
|
|
for scope in self.scopes {
|
|
let thread_id = None;
|
|
let db_channel_id = match scope {
|
|
ReminderScope::User(user_id) => {
|
|
if let Ok(user) = UserId::new(user_id).to_user(&self.ctx).await {
|
|
let user_data = UserData::from_user(
|
|
&user,
|
|
&self.ctx.serenity_context(),
|
|
&self.ctx.data().database,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
if let Some(guild_id) = self.guild_id {
|
|
if guild_id.member(&self.ctx, user).await.is_err() {
|
|
Err(ReminderError::InvalidTag)
|
|
} else if self.set_by.map_or(true, |i| i != user_data.id)
|
|
&& !user_data.allowed_dm
|
|
{
|
|
Err(ReminderError::UserBlockedDm)
|
|
} else {
|
|
Ok(user_data.dm_channel)
|
|
}
|
|
} else {
|
|
Ok(user_data.dm_channel)
|
|
}
|
|
} else {
|
|
Err(ReminderError::InvalidTag)
|
|
}
|
|
}
|
|
ReminderScope::Channel(channel_id) => {
|
|
let channel =
|
|
ChannelId::new(channel_id).to_channel(&self.ctx).await.unwrap();
|
|
|
|
if let Some(mut guild_channel) = channel.clone().guild() {
|
|
if Some(guild_channel.guild_id) != self.guild_id {
|
|
Err(ReminderError::InvalidTag)
|
|
} else {
|
|
let mut channel_data = if guild_channel.kind
|
|
== ChannelType::PublicThread
|
|
{
|
|
// fixme jesus christ
|
|
let parent = guild_channel
|
|
.parent_id
|
|
.unwrap()
|
|
.to_channel(&self.ctx)
|
|
.await
|
|
.unwrap();
|
|
guild_channel = parent.clone().guild().unwrap();
|
|
ChannelData::from_channel(&parent, &self.ctx.data().database)
|
|
.await
|
|
.unwrap()
|
|
} else {
|
|
ChannelData::from_channel(&channel, &self.ctx.data().database)
|
|
.await
|
|
.unwrap()
|
|
};
|
|
|
|
if channel_data.webhook_id.is_none()
|
|
|| channel_data.webhook_token.is_none()
|
|
{
|
|
match create_webhook(&self.ctx, guild_channel, "Reminder").await
|
|
{
|
|
Ok(webhook) => {
|
|
channel_data.webhook_id =
|
|
Some(webhook.id.get().to_owned());
|
|
channel_data.webhook_token =
|
|
webhook.token.map(|s| s.expose_secret());
|
|
|
|
channel_data
|
|
.commit_changes(&self.ctx.data().database)
|
|
.await;
|
|
|
|
Ok(channel_data.id)
|
|
}
|
|
|
|
Err(e) => Err(ReminderError::DiscordError(e.to_string())),
|
|
}
|
|
} else {
|
|
Ok(channel_data.id)
|
|
}
|
|
}
|
|
} else {
|
|
Err(ReminderError::InvalidTag)
|
|
}
|
|
}
|
|
};
|
|
|
|
match db_channel_id {
|
|
Ok(c) => {
|
|
let builder = ReminderBuilder {
|
|
pool: self.ctx.data().database.clone(),
|
|
uid: generate_uid(),
|
|
channel: c,
|
|
thread_id,
|
|
utc_time: self.utc_time,
|
|
timezone: self.timezone.to_string(),
|
|
interval_seconds: self.interval.map(|i| i.sec as i64),
|
|
interval_days: self.interval.map(|i| i.day as i64),
|
|
interval_months: self.interval.map(|i| i.month as i64),
|
|
expires: self.expires,
|
|
content: self.content.content.clone(),
|
|
tts: self.content.tts,
|
|
attachment_name: self.content.attachment_name.clone(),
|
|
attachment: self.content.attachment.clone(),
|
|
set_by: self.set_by,
|
|
};
|
|
|
|
match builder.build().await {
|
|
Ok(r) => {
|
|
ok_locs.insert((r, scope));
|
|
}
|
|
Err(e) => {
|
|
errors.insert(e);
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
errors.insert(e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
(errors, ok_locs)
|
|
}
|
|
}
|