Handle threads in channels option

This commit is contained in:
jude 2024-11-16 12:36:24 +00:00
parent 307649eea0
commit 56dbb95e22
2 changed files with 568 additions and 259 deletions

738
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -3,11 +3,6 @@ pub mod content;
pub mod errors; pub mod errors;
mod helper; mod helper;
use std::{
collections::HashSet,
hash::{Hash, Hasher},
};
use crate::{ use crate::{
commands::look::{LookFlags, TimeDisplayType}, commands::look::{LookFlags, TimeDisplayType},
component_models::{ComponentDataModel, UndoReminder}, component_models::{ComponentDataModel, UndoReminder},
@ -27,6 +22,7 @@ use crate::{
}; };
use chrono::{DateTime, NaiveDateTime, Utc}; use chrono::{DateTime, NaiveDateTime, Utc};
use chrono_tz::Tz; use chrono_tz::Tz;
use log::warn;
use poise::{ use poise::{
serenity_prelude::{ serenity_prelude::{
model::id::{ChannelId, GuildId, UserId}, model::id::{ChannelId, GuildId, UserId},
@ -34,7 +30,13 @@ use poise::{
}, },
CreateReply, CreateReply,
}; };
use serenity::all::Channel;
use sqlx::{Executor, FromRow}; use sqlx::{Executor, FromRow};
use std::thread::ThreadId;
use std::{
collections::HashSet,
hash::{Hash, Hasher},
};
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
#[allow(dead_code)] #[allow(dead_code)]
@ -367,6 +369,61 @@ impl Reminder {
} }
} }
async fn parse_scopes(ctx: Context<'_>, mentions: &str) -> Vec<ReminderScope> {
let captures = REGEX_CHANNEL_USER.captures_iter(mentions);
let mut scopes = vec![];
for capture in captures {
let pref = capture.get(1).unwrap().as_str();
let id = capture.get(2).unwrap().as_str().parse::<u64>().unwrap();
if pref == "#" {
let channel_id = ChannelId::new(id);
let channel = channel_id.to_channel(&ctx).await;
match channel {
Ok(channel) => match channel {
Channel::Guild(channel) => match channel.kind {
ChannelType::PublicThread | ChannelType::PrivateThread => {
scopes.push(ReminderScope::Channel(ChannelWithThread {
channel_id: channel
.parent_id
.expect("No parent_id for a thread")
.get(),
thread_id: Some(id),
}));
}
_ => {
scopes.push(ReminderScope::Channel(ChannelWithThread {
channel_id: id,
thread_id: None,
}));
}
},
Channel::Private(channel) => {
scopes.push(ReminderScope::User(channel.recipient.id.get()));
}
_ => {
warn!("Unknown channel type");
}
},
Err(e) => {
warn!("Could not get channel {}: {}", id, e);
scopes.push(ReminderScope::Channel(ChannelWithThread {
channel_id: id,
thread_id: None,
}))
}
}
} else {
scopes.push(ReminderScope::User(id));
}
}
scopes
}
pub async fn create_reminder( pub async fn create_reminder(
ctx: Context<'_>, ctx: Context<'_>,
time: String, time: String,
@ -377,23 +434,6 @@ pub async fn create_reminder(
tts: Option<bool>, tts: Option<bool>,
timezone: Option<Tz>, timezone: Option<Tz>,
) -> Result<(), Error> { ) -> Result<(), Error> {
fn parse_mention_list(mentions: &str) -> Vec<ReminderScope> {
REGEX_CHANNEL_USER
.captures_iter(mentions)
.map(|i| {
let pref = i.get(1).unwrap().as_str();
let id = i.get(2).unwrap().as_str().parse::<u64>().unwrap();
if pref == "#" {
let channel_with_thread = ChannelWithThread { channel_id: id, thread_id: None };
ReminderScope::Channel(channel_with_thread)
} else {
ReminderScope::User(id)
}
})
.collect::<Vec<ReminderScope>>()
}
fn create_response( fn create_response(
successes: &HashSet<(Reminder, ReminderScope)>, successes: &HashSet<(Reminder, ReminderScope)>,
errors: &HashSet<ReminderError>, errors: &HashSet<ReminderError>,
@ -458,7 +498,10 @@ pub async fn create_reminder(
}; };
let scopes = { let scopes = {
let list = channels.map(|arg| parse_mention_list(&arg)).unwrap_or_default(); let list = match channels {
Some(channels) => parse_scopes(ctx, &channels).await,
None => vec![],
};
if list.is_empty() { if list.is_empty() {
if let Some(channel) = ctx.guild_channel().await { if let Some(channel) = ctx.guild_channel().await {