diff --git a/Cargo.lock b/Cargo.lock index 72f79cf..3c69325 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1230,7 +1230,7 @@ dependencies = [ "regex", "regex_command_attr", "reqwest", - "ring", + "rmp-serde", "serde", "serde_json", "serenity", @@ -1302,6 +1302,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "rmp" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f55e5fa1446c4d5dd1f5daeed2a4fe193071771a2636274d0d7a3b082aa7ad6" +dependencies = [ + "byteorder", + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "723ecff9ad04f4ad92fe1c8ca6c20d2196d9286e9c60727c4cb5511629260e9d" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rsa" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index e852940..082e7e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,11 @@ lazy_static = "1.4" num-integer = "0.1" serde = "1.0" serde_json = "1.0" +rmp-serde = "0.15" rand = "0.7" levenshtein = "1.0" serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["collector", "unstable_discord_api"] } sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]} -ring = "0.16" base64 = "0.13.0" [dependencies.regex_command_attr] diff --git a/create.sql b/migration/00-initial.sql similarity index 100% rename from create.sql rename to migration/00-initial.sql diff --git a/migration/reminder_message_embed.sql b/migration/01-reminder_message_embed.sql similarity index 100% rename from migration/reminder_message_embed.sql rename to migration/01-reminder_message_embed.sql diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index 0b852a1..2c97b75 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -17,6 +17,7 @@ use serenity::{ }; use crate::{ + component_models::{ComponentDataModel, Restrict}, consts::{REGEX_ALIAS, REGEX_COMMANDS, THEME_COLOR}, framework::{CommandInvoke, CreateGenericResponse, PermissionLevel}, models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxData}, @@ -264,6 +265,8 @@ async fn restrict( let len = restrictable_commands.len(); + let restrict_pl = ComponentDataModel::Restrict(Restrict { role_id: role }); + invoke .respond( ctx.http.clone(), @@ -273,7 +276,7 @@ async fn restrict( c.create_action_row(|row| { row.create_select_menu(|select| { select - .custom_id("test_id") + .custom_id(restrict_pl.to_custom_id()) .options(|options| { for command in restrictable_commands { options.create_option(|opt| { diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 4a0aa98..4bc7ee6 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -10,14 +10,19 @@ use num_integer::Integer; use regex_command_attr::command; use serenity::{ client::Context, - model::channel::{Channel, Message}, + futures::StreamExt, + model::{ + channel::{Channel, Message}, + id::ChannelId, + misc::Mentionable, + }, }; use crate::{ check_subscription_on_message, consts::{ - REGEX_CHANNEL_USER, REGEX_NATURAL_COMMAND_1, REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND, - THEME_COLOR, + EMBED_DESCRIPTION_MAX_LENGTH, REGEX_CHANNEL_USER, REGEX_NATURAL_COMMAND_1, + REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND, THEME_COLOR, }, framework::{CommandInvoke, CreateGenericResponse}, models::{ @@ -26,7 +31,7 @@ use crate::{ reminder::{ builder::{MultiReminderBuilder, ReminderScope}, content::Content, - look_flags::LookFlags, + look_flags::{LookFlags, TimeDisplayType}, Reminder, }, timer::Timer, @@ -116,146 +121,249 @@ async fn pause( } } -/* -#[command] -#[permission_level(Restricted)] -async fn offset(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; +#[command("offset")] +#[description("Move all reminders in the current server by a certain amount of time. Times get added together")] +#[arg( + name = "hours", + description = "Number of hours to offset by", + kind = "Integer", + required = false +)] +#[arg( + name = "minutes", + description = "Number of minutes to offset by", + kind = "Integer", + required = false +)] +#[arg( + name = "seconds", + description = "Number of seconds to offset by", + kind = "Integer", + required = false +)] +#[required_permissions(Restricted)] +async fn offset( + ctx: &Context, + invoke: &(dyn CommandInvoke + Send + Sync), + args: HashMap, +) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); - let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); + let combined_time = args.get("hours").map_or(0, |h| h.parse::().unwrap() * 3600) + + args.get("minutes").map_or(0, |m| m.parse::().unwrap() * 60) + + args.get("seconds").map_or(0, |s| s.parse::().unwrap()); - if args.is_empty() { - let prefix = ctx.prefix(msg.guild_id).await; - - command_help(ctx, msg, lm, &prefix, &user_data.language, "offset").await; + if combined_time == 0 { + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content("Please specify one of `hours`, `minutes` or `seconds`"), + ) + .await; } else { - let parser = TimeParser::new(&args, user_data.timezone()); + if let Some(guild) = invoke.guild(ctx.cache.clone()) { + let channels = guild + .channels + .iter() + .filter(|(channel_id, channel)| match channel { + Channel::Guild(guild_channel) => guild_channel.is_text_based(), + _ => false, + }) + .map(|(id, _)| id.0.to_string()) + .collect::>() + .join(","); - if let Ok(displacement) = parser.displacement() { - if let Some(guild) = msg.guild(&ctx) { - let guild_data = GuildData::from_guild(guild, &pool).await.unwrap(); - - sqlx::query!( - " + sqlx::query!( + " UPDATE reminders - INNER JOIN `channels` - ON `channels`.id = reminders.channel_id - SET - reminders.`utc_time` = reminders.`utc_time` + ? - WHERE channels.guild_id = ? - ", - displacement, - guild_data.id - ) - .execute(&pool) - .await - .unwrap(); - } else { - sqlx::query!( - " -UPDATE reminders SET `utc_time` = `utc_time` + ? WHERE reminders.channel_id = ? - ", - displacement, - user_data.dm_channel - ) - .execute(&pool) - .await - .unwrap(); - } - - let response = lm.get(&user_data.language, "offset/success").replacen( - "{}", - &displacement.to_string(), - 1, - ); - - let _ = msg.channel_id.say(&ctx, response).await; +INNER JOIN + `channels` ON `channels`.id = reminders.channel_id +SET reminders.`utc_time` = reminders.`utc_time` + ? +WHERE FIND_IN_SET(channels.`channel`, ?)", + combined_time, + channels + ) + .execute(&pool) + .await + .unwrap(); } else { - let _ = - msg.channel_id.say(&ctx, lm.get(&user_data.language, "offset/invalid_time")).await; + sqlx::query!( + "UPDATE reminders INNER JOIN `channels` ON `channels`.id = reminders.channel_id SET reminders.`utc_time` = reminders.`utc_time` + ? WHERE channels.`channel` = ?", + combined_time, + invoke.channel_id().0 + ) + .execute(&pool) + .await + .unwrap(); } + + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .content(format!("All reminders offset by {} seconds", combined_time)), + ) + .await; } } -#[command] -#[permission_level(Restricted)] -async fn nudge(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; +#[command("nudge")] +#[description("Nudge all future reminders on this channel by a certain amount (don't use for DST! See `/offset`)")] +#[arg( + name = "minutes", + description = "Number of minutes to nudge new reminders by", + kind = "Integer", + required = false +)] +#[arg( + name = "seconds", + description = "Number of seconds to nudge new reminders by", + kind = "Integer", + required = false +)] +#[required_permissions(Restricted)] +async fn nudge( + ctx: &Context, + invoke: &(dyn CommandInvoke + Send + Sync), + args: HashMap, +) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); - let language = UserData::language_of(&msg.author, &pool).await; - let timezone = UserData::timezone_of(&msg.author, &pool).await; + let combined_time = args.get("minutes").map_or(0, |m| m.parse::().unwrap() * 60) + + args.get("seconds").map_or(0, |s| s.parse::().unwrap()); - let mut channel = - ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap(); - - if args.is_empty() { - let content = lm - .get(&language, "nudge/no_argument") - .replace("{nudge}", &format!("{}s", &channel.nudge.to_string())); - - let _ = msg.channel_id.say(&ctx, content).await; + if combined_time < i16::MIN as i64 || combined_time > i16::MAX as i64 { + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("Nudge times must be less than 500 minutes"), + ) + .await; } else { - let parser = TimeParser::new(&args, timezone); - let nudge_time = parser.displacement(); + let mut channel_data = ctx.channel_data(invoke.channel_id()).await.unwrap(); - match nudge_time { - Ok(displacement) => { - if displacement < i16::MIN as i64 || displacement > i16::MAX as i64 { - let _ = msg.channel_id.say(&ctx, lm.get(&language, "nudge/invalid_time")).await; - } else { - channel.nudge = displacement as i16; + channel_data.nudge = combined_time as i16; + channel_data.commit_changes(&pool).await; - channel.commit_changes(&pool).await; - - let response = lm.get(&language, "nudge/success").replacen( - "{}", - &displacement.to_string(), - 1, - ); - - let _ = msg.channel_id.say(&ctx, response).await; - } - } - - Err(_) => { - let _ = msg.channel_id.say(&ctx, lm.get(&language, "nudge/invalid_time")).await; - } - } + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content(format!( + "Future reminders will be nudged by {} seconds", + combined_time + )), + ) + .await; } } #[command("look")] -#[permission_level(Managed)] -async fn look(ctx: &Context, msg: &Message, args: String) { - let (pool, _lm) = get_ctx_data(&ctx).await; +#[description("View reminders on a specific channel")] +#[arg( + name = "channel", + description = "The channel to view reminders on", + kind = "Channel", + required = false +)] +#[arg( + name = "disabled", + description = "Whether to show disabled reminders or not", + kind = "Boolean", + required = false +)] +#[arg( + name = "relative", + description = "Whether to display times as relative or exact times", + kind = "Boolean", + required = false +)] +#[required_permissions(Managed)] +async fn look( + ctx: &Context, + invoke: &(dyn CommandInvoke + Send + Sync), + args: HashMap, +) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); - let timezone = UserData::timezone_of(&msg.author, &pool).await; + let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; - let flags = LookFlags::from_string(&args); + let flags = LookFlags { + show_disabled: args.get("disabled").map(|b| b == "true").unwrap_or(true), + channel_id: args.get("channel").map(|c| ChannelId(c.parse::().unwrap())), + time_display: args.get("relative").map_or(TimeDisplayType::Relative, |b| { + if b == "true" { + TimeDisplayType::Relative + } else { + TimeDisplayType::Absolute + } + }), + }; - let channel_opt = msg.channel_id.to_channel_cached(&ctx); + let channel_opt = invoke.channel_id().to_channel_cached(&ctx); let channel_id = if let Some(Channel::Guild(channel)) = channel_opt { - if Some(channel.guild_id) == msg.guild_id { - flags.channel_id.unwrap_or(msg.channel_id) + if Some(channel.guild_id) == invoke.guild_id() { + flags.channel_id.unwrap_or(invoke.channel_id()) } else { - msg.channel_id + invoke.channel_id() } } else { - msg.channel_id + invoke.channel_id() }; let reminders = Reminder::from_channel(ctx, channel_id, &flags).await; if reminders.is_empty() { - let _ = msg.channel_id.say(&ctx, "No reminders on specified channel").await; + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new().content("No reminders on specified channel"), + ) + .await; } else { - let display = reminders.iter().map(|reminder| reminder.display(&flags, &timezone)); + let mut char_count = 0; - let _ = msg.channel_id.say_lines(&ctx, display).await; + let display = reminders + .iter() + .map(|reminder| reminder.display(&flags, &timezone)) + .take_while(|p| { + char_count += p.len(); + + char_count < EMBED_DESCRIPTION_MAX_LENGTH + }) + .collect::>() + .join("\n"); + + let pages = reminders + .iter() + .map(|reminder| reminder.display(&flags, &timezone)) + .fold(0, |t, r| t + r.len()) + .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); + + let _ = invoke + .respond( + ctx.http.clone(), + CreateGenericResponse::new() + .embed(|e| { + e.title(format!("Reminders on {}", channel_id.mention())) + .description(display) + .footer(|f| f.text(format!("Page {} of {}", 1, pages))) + }) + .components(|comp| { + comp.create_action_row(|row| { + row.create_button(|b| b.label("⏮️").custom_id(".1")) + .create_button(|b| b.label("◀️").custom_id(".2")) + .create_button(|b| b.label("▶️").custom_id(".3")) + .create_button(|b| b.label("⏭️").custom_id(".4")) + }) + }), + ) + .await; } } +/* #[command("del")] #[permission_level(Managed)] async fn delete(ctx: &Context, msg: &Message, _args: String) { diff --git a/src/component_models/mod.rs b/src/component_models/mod.rs new file mode 100644 index 0000000..20dac0b --- /dev/null +++ b/src/component_models/mod.rs @@ -0,0 +1,51 @@ +use std::io::Cursor; + +use rmp_serde::Serializer; +use serde::{Deserialize, Serialize}; +use serenity::model::{ + id::{ChannelId, RoleId}, + interactions::message_component::MessageComponentInteraction, +}; + +use crate::models::reminder::look_flags::LookFlags; + +#[derive(Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ComponentDataModel { + Restrict(Restrict), + LookPager(LookPager), +} + +impl ComponentDataModel { + pub fn to_custom_id(&self) -> String { + let mut buf = Vec::new(); + self.serialize(&mut Serializer::new(&mut buf)).unwrap(); + base64::encode(buf) + } + + pub fn from_custom_id(data: &String) -> Self { + let buf = base64::decode(data).unwrap(); + let cur = Cursor::new(buf); + rmp_serde::from_read(cur).unwrap() + } + + pub async fn act(&self, component: MessageComponentInteraction) { + match self { + ComponentDataModel::Restrict(restrict) => { + println!("{:?}", component.data.values); + } + ComponentDataModel::LookPager(pager) => {} + } + } +} + +#[derive(Deserialize, Serialize)] +pub struct Restrict { + pub role_id: RoleId, +} + +#[derive(Deserialize, Serialize)] +pub struct LookPager { + pub flags: LookFlags, + pub page_request: u16, +} diff --git a/src/consts.rs b/src/consts.rs index fe99a94..0bf03ad 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -1,6 +1,7 @@ pub const DAY: u64 = 86_400; pub const HOUR: u64 = 3_600; pub const MINUTE: u64 = 60; +pub const EMBED_DESCRIPTION_MAX_LENGTH: usize = 4000; pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; diff --git a/src/main.rs b/src/main.rs index 925f8aa..62c27cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,9 @@ +#![feature(int_roundings)] #[macro_use] extern crate lazy_static; mod commands; +mod component_models; mod consts; mod framework; mod models; @@ -34,6 +36,7 @@ use tokio::sync::RwLock; use crate::{ commands::{info_cmds, moderation_cmds, reminder_cmds}, + component_models::ComponentDataModel, consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR}, framework::RegexFramework, models::guild_data::GuildData, @@ -253,6 +256,10 @@ DELETE FROM guilds WHERE guild = ? framework.execute(ctx, application_command).await; } + Interaction::MessageComponent(component) => { + let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id); + component_model.act(component).await; + } _ => {} } } @@ -298,13 +305,13 @@ async fn main() -> Result<(), Box> { .add_command("n", &reminder_cmds::NATURAL_COMMAND) .add_command("", &reminder_cmds::NATURAL_COMMAND) // management commands - .add_command("look", &reminder_cmds::LOOK_COMMAND) .add_command("del", &reminder_cmds::DELETE_COMMAND) */ + .add_command(&reminder_cmds::LOOK_COMMAND) .add_command(&reminder_cmds::PAUSE_COMMAND) + .add_command(&reminder_cmds::OFFSET_COMMAND) + .add_command(&reminder_cmds::NUDGE_COMMAND) /* - .add_command("offset", &reminder_cmds::OFFSET_COMMAND) - .add_command("nudge", &reminder_cmds::NUDGE_COMMAND) // to-do commands .add_command("todo", &todo_cmds::TODO_USER_COMMAND) .add_command("todo user", &todo_cmds::TODO_USER_COMMAND) diff --git a/src/models/reminder/look_flags.rs b/src/models/reminder/look_flags.rs index 576f364..dde4568 100644 --- a/src/models/reminder/look_flags.rs +++ b/src/models/reminder/look_flags.rs @@ -1,14 +1,16 @@ +use serde::{Deserialize, Serialize}; use serenity::model::id::ChannelId; use crate::consts::REGEX_CHANNEL; +#[derive(Serialize, Deserialize)] pub enum TimeDisplayType { - Absolute, - Relative, + Absolute = 0, + Relative = 1, } +#[derive(Serialize, Deserialize)] pub struct LookFlags { - pub limit: u16, pub show_disabled: bool, pub channel_id: Option, pub time_display: TimeDisplayType, @@ -16,44 +18,6 @@ pub struct LookFlags { impl Default for LookFlags { fn default() -> Self { - Self { - limit: u16::MAX, - show_disabled: true, - channel_id: None, - time_display: TimeDisplayType::Relative, - } - } -} - -impl LookFlags { - pub fn from_string(args: &str) -> Self { - let mut new_flags: Self = Default::default(); - - for arg in args.split(' ') { - match arg { - "enabled" => { - new_flags.show_disabled = false; - } - - "time" => { - new_flags.time_display = TimeDisplayType::Absolute; - } - - param => { - if let Ok(val) = param.parse::() { - new_flags.limit = val; - } else if let Some(channel) = REGEX_CHANNEL - .captures(arg) - .map(|cap| cap.get(1)) - .flatten() - .map(|c| c.as_str().parse::().unwrap()) - { - new_flags.channel_id = Some(ChannelId(channel)); - } - } - } - } - - new_flags + Self { show_disabled: true, channel_id: None, time_display: TimeDisplayType::Relative } } } diff --git a/src/models/reminder/mod.rs b/src/models/reminder/mod.rs index ac30cfe..9837486 100644 --- a/src/models/reminder/mod.rs +++ b/src/models/reminder/mod.rs @@ -11,7 +11,6 @@ use std::{ use chrono::{NaiveDateTime, TimeZone}; use chrono_tz::Tz; -use ring::hmac; use serenity::{ client::Context, model::id::{ChannelId, GuildId, UserId}, @@ -27,31 +26,6 @@ use crate::{ SQLPool, }; -#[derive(Clone, Copy)] -pub enum ReminderAction { - Delete, -} - -impl ToString for ReminderAction { - fn to_string(&self) -> String { - match self { - Self::Delete => String::from("del"), - } - } -} - -impl TryFrom<&str> for ReminderAction { - type Error = (); - - fn try_from(value: &str) -> Result { - match value { - "del" => Ok(Self::Delete), - - _ => Err(()), - } - } -} - #[derive(Debug)] pub struct Reminder { pub id: u32, @@ -178,12 +152,9 @@ WHERE FIND_IN_SET(reminders.enabled, ?) ORDER BY reminders.utc_time -LIMIT - ? ", channel_id.as_u64(), enabled, - flags.limit ) .fetch_all(&pool) .await @@ -341,59 +312,6 @@ WHERE } } - pub async fn from_interaction>( - ctx: &Context, - member_id: U, - payload: String, - ) -> Result<(Self, ReminderAction), InteractionError> { - let sections = payload.split('.').collect::>(); - - if sections.len() != 3 { - Err(InteractionError::InvalidFormat) - } else { - let action = ReminderAction::try_from(sections[0]) - .map_err(|_| InteractionError::InvalidAction)?; - - let reminder_id = u32::from_le_bytes( - base64::decode(sections[1]) - .map_err(|_| InteractionError::InvalidBase64)? - .try_into() - .map_err(|_| InteractionError::InvalidSize)?, - ); - - if let Some(reminder) = Self::from_id(ctx, reminder_id).await { - if reminder.signed_action(member_id, action) == payload { - Ok((reminder, action)) - } else { - Err(InteractionError::SignatureMismatch) - } - } else { - Err(InteractionError::NoReminder) - } - } - } - - pub fn signed_action>(&self, member_id: U, action: ReminderAction) -> String { - let s_key = hmac::Key::new( - hmac::HMAC_SHA256, - env::var("SECRET_KEY").expect("No SECRET_KEY provided").as_bytes(), - ); - - let mut context = hmac::Context::with_key(&s_key); - - context.update(&self.id.to_le_bytes()); - context.update(&member_id.into().to_le_bytes()); - - let signature = context.sign(); - - format!( - "{}.{}.{}", - action.to_string(), - base64::encode(self.id.to_le_bytes()), - base64::encode(&signature) - ) - } - pub async fn delete(&self, ctx: &Context) { let pool = ctx.data.read().await.get::().cloned().unwrap();