Support sending reminders to threads

This commit is contained in:
jude 2024-03-03 13:04:50 +00:00
parent dcee9e0d2a
commit 6f0bdf9852
3 changed files with 91 additions and 58 deletions

View File

@ -0,0 +1,5 @@
-- Add migration script here
ALTER TABLE reminders
ADD INDEX `utc_time_index` (`utc_time`);
ALTER TABLE reminders
ADD INDEX `status_index` (`status`);

View File

@ -232,6 +232,7 @@ pub struct Reminder {
id: u32, id: u32,
channel_id: u64, channel_id: u64,
thread_id: Option<u64>,
webhook_id: Option<u64>, webhook_id: Option<u64>,
webhook_token: Option<String>, webhook_token: Option<String>,
@ -262,10 +263,11 @@ impl Reminder {
match sqlx::query_as_unchecked!( match sqlx::query_as_unchecked!(
Reminder, Reminder,
r#" r#"
SELECT SELECT
reminders.`id` AS id, reminders.`id` AS id,
channels.`channel` AS channel_id, channels.`channel` AS channel_id,
reminders.`thread_id` AS thread_id,
channels.`webhook_id` AS webhook_id, channels.`webhook_id` AS webhook_id,
channels.`webhook_token` AS webhook_token, channels.`webhook_token` AS webhook_token,
@ -289,13 +291,13 @@ SELECT
reminders.`avatar` AS avatar, reminders.`avatar` AS avatar,
reminders.`username` AS username reminders.`username` AS username
FROM FROM
reminders reminders
INNER JOIN INNER JOIN
channels channels
ON ON
reminders.channel_id = channels.id reminders.channel_id = channels.id
WHERE WHERE
reminders.`status` = 'pending' AND reminders.`status` = 'pending' AND
reminders.`id` IN ( reminders.`id` IN (
SELECT SELECT
@ -337,7 +339,9 @@ WHERE
async fn reset_webhook(&self, pool: impl Executor<'_, Database = Database> + Copy) { async fn reset_webhook(&self, pool: impl Executor<'_, Database = Database> + Copy) {
let _ = sqlx::query!( let _ = sqlx::query!(
"UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ?", "
UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ?
",
self.channel_id self.channel_id
) )
.execute(pool) .execute(pool)
@ -473,7 +477,11 @@ WHERE
reminder: &Reminder, reminder: &Reminder,
embed: Option<CreateEmbed>, embed: Option<CreateEmbed>,
) -> Result<()> { ) -> Result<()> {
let channel = ChannelId::new(reminder.channel_id).to_channel(&cache_http).await; let channel = if let Some(thread_id) = reminder.thread_id {
ChannelId::new(thread_id).to_channel(&cache_http).await
} else {
ChannelId::new(reminder.channel_id).to_channel(&cache_http).await
};
let mut message = CreateMessage::new().content(&reminder.content).tts(reminder.tts); let mut message = CreateMessage::new().content(&reminder.content).tts(reminder.tts);
@ -524,7 +532,14 @@ WHERE
webhook: Webhook, webhook: Webhook,
embed: Option<CreateEmbed>, embed: Option<CreateEmbed>,
) -> Result<()> { ) -> Result<()> {
let mut builder = ExecuteWebhook::new().content(&reminder.content).tts(reminder.tts); let mut builder = if let Some(thread_id) = reminder.thread_id {
ExecuteWebhook::new()
.content(&reminder.content)
.tts(reminder.tts)
.in_thread(thread_id)
} else {
ExecuteWebhook::new().content(&reminder.content).tts(reminder.tts)
};
if let Some(username) = &reminder.username { if let Some(username) = &reminder.username {
if !username.is_empty() { if !username.is_empty() {
@ -571,7 +586,9 @@ WHERE
.map_or(true, |inner| inner >= Utc::now().naive_local())) .map_or(true, |inner| inner >= Utc::now().naive_local()))
{ {
let _ = sqlx::query!( let _ = sqlx::query!(
"UPDATE `channels` SET paused = 0, paused_until = NULL WHERE `channel` = ?", "
UPDATE `channels` SET paused = 0, paused_until = NULL WHERE `channel` = ?
",
self.channel_id self.channel_id
) )
.execute(pool) .execute(pool)

View File

@ -13,7 +13,7 @@ use chrono_tz::Tz;
use poise::{ use poise::{
serenity_prelude::{ serenity_prelude::{
model::id::{ChannelId, GuildId, UserId}, model::id::{ChannelId, GuildId, UserId},
ButtonStyle, Cache, CreateActionRow, CreateButton, CreateEmbed, ReactionType, ButtonStyle, Cache, ChannelType, CreateActionRow, CreateButton, CreateEmbed, ReactionType,
}, },
CreateReply, CreateReply,
}; };
@ -482,12 +482,23 @@ pub async fn create_reminder(
let list = channels.map(|arg| parse_mention_list(&arg)).unwrap_or_default(); let list = channels.map(|arg| parse_mention_list(&arg)).unwrap_or_default();
if list.is_empty() { if list.is_empty() {
if ctx.guild_id().is_some() { if let Some(channel) = ctx.guild_channel().await {
if channel.kind == ChannelType::PublicThread
|| channel.kind == ChannelType::PrivateThread
{
let parent = channel.parent_id.unwrap();
let channel_with_threads = ChannelWithThread {
channel_id: parent.get(),
thread_id: Some(ctx.channel_id().get()),
};
vec![ReminderScope::Channel(channel_with_threads)]
} else {
let channel_with_threads = ChannelWithThread { let channel_with_threads = ChannelWithThread {
channel_id: ctx.channel_id().get(), channel_id: ctx.channel_id().get(),
thread_id: None, thread_id: None,
}; };
vec![ReminderScope::Channel(channel_with_threads)] vec![ReminderScope::Channel(channel_with_threads)]
}
} else { } else {
vec![ReminderScope::User(ctx.author().id.get())] vec![ReminderScope::User(ctx.author().id.get())]
} }