From e2229278588946b1580329a804b37d923cfbbd98 Mon Sep 17 00:00:00 2001 From: jude Date: Fri, 11 Sep 2020 17:41:15 +0100 Subject: [PATCH] working on the basic reminder commands --- Cargo.lock | 17 ++- Cargo.toml | 5 +- src/commands/info_cmds.rs | 8 +- src/commands/moderation_cmds.rs | 10 +- src/commands/reminder_cmds.rs | 191 +++++++++++++++++++++++++++++++- src/main.rs | 5 + src/models.rs | 10 +- 7 files changed, 222 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e98f3d5..150e2eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -273,6 +273,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "custom_error" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93a0fc65739ae998afc8d68e64bdac2efd1bc4ffa1a0703d171ef2defae3792f" + [[package]] name = "digest" version = "0.8.1" @@ -1118,6 +1124,7 @@ dependencies = [ "async-trait", "chrono", "chrono-tz", + "custom_error", "dotenv", "lazy_static", "log", @@ -1125,6 +1132,8 @@ dependencies = [ "regex", "regex_command_attr", "reqwest", + "serde", + "serde_json", "serenity", "sqlx", "tokio", @@ -1258,18 +1267,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.114" +version = "1.0.115" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5317f7588f0a5078ee60ef675ef96735a1442132dc645eb1d12c018620ed8cd3" +checksum = "e54c9a88f2da7238af84b5101443f0c0d0a3bbdc455e34a5c9497b1903ed55d5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.114" +version = "1.0.115" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0be94b04690fbaed37cddffc5c134bf537c8e3329d53e982fe04c374978f8e" +checksum = "609feed1d0a73cc36a0182a840a9b37b4a82f0b1150369f0536a9e3f2a31dc48" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 2197e41..024a0c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" [dependencies] serenity = { version = "0.9.0-rc.1", features = ["collector"] } dotenv = "0.15" -tokio = { version = "0.2.19", features = ["fs", "sync", "process", "io-util"] } +tokio = { version = "0.2.19", features = ["process"] } reqwest = "0.10.6" sqlx = { version = "0.3.5", default-features = false, features = ["runtime-tokio", "macros", "mysql", "bigdecimal", "chrono"] } regex = "1.3.9" @@ -17,6 +17,9 @@ chrono = "0.4" chrono-tz = "0.5" lazy_static = "1.4.0" num-integer = "0.1.43" +custom_error = "1.7.1" +serde = "1.0.115" +serde_json = "1.0.57" [dependencies.regex_command_attr] path = "./regex_command_attr" diff --git a/src/commands/info_cmds.rs b/src/commands/info_cmds.rs index 15c97aa..c6d0d2f 100644 --- a/src/commands/info_cmds.rs +++ b/src/commands/info_cmds.rs @@ -45,7 +45,7 @@ async fn help(ctx: &Context, msg: &Message, _args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let desc = user_data.response(&pool, "help").await; msg.channel_id.send_message(ctx, |m| m @@ -64,7 +64,7 @@ async fn info(ctx: &Context, msg: &Message, _args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let desc = user_data.response(&pool, "info").await; msg.channel_id.send_message(ctx, |m| m @@ -83,7 +83,7 @@ async fn donate(ctx: &Context, msg: &Message, _args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let desc = user_data.response(&pool, "donate").await; msg.channel_id.send_message(ctx, |m| m @@ -115,7 +115,7 @@ async fn clock(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let tz: Tz = user_data.timezone.parse().unwrap(); diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index eae2d27..434a884 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -72,7 +72,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let mut user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); if args.len() > 0 { match args.parse::() { @@ -101,7 +101,7 @@ async fn language(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let mut user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); match sqlx::query!( " @@ -134,7 +134,7 @@ async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult { .get::().cloned().expect("Could not get SQLPool from data"); let mut guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap(); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); if args.len() > 5 { let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "prefix/too_long").await).await; @@ -159,7 +159,7 @@ async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap(); let role_tag_match = REGEX_ROLE.find(&args); @@ -245,7 +245,7 @@ async fn alias(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let guild_id = msg.guild_id.unwrap().as_u64().to_owned(); diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index a1a1804..54640cf 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -1,8 +1,15 @@ +use custom_error::custom_error; + use regex_command_attr::command; use serenity::{ client::Context, model::{ + id::{ + UserId, + ChannelId, + GuildId, + }, channel::{ Message, }, @@ -36,8 +43,12 @@ use std::{ use regex::Regex; +use serde_json::json; + lazy_static! { static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); + + static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"^\s*<(#|@)(?:!)?(\d+)>\s*$"#).unwrap(); } @@ -48,7 +59,7 @@ async fn pause(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap(); if args.len() == 0 { @@ -93,7 +104,7 @@ async fn offset(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); if args.len() == 0 { let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "offset/help").await).await; @@ -144,7 +155,7 @@ async fn nudge(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap(); if args.len() == 0 { @@ -245,7 +256,7 @@ async fn look(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let flags = LookFlags::from_string(&args); @@ -314,7 +325,7 @@ async fn delete(ctx: &Context, msg: &Message, _args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "del/listing").await).await; @@ -423,7 +434,7 @@ async fn timer(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let mut args_iter = args.splitn(2, " "); @@ -491,3 +502,171 @@ DELETE FROM timers WHERE owner = ? AND name = ? Ok(()) } + +enum RemindCommand { + Remind, + Interval, +} + +enum ReminderScope { + User(u64), + Channel(u64), +} + +custom_error!{ReminderError + LongTime = "Time too long", + LongInterval = "Interval too long", + PastTime = "Time has already passed", + ShortInterval = "Interval too short", + InvalidTag = "Invalid reminder scope", + NotEnoughArgs = "Not enough args", + InvalidTime = "Invalid time provided", + DiscordError = "Bad response received from Discord" +} + +#[command] +#[permission_level(Managed)] +async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult { + remind_command(ctx, msg, args, RemindCommand::Remind).await; + + Ok(()) +} + +#[command] +#[permission_level(Managed)] +async fn interval(ctx: &Context, msg: &Message, args: String) -> CommandResult { + remind_command(ctx, msg, args, RemindCommand::Interval).await; + + Ok(()) +} + +async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) { + let user_data; + + { + let pool = ctx.data.read().await + .get::().cloned().expect("Could not get SQLPool from data"); + + user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); + } + + let mut args_iter = args.split(' ').filter(|s| s.len() > 0); + + if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) { + + let scope_id; + let time_parser; + let content; + + let guild_id = msg.guild_id; + + if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER + .captures(&first_arg) + .map(|cap| (cap.get(1), cap.get(2))) { + + if scope_match.as_str() == "@" { + scope_id = ReminderScope::User(id_match.as_str().parse::().unwrap()); + } + else { + scope_id = ReminderScope::Channel(id_match.as_str().parse::().unwrap()); + } + + if let Some(next) = args_iter.next().map(|inner| inner.to_string()) { + time_parser = TimeParser::new(next, user_data.timezone.parse().unwrap()); + + content = args_iter.collect::>().join(" "); + + // TODO replace unwrap with converting response into discord response + create_reminder(ctx, guild_id, scope_id, time_parser, content).await.unwrap(); + } + else { + + } + } + else { + scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned()); + + time_parser = TimeParser::new(first_arg, user_data.timezone.parse().unwrap()); + + content = args_iter.collect::>().join(" "); + + // TODO replace unwrap with converting response into discord response + create_reminder(ctx, guild_id, scope_id, time_parser, content).await.unwrap(); + } + } + else { + + } +} + +async fn create_reminder(ctx: &Context, guild_id: Option, scope_id: ReminderScope, time_parser: TimeParser, content: String) + -> Result<(), ReminderError> { + + let pool = ctx.data.read().await + .get::().cloned().expect("Could not get SQLPool from data"); + + let db_channel_id = match scope_id { + ReminderScope::User(user_id) => { + let user = UserId(user_id).to_user(&ctx).await.unwrap(); + + let user_data = UserData::from_user(&user, &ctx, &pool).await.unwrap(); + + user_data.dm_channel + }, + + ReminderScope::Channel(channel_id) => { + let channel = ChannelId(channel_id).to_channel(&ctx).await.unwrap(); + + if channel.clone().guild().map(|gc| gc.guild_id) != guild_id { + return Err(ReminderError::InvalidTag) + } + + let mut channel_data = ChannelData::from_channel(channel.clone(), &pool).await.unwrap(); + + if let Some(guild_channel) = channel.guild() { + if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() { + if let Ok(webhook) = ctx.http.create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await { + channel_data.webhook_id = Some(webhook.id.as_u64().to_owned()); + channel_data.webhook_token = Some(webhook.token); + + channel_data.commit_changes(&pool).await; + } + else { + return Err(ReminderError::DiscordError) + } + } + } + + channel_data.id + }, + }; + + // validate time, channel, content + if content.len() == 0 { + Err(ReminderError::NotEnoughArgs) + } + else { + match time_parser.timestamp() { + Ok(time) => { + let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; + + if time > unix_time { + if time > unix_time + 60*60*24*365*50 { + Err(ReminderError::LongTime) + } + else { + + Ok(()) + } + } + else { + Err(ReminderError::PastTime) + } + }, + + Err(_) => { + Err(ReminderError::InvalidTime) + }, + } + } +} diff --git a/src/main.rs b/src/main.rs index 0418dbb..6c071bd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -76,6 +76,11 @@ async fn main() -> Result<(), Box> { .add_command("timer", &reminder_cmds::TIMER_COMMAND) + .add_command("remind", &reminder_cmds::REMIND_COMMAND) + .add_command("r", &reminder_cmds::REMIND_COMMAND) + .add_command("interval", &reminder_cmds::INTERVAL_COMMAND) + .add_command("i", &reminder_cmds::INTERVAL_COMMAND) + .add_command("look", &reminder_cmds::LOOK_COMMAND) .add_command("del", &reminder_cmds::DELETE_COMMAND) diff --git a/src/models.rs b/src/models.rs index 455a650..9a52443 100644 --- a/src/models.rs +++ b/src/models.rs @@ -60,7 +60,7 @@ UPDATE guilds SET name = ?, prefix = ? WHERE id = ? } pub struct ChannelData { - id: u32, + pub id: u32, channel: u64, pub name: String, pub nudge: i16, @@ -82,7 +82,9 @@ SELECT * FROM channels WHERE channel = ? .await.ok() } - pub async fn from_channel(channel: Channel, pool: &MySqlPool) -> Result> { + pub async fn from_channel(channel: Channel, pool: &MySqlPool) + -> Result> + { let channel_id = channel.id().as_u64().clone(); if let Ok(c) = sqlx::query_as_unchecked!(Self, @@ -139,7 +141,7 @@ pub struct UserData { } impl UserData { - pub async fn from_id(user: &User, ctx: &&Context, pool: &MySqlPool) -> Result> { + pub async fn from_user(user: &User, ctx: &&Context, pool: &MySqlPool) -> Result> { let user_id = user.id.as_u64().clone(); if let Ok(c) = sqlx::query_as_unchecked!(Self, @@ -153,7 +155,7 @@ SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ? } else { let dm_channel = user.create_dm_channel(ctx).await?; - let dm_id = dm_channel.id.as_u64().clone(); + let dm_id = dm_channel.id.as_u64().to_owned(); let pool_c = pool.clone();