Begin to work on thread support

This commit is contained in:
jude 2024-03-03 11:58:22 +00:00
parent 8e6e1a18b7
commit dcee9e0d2a
3 changed files with 82 additions and 59 deletions

View File

@ -1,9 +1,13 @@
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use poise::serenity_prelude::model::channel::Channel; use poise::serenity_prelude::{model::channel::Channel, CacheHttp, ChannelId, CreateWebhook};
use secrecy::ExposeSecret;
use sqlx::MySqlPool; use sqlx::MySqlPool;
use crate::{consts::DEFAULT_AVATAR, Error};
pub struct ChannelData { pub struct ChannelData {
pub id: u32, pub id: u32,
pub channel: u64,
pub name: Option<String>, pub name: Option<String>,
pub nudge: i16, pub nudge: i16,
pub blacklisted: bool, pub blacklisted: bool,
@ -22,7 +26,12 @@ impl ChannelData {
if let Ok(c) = sqlx::query_as_unchecked!( if let Ok(c) = sqlx::query_as_unchecked!(
Self, Self,
"SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?", "
SELECT id, channel, name, nudge, blacklisted, webhook_id, webhook_token, paused,
paused_until
FROM channels
WHERE channel = ?
",
channel_id channel_id
) )
.fetch_one(pool) .fetch_one(pool)
@ -32,7 +41,8 @@ impl ChannelData {
} else { } else {
let props = channel.to_owned().guild().map(|g| (g.guild_id.get().to_owned(), g.name)); let props = channel.to_owned().guild().map(|g| (g.guild_id.get().to_owned(), g.name));
let (guild_id, channel_name) = if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) }; let (guild_id, channel_name) =
if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) };
sqlx::query!( sqlx::query!(
"INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))", "INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))",
@ -46,7 +56,9 @@ impl ChannelData {
Ok(sqlx::query_as_unchecked!( Ok(sqlx::query_as_unchecked!(
Self, Self,
" "
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? SELECT id, channel, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until
FROM channels
WHERE channel = ?
", ",
channel_id channel_id
) )
@ -58,8 +70,16 @@ SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_u
pub async fn commit_changes(&self, pool: &MySqlPool) { pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!( sqlx::query!(
" "
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until \ UPDATE channels
= ? WHERE id = ? SET
name = ?,
nudge = ?,
blacklisted = ?,
webhook_id = ?,
webhook_token = ?,
paused = ?,
paused_until = ?
WHERE id = ?
", ",
self.name, self.name,
self.nudge, self.nudge,
@ -74,4 +94,24 @@ UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhoo
.await .await
.unwrap(); .unwrap();
} }
pub async fn ensure_webhook(
&mut self,
ctx: impl CacheHttp,
pool: &MySqlPool,
) -> Result<(), Error> {
if self.webhook_id.is_none() || self.webhook_token.is_none() {
let guild_channel = ChannelId::new(self.channel);
let webhook = guild_channel
.create_webhook(ctx.http(), CreateWebhook::new("Reminder").avatar(&*DEFAULT_AVATAR))
.await?;
self.webhook_id = Some(webhook.id.get().to_owned());
self.webhook_token = webhook.token.map(|s| s.expose_secret().clone());
self.commit_changes(pool).await;
}
Ok(())
}
} }

View File

@ -3,19 +3,13 @@ use std::collections::HashSet;
use chrono::{Duration, NaiveDateTime, Utc}; use chrono::{Duration, NaiveDateTime, Utc};
use chrono_tz::Tz; use chrono_tz::Tz;
use poise::serenity_prelude::{ use poise::serenity_prelude::{
http::CacheHttp, model::id::{ChannelId, GuildId, UserId},
model::{ ChannelType,
channel::GuildChannel,
id::{ChannelId, GuildId, UserId},
webhook::Webhook,
},
ChannelType, CreateWebhook, Result as SerenityResult,
}; };
use secrecy::ExposeSecret;
use sqlx::MySqlPool; use sqlx::MySqlPool;
use crate::{ use crate::{
consts::{DAY, DEFAULT_AVATAR, MAX_TIME, MIN_INTERVAL}, consts::{DAY, MAX_TIME, MIN_INTERVAL},
interval_parser::Interval, interval_parser::Interval,
models::{ models::{
channel_data::ChannelData, channel_data::ChannelData,
@ -25,25 +19,23 @@ use crate::{
Context, Context,
}; };
async fn create_webhook( #[derive(Hash, PartialEq, Eq, Copy, Clone)]
ctx: impl CacheHttp, pub struct ChannelWithThread {
channel: GuildChannel, pub channel_id: u64,
name: impl Into<String>, pub thread_id: Option<u64>,
) -> SerenityResult<Webhook> {
channel.create_webhook(ctx.http(), CreateWebhook::new(name).avatar(&*DEFAULT_AVATAR)).await
} }
#[derive(Hash, PartialEq, Eq)] #[derive(Hash, PartialEq, Eq)]
pub enum ReminderScope { pub enum ReminderScope {
User(u64), User(u64),
Channel(u64), Channel(ChannelWithThread),
} }
impl ReminderScope { impl ReminderScope {
pub fn mention(&self) -> String { pub fn mention(&self) -> String {
match self { match self {
Self::User(id) => format!("<@{}>", id), Self::User(id) => format!("<@{}>", id),
Self::Channel(id) => format!("<#{}>", id), Self::Channel(c) => format!("<#{}>", c.channel_id),
} }
} }
} }
@ -89,6 +81,7 @@ impl ReminderBuilder {
INSERT INTO reminders ( INSERT INTO reminders (
`uid`, `uid`,
`channel_id`, `channel_id`,
`thread_id`,
`utc_time`, `utc_time`,
`timezone`, `timezone`,
`interval_seconds`, `interval_seconds`,
@ -101,11 +94,12 @@ impl ReminderBuilder {
`attachment`, `attachment`,
`set_by` `set_by`
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
", ",
self.uid, self.uid,
self.channel, self.channel,
self.thread_id,
utc_time, utc_time,
self.timezone, self.timezone,
self.interval_seconds, self.interval_seconds,
@ -218,7 +212,6 @@ impl<'a> MultiReminderBuilder<'a> {
errors.insert(ReminderError::LongInterval); errors.insert(ReminderError::LongInterval);
} else { } else {
for scope in self.scopes { for scope in self.scopes {
let thread_id = None;
let db_channel_id = match scope { let db_channel_id = match scope {
ReminderScope::User(user_id) => { ReminderScope::User(user_id) => {
if let Ok(user) = UserId::new(user_id).to_user(&self.ctx).await { if let Ok(user) = UserId::new(user_id).to_user(&self.ctx).await {
@ -238,34 +231,34 @@ impl<'a> MultiReminderBuilder<'a> {
{ {
Err(ReminderError::UserBlockedDm) Err(ReminderError::UserBlockedDm)
} else { } else {
Ok(user_data.dm_channel) Ok((user_data.dm_channel, None))
} }
} else { } else {
Ok(user_data.dm_channel) Ok((user_data.dm_channel, None))
} }
} else { } else {
Err(ReminderError::InvalidTag) Err(ReminderError::InvalidTag)
} }
} }
ReminderScope::Channel(channel_id) => { ReminderScope::Channel(channel_with_thread) => {
let channel = let channel = ChannelId::new(channel_with_thread.channel_id)
ChannelId::new(channel_id).to_channel(&self.ctx).await.unwrap(); .to_channel(&self.ctx)
.await
.unwrap();
if let Some(mut guild_channel) = channel.clone().guild() { if let Some(guild_channel) = channel.clone().guild() {
if Some(guild_channel.guild_id) != self.guild_id { if Some(guild_channel.guild_id) != self.guild_id {
Err(ReminderError::InvalidTag) Err(ReminderError::InvalidTag)
} else { } else {
let mut channel_data = if guild_channel.kind let mut channel_data = if guild_channel.kind
== ChannelType::PublicThread == ChannelType::PublicThread
{ {
// fixme jesus christ
let parent = guild_channel let parent = guild_channel
.parent_id .parent_id
.unwrap() .unwrap()
.to_channel(&self.ctx) .to_channel(&self.ctx)
.await .await
.unwrap(); .unwrap();
guild_channel = parent.clone().guild().unwrap();
ChannelData::from_channel(&parent, &self.ctx.data().database) ChannelData::from_channel(&parent, &self.ctx.data().database)
.await .await
.unwrap() .unwrap()
@ -275,28 +268,13 @@ impl<'a> MultiReminderBuilder<'a> {
.unwrap() .unwrap()
}; };
if channel_data.webhook_id.is_none() match channel_data
|| channel_data.webhook_token.is_none() .ensure_webhook(&self.ctx, &self.ctx.data().database)
.await
.map_err(|e| ReminderError::DiscordError(e.to_string()))
{ {
match create_webhook(&self.ctx, guild_channel, "Reminder").await Ok(()) => Ok((channel_data.id, channel_with_thread.thread_id)),
{ Err(e) => Err(e),
Ok(webhook) => {
channel_data.webhook_id =
Some(webhook.id.get().to_owned());
channel_data.webhook_token =
webhook.token.map(|s| s.expose_secret().clone());
channel_data
.commit_changes(&self.ctx.data().database)
.await;
Ok(channel_data.id)
}
Err(e) => Err(ReminderError::DiscordError(e.to_string())),
}
} else {
Ok(channel_data.id)
} }
} }
} else { } else {
@ -306,11 +284,11 @@ impl<'a> MultiReminderBuilder<'a> {
}; };
match db_channel_id { match db_channel_id {
Ok(c) => { Ok((channel, thread_id)) => {
let builder = ReminderBuilder { let builder = ReminderBuilder {
pool: self.ctx.data().database.clone(), pool: self.ctx.data().database.clone(),
uid: generate_uid(), uid: generate_uid(),
channel: c, channel,
thread_id, thread_id,
utc_time: self.utc_time, utc_time: self.utc_time,
timezone: self.timezone.to_string(), timezone: self.timezone.to_string(),

View File

@ -26,7 +26,7 @@ use crate::{
interval_parser::parse_duration, interval_parser::parse_duration,
models::{ models::{
reminder::{ reminder::{
builder::{MultiReminderBuilder, ReminderScope}, builder::{ChannelWithThread, MultiReminderBuilder, ReminderScope},
content::Content, content::Content,
errors::ReminderError, errors::ReminderError,
}, },
@ -406,7 +406,8 @@ pub async fn create_reminder(
let id = i.get(2).unwrap().as_str().parse::<u64>().unwrap(); let id = i.get(2).unwrap().as_str().parse::<u64>().unwrap();
if pref == "#" { if pref == "#" {
ReminderScope::Channel(id) let channel_with_thread = ChannelWithThread { channel_id: id, thread_id: None };
ReminderScope::Channel(channel_with_thread)
} else { } else {
ReminderScope::User(id) ReminderScope::User(id)
} }
@ -482,7 +483,11 @@ pub async fn create_reminder(
if list.is_empty() { if list.is_empty() {
if ctx.guild_id().is_some() { if ctx.guild_id().is_some() {
vec![ReminderScope::Channel(ctx.channel_id().get())] let channel_with_threads = ChannelWithThread {
channel_id: ctx.channel_id().get(),
thread_id: None,
};
vec![ReminderScope::Channel(channel_with_threads)]
} else { } else {
vec![ReminderScope::User(ctx.author().id.get())] vec![ReminderScope::User(ctx.author().id.get())]
} }