diff --git a/src/models/channel_data.rs b/src/models/channel_data.rs index 3373ca3..0122b80 100644 --- a/src/models/channel_data.rs +++ b/src/models/channel_data.rs @@ -1,9 +1,13 @@ use chrono::NaiveDateTime; -use poise::serenity_prelude::model::channel::Channel; +use poise::serenity_prelude::{model::channel::Channel, CacheHttp, ChannelId, CreateWebhook}; +use secrecy::ExposeSecret; use sqlx::MySqlPool; +use crate::{consts::DEFAULT_AVATAR, Error}; + pub struct ChannelData { pub id: u32, + pub channel: u64, pub name: Option, pub nudge: i16, pub blacklisted: bool, @@ -22,7 +26,12 @@ impl ChannelData { if let Ok(c) = sqlx::query_as_unchecked!( Self, - "SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?", + " + SELECT id, channel, name, nudge, blacklisted, webhook_id, webhook_token, paused, + paused_until + FROM channels + WHERE channel = ? + ", channel_id ) .fetch_one(pool) @@ -32,7 +41,8 @@ impl ChannelData { } else { let props = channel.to_owned().guild().map(|g| (g.guild_id.get().to_owned(), g.name)); - let (guild_id, channel_name) = if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) }; + let (guild_id, channel_name) = + if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) }; sqlx::query!( "INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))", @@ -46,7 +56,9 @@ impl ChannelData { Ok(sqlx::query_as_unchecked!( Self, " -SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? + SELECT id, channel, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until + FROM channels + WHERE channel = ? ", channel_id ) @@ -58,8 +70,16 @@ SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_u pub async fn commit_changes(&self, pool: &MySqlPool) { sqlx::query!( " -UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until \ - = ? WHERE id = ? + UPDATE channels + SET + name = ?, + nudge = ?, + blacklisted = ?, + webhook_id = ?, + webhook_token = ?, + paused = ?, + paused_until = ? + WHERE id = ? ", self.name, self.nudge, @@ -74,4 +94,24 @@ UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhoo .await .unwrap(); } + + pub async fn ensure_webhook( + &mut self, + ctx: impl CacheHttp, + pool: &MySqlPool, + ) -> Result<(), Error> { + if self.webhook_id.is_none() || self.webhook_token.is_none() { + let guild_channel = ChannelId::new(self.channel); + + let webhook = guild_channel + .create_webhook(ctx.http(), CreateWebhook::new("Reminder").avatar(&*DEFAULT_AVATAR)) + .await?; + + self.webhook_id = Some(webhook.id.get().to_owned()); + self.webhook_token = webhook.token.map(|s| s.expose_secret().clone()); + self.commit_changes(pool).await; + } + + Ok(()) + } } diff --git a/src/models/reminder/builder.rs b/src/models/reminder/builder.rs index 218fa8d..c77540e 100644 --- a/src/models/reminder/builder.rs +++ b/src/models/reminder/builder.rs @@ -3,19 +3,13 @@ use std::collections::HashSet; use chrono::{Duration, NaiveDateTime, Utc}; use chrono_tz::Tz; use poise::serenity_prelude::{ - http::CacheHttp, - model::{ - channel::GuildChannel, - id::{ChannelId, GuildId, UserId}, - webhook::Webhook, - }, - ChannelType, CreateWebhook, Result as SerenityResult, + model::id::{ChannelId, GuildId, UserId}, + ChannelType, }; -use secrecy::ExposeSecret; use sqlx::MySqlPool; use crate::{ - consts::{DAY, DEFAULT_AVATAR, MAX_TIME, MIN_INTERVAL}, + consts::{DAY, MAX_TIME, MIN_INTERVAL}, interval_parser::Interval, models::{ channel_data::ChannelData, @@ -25,25 +19,23 @@ use crate::{ Context, }; -async fn create_webhook( - ctx: impl CacheHttp, - channel: GuildChannel, - name: impl Into, -) -> SerenityResult { - channel.create_webhook(ctx.http(), CreateWebhook::new(name).avatar(&*DEFAULT_AVATAR)).await +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub struct ChannelWithThread { + pub channel_id: u64, + pub thread_id: Option, } #[derive(Hash, PartialEq, Eq)] pub enum ReminderScope { User(u64), - Channel(u64), + Channel(ChannelWithThread), } impl ReminderScope { pub fn mention(&self) -> String { match self { Self::User(id) => format!("<@{}>", id), - Self::Channel(id) => format!("<#{}>", id), + Self::Channel(c) => format!("<#{}>", c.channel_id), } } } @@ -89,6 +81,7 @@ impl ReminderBuilder { INSERT INTO reminders ( `uid`, `channel_id`, + `thread_id`, `utc_time`, `timezone`, `interval_seconds`, @@ -101,11 +94,12 @@ impl ReminderBuilder { `attachment`, `set_by` ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) ", self.uid, self.channel, + self.thread_id, utc_time, self.timezone, self.interval_seconds, @@ -218,7 +212,6 @@ impl<'a> MultiReminderBuilder<'a> { errors.insert(ReminderError::LongInterval); } else { for scope in self.scopes { - let thread_id = None; let db_channel_id = match scope { ReminderScope::User(user_id) => { if let Ok(user) = UserId::new(user_id).to_user(&self.ctx).await { @@ -238,34 +231,34 @@ impl<'a> MultiReminderBuilder<'a> { { Err(ReminderError::UserBlockedDm) } else { - Ok(user_data.dm_channel) + Ok((user_data.dm_channel, None)) } } else { - Ok(user_data.dm_channel) + Ok((user_data.dm_channel, None)) } } else { Err(ReminderError::InvalidTag) } } - ReminderScope::Channel(channel_id) => { - let channel = - ChannelId::new(channel_id).to_channel(&self.ctx).await.unwrap(); + ReminderScope::Channel(channel_with_thread) => { + let channel = ChannelId::new(channel_with_thread.channel_id) + .to_channel(&self.ctx) + .await + .unwrap(); - if let Some(mut guild_channel) = channel.clone().guild() { + if let Some(guild_channel) = channel.clone().guild() { if Some(guild_channel.guild_id) != self.guild_id { Err(ReminderError::InvalidTag) } else { let mut channel_data = if guild_channel.kind == ChannelType::PublicThread { - // fixme jesus christ let parent = guild_channel .parent_id .unwrap() .to_channel(&self.ctx) .await .unwrap(); - guild_channel = parent.clone().guild().unwrap(); ChannelData::from_channel(&parent, &self.ctx.data().database) .await .unwrap() @@ -275,28 +268,13 @@ impl<'a> MultiReminderBuilder<'a> { .unwrap() }; - if channel_data.webhook_id.is_none() - || channel_data.webhook_token.is_none() + match channel_data + .ensure_webhook(&self.ctx, &self.ctx.data().database) + .await + .map_err(|e| ReminderError::DiscordError(e.to_string())) { - match create_webhook(&self.ctx, guild_channel, "Reminder").await - { - Ok(webhook) => { - channel_data.webhook_id = - Some(webhook.id.get().to_owned()); - channel_data.webhook_token = - webhook.token.map(|s| s.expose_secret().clone()); - - channel_data - .commit_changes(&self.ctx.data().database) - .await; - - Ok(channel_data.id) - } - - Err(e) => Err(ReminderError::DiscordError(e.to_string())), - } - } else { - Ok(channel_data.id) + Ok(()) => Ok((channel_data.id, channel_with_thread.thread_id)), + Err(e) => Err(e), } } } else { @@ -306,11 +284,11 @@ impl<'a> MultiReminderBuilder<'a> { }; match db_channel_id { - Ok(c) => { + Ok((channel, thread_id)) => { let builder = ReminderBuilder { pool: self.ctx.data().database.clone(), uid: generate_uid(), - channel: c, + channel, thread_id, utc_time: self.utc_time, timezone: self.timezone.to_string(), diff --git a/src/models/reminder/mod.rs b/src/models/reminder/mod.rs index 21418ac..294c91e 100644 --- a/src/models/reminder/mod.rs +++ b/src/models/reminder/mod.rs @@ -26,7 +26,7 @@ use crate::{ interval_parser::parse_duration, models::{ reminder::{ - builder::{MultiReminderBuilder, ReminderScope}, + builder::{ChannelWithThread, MultiReminderBuilder, ReminderScope}, content::Content, errors::ReminderError, }, @@ -406,7 +406,8 @@ pub async fn create_reminder( let id = i.get(2).unwrap().as_str().parse::().unwrap(); if pref == "#" { - ReminderScope::Channel(id) + let channel_with_thread = ChannelWithThread { channel_id: id, thread_id: None }; + ReminderScope::Channel(channel_with_thread) } else { ReminderScope::User(id) } @@ -482,7 +483,11 @@ pub async fn create_reminder( if list.is_empty() { if ctx.guild_id().is_some() { - vec![ReminderScope::Channel(ctx.channel_id().get())] + let channel_with_threads = ChannelWithThread { + channel_id: ctx.channel_id().get(), + thread_id: None, + }; + vec![ReminderScope::Channel(channel_with_threads)] } else { vec![ReminderScope::User(ctx.author().id.get())] }