From 523ab7f03aa8f04fd439bffc0a25c593dc441f60 Mon Sep 17 00:00:00 2001 From: jude Date: Thu, 11 May 2023 15:02:26 +0100 Subject: [PATCH] Partial thread support --- src/commands/reminder_cmds.rs | 20 ++++++++--- src/hooks.rs | 15 ++++---- src/models/mod.rs | 17 +++++++-- src/models/reminder/builder.rs | 64 ++++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 13 deletions(-) diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 25e877a..c0b509e 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -5,7 +5,8 @@ use chrono_tz::Tz; use num_integer::Integer; use poise::{ serenity_prelude::{ - builder::CreateEmbed, component::ButtonStyle, model::channel::Channel, ReactionType, + builder::CreateEmbed, component::ButtonStyle, model::channel::Channel, ChannelType, + ReactionType, }, CreateReply, Modal, }; @@ -664,10 +665,18 @@ async fn create_reminder( let list = channels.map(|arg| parse_mention_list(&arg)).unwrap_or_default(); if list.is_empty() { - if ctx.guild_id().is_some() { - vec![ReminderScope::Channel(ctx.channel_id().0)] - } else { - vec![ReminderScope::User(ctx.author().id.0)] + let channel = ctx.channel_id().to_channel(&ctx.discord()).await?; + + match channel.guild() { + Some(guild_channel) => { + if guild_channel.kind == ChannelType::PublicThread { + vec![ReminderScope::Thread(ctx.channel_id().0)] + } else { + vec![ReminderScope::Channel(ctx.channel_id().0)] + } + } + + None => vec![ReminderScope::User(ctx.author().id.0)], } } else { list @@ -815,6 +824,7 @@ fn create_response( embed } +// TODO process threads here fn parse_mention_list(mentions: &str) -> Vec { REGEX_CHANNEL_USER .captures_iter(mentions) diff --git a/src/hooks.rs b/src/hooks.rs index c4606a3..35ae9c0 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -53,19 +53,22 @@ async fn check_self_permissions(ctx: Context<'_>) -> bool { .member_permissions(&ctx.discord(), user_id) .await .map_or(false, |p| p.manage_webhooks()); + let (view_channel, send_messages, embed_links) = ctx .channel_id() - .to_channel_cached(&ctx.discord()) + .to_channel(&ctx.discord()) + .await + .ok() .and_then(|c| { if let Channel::Guild(channel) = c { - channel.permissions_for_user(&ctx.discord(), user_id).ok() + let perms = channel.permissions_for_user(&ctx.discord(), user_id).ok()?; + + Some((perms.view_channel(), perms.send_messages(), perms.embed_links())) } else { None } }) - .map_or((false, false, false), |p| { - (p.view_channel(), p.send_messages(), p.embed_links()) - }); + .unwrap_or((false, false, false)); if manage_webhooks && send_messages && embed_links { true @@ -81,8 +84,8 @@ async fn check_self_permissions(ctx: Context<'_>) -> bool { {} **Manage Webhooks**", if view_channel { "✅" } else { "❌" }, if send_messages { "✅" } else { "❌" }, - if manage_webhooks { "✅" } else { "❌" }, if embed_links { "✅" } else { "❌" }, + if manage_webhooks { "✅" } else { "❌" }, )) }) .await; diff --git a/src/models/mod.rs b/src/models/mod.rs index db7e0f9..760d15a 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -5,7 +5,7 @@ pub mod timer; pub mod user_data; use chrono_tz::Tz; -use poise::serenity_prelude::{async_trait, model::id::UserId}; +use poise::serenity_prelude::{async_trait, model::id::UserId, ChannelType}; use crate::{ models::{channel_data::ChannelData, user_data::UserData}, @@ -43,7 +43,20 @@ impl CtxData for Context<'_> { } async fn channel_data(&self) -> Result> { - let channel = self.channel_id().to_channel_cached(&self.discord()).unwrap(); + // If we're in a thread, get the parent channel. + let recv_channel = self.channel_id().to_channel(&self.discord()).await?; + + let channel = match recv_channel.guild() { + Some(guild_channel) => { + if guild_channel.kind == ChannelType::PublicThread { + guild_channel.parent_id.unwrap().to_channel_cached(&self.discord()).unwrap() + } else { + self.channel_id().to_channel_cached(&self.discord()).unwrap() + } + } + + None => self.channel_id().to_channel_cached(&self.discord()).unwrap(), + }; ChannelData::from_channel(&channel, &self.data().database).await } diff --git a/src/models/reminder/builder.rs b/src/models/reminder/builder.rs index 9fd91a6..866dbcd 100644 --- a/src/models/reminder/builder.rs +++ b/src/models/reminder/builder.rs @@ -36,6 +36,7 @@ async fn create_webhook( pub enum ReminderScope { User(u64), Channel(u64), + Thread(u64), } impl ReminderScope { @@ -43,6 +44,7 @@ impl ReminderScope { match self { Self::User(id) => format!("<@{}>", id), Self::Channel(id) => format!("<#{}>", id), + Self::Thread(id) => format!("<#{}>", id), } } } @@ -51,6 +53,7 @@ pub struct ReminderBuilder { pool: MySqlPool, uid: String, channel: u32, + thread_id: Option, utc_time: NaiveDateTime, timezone: String, interval_seconds: Option, @@ -299,6 +302,66 @@ impl<'a> MultiReminderBuilder<'a> { Err(ReminderError::InvalidTag) } } + ReminderScope::Thread(thread_id) => { + let thread = + ChannelId(thread_id).to_channel(&self.ctx.discord()).await.unwrap(); + + if let Some(guild_channel) = thread.guild() { + if Some(guild_channel.guild_id) != self.guild_id { + Err(ReminderError::InvalidTag) + } else { + match guild_channel.parent_id { + Some(parent_id) => { + let channel = parent_id + .to_channel(&self.ctx.discord()) + .await + .unwrap(); + + let mut channel_data = ChannelData::from_channel( + &channel, + &self.ctx.data().database, + ) + .await + .unwrap(); + + if channel_data.webhook_id.is_none() + || channel_data.webhook_token.is_none() + { + match create_webhook( + &self.ctx.discord(), + channel.guild().unwrap(), + "Reminder", + ) + .await + { + Ok(webhook) => { + channel_data.webhook_id = + Some(webhook.id.as_u64().to_owned()); + channel_data.webhook_token = webhook.token; + + channel_data + .commit_changes(&self.ctx.data().database) + .await; + + Ok(channel_data.id) + } + + Err(e) => { + Err(ReminderError::DiscordError(e.to_string())) + } + } + } else { + Ok(channel_data.id) + } + } + + None => Err(ReminderError::InvalidTag), + } + } + } else { + Err(ReminderError::InvalidTag) + } + } }; match db_channel_id { @@ -307,6 +370,7 @@ impl<'a> MultiReminderBuilder<'a> { pool: self.ctx.data().database.clone(), uid: generate_uid(), channel: c, + thread_id: None, utc_time: self.utc_time, timezone: self.timezone.to_string(), interval_seconds: self.interval.map(|i| i.sec as i64),