diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index 434a884..b518f54 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -74,7 +74,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult { let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); - if args.len() > 0 { + if !args.is_empty() { match args.parse::() { Ok(_) => { user_data.timezone = args; @@ -139,7 +139,7 @@ async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult { if args.len() > 5 { let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "prefix/too_long").await).await; } - else if args.len() == 0 { + else if args.is_empty() { let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "prefix/no_argument").await).await; } else { @@ -171,7 +171,7 @@ async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult { let role_opt = role_id.to_role_cached(&ctx).await; if let Some(role) = role_opt { - if commands.len() == 0 { + if commands.is_empty() { let _ = sqlx::query!( " DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?) @@ -206,8 +206,8 @@ INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM role } } } - else if args.len() == 0 { - let guild_id = msg.guild_id.unwrap().as_u64().clone(); + else if args.is_empty() { + let guild_id = msg.guild_id.unwrap().as_u64().to_owned(); let rows = sqlx::query!( " diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index ac117f5..2e7ba3e 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -43,26 +43,32 @@ use rand::{ RngCore, }; +use sqlx::{ + Type, + MySql, + MySqlPool, + encode::Encode, +}; + use std::str::from_utf8; use num_integer::Integer; use num_traits::cast::ToPrimitive; use std::{ - string::ToString, + convert::TryInto, default::Default, + env, + string::ToString, time::{ SystemTime, UNIX_EPOCH, }, - env, }; use regex::Regex; use serde_json::json; -use sqlx::MySqlPool; -use std::convert::TryInto; lazy_static! { static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); @@ -725,8 +731,6 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { let to_str = user_data.response(&pool, "natural/to").await; let every_str = user_data.response(&pool, "natural/every").await; - let location_ids = vec![msg.channel_id.as_u64().to_owned()]; - let mut args_iter = args.splitn(1, &send_str); let (time_crop_opt, msg_crop_opt) = (args_iter.next(), args_iter.next()); @@ -748,7 +752,53 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { None }).flatten() { + let mut location_ids = vec![ReminderScope::Channel(msg.channel_id.as_u64().to_owned())]; + let mut content = msg_crop; + // check other options and then create reminder :) + if msg.guild_id.is_some() { + let re_match = Regex::new(&format!(r#"(?P.*) {} (?P((?:<@\d+>)|(?:<@!\d+>)|(?:<#\d+>)|(?:\s+))+)$"#, to_str)) + .unwrap() + .captures(msg_crop); + + if let Some(captures) = re_match { + content = captures.name("msg").unwrap().as_str(); + + let mentions = captures.name("mentions").unwrap().as_str(); + location_ids = REGEX_CHANNEL_USER + .captures_iter(mentions) + .map(|i| { + let pref = i.get(1).unwrap().as_str(); + let id = i.get(2).unwrap().as_str().parse::().unwrap(); + + if pref == "#" { + ReminderScope::Channel(id) + } else { + ReminderScope::User(id) + } + }).collect::>(); + } + } + + let mut issue_count = 0; + + for location in location_ids { + let res = create_reminder( + &ctx, + &pool, + msg.author.id.as_u64().to_owned(), + msg.guild_id, + &location, + timestamp, + None, + &content).await; + + if res.is_ok() { + issue_count += 1; + } + } + + let _ = msg.channel_id.say(&ctx, format!("successfully set {} reminders", issue_count)).await; } // something not right with the time parse else { @@ -759,7 +809,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { Ok(()) } -async fn create_reminder>( +async fn create_reminder, S: ToString + Type + Encode>( ctx: impl CacheHttp, pool: &MySqlPool, user_id: u64, @@ -767,9 +817,11 @@ async fn create_reminder>( scope_id: &ReminderScope, time_parser: T, interval: Option, - content: String) + content: S) -> Result<(), ReminderError> { + let content_string = content.to_string(); + let db_channel_id = match scope_id { ReminderScope::User(user_id) => { let user = UserId(*user_id).to_user(&ctx).await.unwrap(); @@ -807,7 +859,7 @@ async fn create_reminder>( }; // validate time, channel, content - if content.len() == 0 { + if content_string.is_empty() { Err(ReminderError::NotEnoughArgs) } // todo replace numbers with configurable values diff --git a/src/commands/todo_cmds.rs b/src/commands/todo_cmds.rs index 7b34e2f..4930896 100644 --- a/src/commands/todo_cmds.rs +++ b/src/commands/todo_cmds.rs @@ -182,7 +182,7 @@ enum SubCommand { #[permission_level(Managed)] async fn todo_parse(ctx: &Context, msg: &Message, args: String) -> CommandResult { - let mut split = args.split(" "); + let mut split = args.split(' '); if let Some(target) = split.next() { let target_opt = match target { diff --git a/src/time_parser.rs b/src/time_parser.rs index 83f7cb6..4be77ce 100644 --- a/src/time_parser.rs +++ b/src/time_parser.rs @@ -51,14 +51,9 @@ impl TryFrom<&TimeParser> for i64 { impl TimeParser { pub fn new(input: String, timezone: Tz) -> Self { - let inverted = if input.starts_with("-") { - true - } - else { - false - }; + let inverted = input.starts_with('-'); - let parse_type = if input.contains("/") || input.contains(":") { + let parse_type = if input.contains('/') || input.contains(':') { ParseType::Explicit } else { @@ -68,7 +63,7 @@ impl TimeParser { Self { timezone, inverted, - time_string: input.trim_start_matches("-").to_string(), + time_string: input.trim_start_matches('-').to_string(), parse_type, } } @@ -109,10 +104,10 @@ impl TimeParser { fn process_explicit(&self) -> Result { - let segments = self.time_string.matches("-").count(); + let segments = self.time_string.matches('-').count(); let parse_string = if segments == 1 { - let slashes = self.time_string.matches("/").count(); + let slashes = self.time_string.matches('/').count(); match slashes { 0 => Ok("%d-".to_string()), @@ -123,7 +118,7 @@ impl TimeParser { } else { Ok("".to_string()) }? + if segments == 1 { - let colons = self.time_string.matches(":").count(); + let colons = self.time_string.matches(':').count(); match colons { 1 => Ok("%H:%M"),