diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index d7cfb67..f927252 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -3,6 +3,7 @@ use custom_error::custom_error; use regex_command_attr::command; use serenity::{ + cache::Cache, http::CacheHttp, client::Context, model::{ @@ -29,10 +30,10 @@ use crate::{ Reminder, Timer, }, - check_subscription, SQLPool, time_parser::TimeParser, framework::SendIterator, + check_subscription_on_message, }; use chrono::NaiveDateTime; @@ -69,7 +70,6 @@ use std::{ use regex::Regex; use serde_json::json; -use serenity::cache::Cache; lazy_static! { static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); @@ -570,7 +570,13 @@ custom_error!{ReminderError DiscordError = "Bad response received from Discord" } -impl ReminderError { +trait ToResponse { + fn to_response(&self) -> String; + + fn to_response_natural(&self) -> String; +} + +impl ToResponse for ReminderError { fn to_response(&self) -> String { match self { Self::LongTime => "remind/long_time", @@ -583,6 +589,31 @@ impl ReminderError { Self::DiscordError => "remind/no_webhook", }.to_string() } + + fn to_response_natural(&self) -> String { + match self { + Self::LongTime => "natural/long_time".to_string(), + _ => self.to_response(), + } + } +} + +impl ToResponse for Result { + fn to_response(&self) -> String { + match self { + Ok(_) => "remind/success".to_string(), + + Err(reminder_error) => reminder_error.to_response(), + } + } + + fn to_response_natural(&self) -> String { + match self { + Ok(_) => "remind/success".to_string(), + + Err(reminder_error) => reminder_error.to_response_natural(), + } + } } fn generate_uid() -> String { @@ -623,10 +654,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem pool: &MySqlPool) -> Result<(), ReminderError> { - let subscribed = check_subscription(&ctx, &msg.author).await || - if let Some(guild) = msg.guild(&ctx).await { check_subscription(&ctx, guild.owner_id).await } else { false }; - - if command == RemindCommand::Interval && subscribed { + if command == RemindCommand::Interval && check_subscription_on_message(&ctx, &msg).await { if let Some(interval_arg) = args_iter.next() { let interval = TimeParser::new(interval_arg.to_string(), UTC); @@ -708,11 +736,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem Err(ReminderError::NotEnoughArgs) }; - let str_response = match response { - Ok(_) => user_data.response(&pool, "remind/success").await, - - Err(reminder_error) => user_data.response(&pool, &reminder_error.to_response()).await, - } + let str_response = response.to_response() .replacen("{location}", &scope_id.mention(), 1) .replacen("{offset}", &time_parser.map(|tp| tp.displacement().ok()).flatten().unwrap_or(-1).to_string(), 1) .replacen("{min_interval}", &MIN_INTERVAL.to_string(), 1) @@ -726,6 +750,11 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); + let now = SystemTime::now(); + let since_epoch = now + .duration_since(UNIX_EPOCH) + .expect("Time calculated as going backwards. Very bad"); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let send_str = user_data.response(&pool, "natural/send").await; @@ -740,7 +769,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { let python_call = Command::new("venv/bin/python3") .arg("dp.py") .arg(time_crop) - .arg(user_data.timezone) + .arg(&user_data.timezone) .arg(&env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string())) .output() .await; @@ -782,10 +811,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { } } - let subscribed = check_subscription(&ctx, &msg.author).await || - if let Some(guild) = msg.guild(&ctx).await { check_subscription(&ctx, guild.owner_id).await } else { false }; - - if subscribed { + if check_subscription_on_message(&ctx, &msg).await { let re_match = Regex::new(&format!(r#"(?P.*) {} (?P.*)$"#, every_str)) .unwrap() .captures(content); @@ -803,11 +829,6 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { .output() .await; - let now = SystemTime::now(); - let since_epoch = now - .duration_since(UNIX_EPOCH) - .expect("Time calculated as going backwards. Very bad"); - interval = python_call.ok().map(|inner| if inner.status.success() { Some(from_utf8(&*inner.stdout).unwrap().parse::().unwrap() - since_epoch.as_secs() as i64) @@ -819,25 +840,51 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { } - let mut issue_count = 0; + if location_ids.len() == 1 { + let location_id = location_ids.get(0).unwrap(); - for location in location_ids { let res = create_reminder( &ctx, &pool, msg.author.id.as_u64().to_owned(), msg.guild_id, - &location, + &location_id, timestamp, interval, &content).await; - if res.is_ok() { - issue_count += 1; - } - } - let _ = msg.channel_id.say(&ctx, format!("successfully set {} reminders", issue_count)).await; + let str_response = res.to_response_natural() + .replacen("{location}", &location_id.mention(), 1) + .replacen("{offset}", &(timestamp as u64 - since_epoch.as_secs()).to_string(), 1) + .replacen("{min_interval}", &MIN_INTERVAL.to_string(), 1) + .replacen("{max_time}", &MAX_TIME.to_string(), 1); + + let _ = msg.channel_id.say(&ctx, &str_response).await; + } + else { + let mut issue_count = 0_u8; + + for location in location_ids { + let res = create_reminder( + &ctx, + &pool, + msg.author.id.as_u64().to_owned(), + msg.guild_id, + &location, + timestamp, + interval, + &content).await; + + if res.is_ok() { + issue_count += 1; + } + } + + let content = user_data.response(&pool, "natural/bulk_set").await.replace("{count}", &issue_count.to_string()); + + let _ = msg.channel_id.say(&ctx, content).await; + } } else { let _ = msg.channel_id.say(&ctx, "DEV ERROR: Failed to invoke Python").await; diff --git a/src/main.rs b/src/main.rs index c608f33..f40cc8e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,13 +8,17 @@ mod time_parser; mod consts; use serenity::{ + cache::Cache, http::CacheHttp, client::{ bridge::gateway::GatewayIntents, Client, }, - model::id::{ - GuildId, UserId, + model::{ + id::{ + GuildId, UserId, + }, + channel::Message, }, framework::Framework, prelude::TypeMapKey, @@ -164,3 +168,8 @@ pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into, msg: &Message) -> bool { + check_subscription(&cache_http, &msg.author).await || + if let Some(guild) = msg.guild(&cache_http).await { check_subscription(&cache_http, guild.owner_id).await } else { false } +}