diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index 52d126f..e81b8cd 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -14,8 +14,6 @@ use serenity::{ }, }; -use regex::Regex; - use chrono_tz::Tz; use chrono::offset::Utc; @@ -31,19 +29,16 @@ use crate::{ SQLPool, FrameworkCtx, framework::SendIterator, + consts::{ + REGEX_ALIAS, + REGEX_CHANNEL, + REGEX_COMMANDS, + REGEX_ROLE, + }, }; use std::iter; -lazy_static! { - static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); - - static ref REGEX_ROLE: Regex = Regex::new(r#"<@&([0-9]+)>"#).unwrap(); - - static ref REGEX_COMMANDS: Regex = Regex::new(r#"([a-z]+)"#).unwrap(); - - static ref REGEX_ALIAS: Regex = Regex::new(r#"(?P[\S]{1,12})(?:(?: (?P.*)$)|$)"#).unwrap(); -} #[command] #[supports_dm(false)] @@ -82,7 +77,6 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult { .get::().cloned().expect("Could not get SQLPool from data"); let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); - let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap(); if !args.is_empty() { match args.parse::() { @@ -106,7 +100,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult { } else { let content = user_data.response(&pool, "timezone/no_argument").await - .replace("{prefix}", &guild_data.prefix) + .replace("{prefix}", &GuildData::prefix_from_id(msg.guild_id, &pool).await) .replacen("{timezone}", &user_data.timezone, 1); let _ = msg.channel_id.say(&ctx, content).await; diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index a4430bf..1e8a762 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -23,6 +23,13 @@ use serenity::{ use tokio::process::Command; use crate::{ + consts::{ + REGEX_CHANNEL, + REGEX_CHANNEL_USER, + MIN_INTERVAL, + MAX_TIME, + CHARACTERS, + }, models::{ ChannelData, GuildData, @@ -71,18 +78,6 @@ use regex::Regex; use serde_json::json; -lazy_static! { - static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); - - static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"^\s*<(#|@)(?:!)?(\d+)>\s*$"#).unwrap(); - - static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL").ok().map(|inner| inner.parse::().ok()).flatten().unwrap_or(600); - - static ref MAX_TIME: i64 = env::var("MAX_TIME").ok().map(|inner| inner.parse::().ok()).flatten().unwrap_or(60*60*24*365*50); -} - -static CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; - #[command] #[supports_dm(false)] @@ -330,7 +325,7 @@ SELECT reminders.id, reminders.time, messages.content, channels.channel FROM reminders -INNER JOIN +LEFT OUTER JOIN channels ON channels.id = reminders.channel_id @@ -392,14 +387,14 @@ SELECT reminders.id, reminders.time, messages.content, channels.channel FROM reminders -INNER JOIN +LEFT OUTER JOIN channels ON - reminders.channel_id = channels.id + channels.id = reminders.channel_id INNER JOIN messages ON - reminders.message_id = messages.id + messages.id = reminders.message_id WHERE channels.guild_id = (SELECT id FROM guilds WHERE guild = ?) ", guild_id) @@ -467,17 +462,30 @@ WHERE .collect::>(); if parts.len() == valid_parts.len() { + let joined = valid_parts.join(","); + + let count_row = sqlx::query!( + " +SELECT COUNT(1) AS count FROM reminders WHERE FIND_IN_SET(id, ?) + ", joined) + .fetch_one(&pool) + .await + .unwrap(); + sqlx::query!( " DELETE FROM reminders WHERE FIND_IN_SET(id, ?) - ", valid_parts.join(",")) + ", joined) .execute(&pool) .await .unwrap(); // TODO add deletion events to event list - let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "del/count").await).await; + let content = user_data.response(&pool, "del/count").await + .replacen("{}", &count_row.count.to_string(), 1); + + let _ = msg.channel_id.say(&ctx, content).await; } } @@ -599,6 +607,7 @@ custom_error!{ReminderError InvalidTag = "Invalid reminder scope", NotEnoughArgs = "Not enough args", InvalidTime = "Invalid time provided", + NeedSubscription = "Subscription required and not found", DiscordError = "Bad response received from Discord" } @@ -618,6 +627,7 @@ impl ToResponse for ReminderError { Self::InvalidTag => "remind/invalid_tag", Self::NotEnoughArgs => "remind/no_argument", Self::InvalidTime => "remind/invalid_time", + Self::NeedSubscription => "interval/donor", Self::DiscordError => "remind/no_webhook", }.to_string() } @@ -625,6 +635,7 @@ impl ToResponse for ReminderError { fn to_response_natural(&self) -> String { match self { Self::LongTime => "natural/long_time".to_string(), + Self::InvalidTime => "natural/invalid_time".to_string(), _ => self.to_response(), } } @@ -709,6 +720,9 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem Err(ReminderError::NotEnoughArgs) } } + else if command == RemindCommand::Interval { + Err(ReminderError::NeedSubscription) + } else { let content = args_iter.collect::>().join(" "); @@ -769,6 +783,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem let offset = time_parser.map(|tp| tp.displacement().ok()).flatten().unwrap_or(0) as u64; let str_response = user_data.response(&pool, &response.to_response()).await + .replace("{prefix}", &GuildData::prefix_from_id(msg.guild_id, &pool).await) .replacen("{location}", &scope_id.mention(), 1) .replacen("{offset}", &shorthand_displacement(offset), 1) .replacen("{min_interval}", &MIN_INTERVAL.to_string(), 1) @@ -887,6 +902,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) -> CommandResult { let offset = timestamp as u64 - since_epoch.as_secs(); let str_response = user_data.response(&pool, &res.to_response_natural()).await + .replace("{prefix}", &GuildData::prefix_from_id(msg.guild_id, &pool).await) .replacen("{location}", &location_id.mention(), 1) .replacen("{offset}", &shorthand_displacement(offset), 1) .replacen("{min_interval}", &MIN_INTERVAL.to_string(), 1) @@ -1004,7 +1020,7 @@ async fn create_reminder, S: ToString + Type + Encode unix_time { + if time >= unix_time { if time > unix_time + *MAX_TIME { Err(ReminderError::LongTime) } @@ -1032,6 +1048,10 @@ INSERT INTO reminders (uid, message_id, channel_id, time, `interval`, method, se Ok(()) } } + else if time < 0 { + // case required for if python returns -1 + Err(ReminderError::InvalidTime) + } else { Err(ReminderError::PastTime) } diff --git a/src/consts.rs b/src/consts.rs index b6defb1..6c4385a 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -4,3 +4,43 @@ pub const MAX_MESSAGE_LENGTH: usize = 2048; pub const DAY: u64 = 86_400; pub const HOUR: u64 = 3_600; pub const MINUTE: u64 = 60; + +pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; + +use std::{ + iter::FromIterator, + env, + collections::HashSet, +}; + +use lazy_static; + +use regex::Regex; + +lazy_static! { + pub static ref SUBSCRIPTION_ROLES: HashSet = HashSet::from_iter(env::var("SUBSCRIPTION_ROLES") + .map( + |var| var + .split(',') + .filter_map(|item| { + item.parse::().ok() + }) + .collect::>() + ).unwrap_or_else(|_| vec![])); + + pub static ref CNC_GUILD: Option = env::var("CNC_GUILD").map(|var| var.parse::().ok()).ok().flatten(); + + pub static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); + + pub static ref REGEX_ROLE: Regex = Regex::new(r#"<@&([0-9]+)>"#).unwrap(); + + pub static ref REGEX_COMMANDS: Regex = Regex::new(r#"([a-z]+)"#).unwrap(); + + pub static ref REGEX_ALIAS: Regex = Regex::new(r#"(?P[\S]{1,12})(?:(?: (?P.*)$)|$)"#).unwrap(); + + pub static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"^\s*<(#|@)(?:!)?(\d+)>\s*$"#).unwrap(); + + pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL").ok().map(|inner| inner.parse::().ok()).flatten().unwrap_or(600); + + pub static ref MAX_TIME: i64 = env::var("MAX_TIME").ok().map(|inner| inner.parse::().ok()).flatten().unwrap_or(60*60*24*365*50); +} diff --git a/src/main.rs b/src/main.rs index 45d56de..ccf1e19 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,10 @@ mod consts; use serenity::{ cache::Cache, - http::CacheHttp, + http::{ + CacheHttp, + client::Http, + }, client::{ bridge::gateway::GatewayIntents, Client, @@ -43,6 +46,7 @@ use crate::{ framework::RegexFramework, consts::{ PREFIX, DAY, HOUR, MINUTE, + SUBSCRIPTION_ROLES, CNC_GUILD, }, commands::{ info_cmds, @@ -51,7 +55,9 @@ use crate::{ moderation_cmds, }, }; + use num_integer::Integer; +use serenity::futures::TryFutureExt; struct SQLPool; @@ -77,7 +83,11 @@ static THEME_COLOR: u32 = 0x8fb677; async fn main() -> Result<(), Box> { dotenv()?; - let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?) + let token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment"); + + let http = Http::new_with_token(&token); + + let framework = RegexFramework::new(http.get_current_user().map_ok(|user| user.id.as_u64().to_owned()).await?) .ignore_bots(true) .default_prefix(&env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string())) @@ -122,7 +132,7 @@ async fn main() -> Result<(), Box> { let framework_arc = Arc::new(Box::new(framework) as Box); - let mut client = Client::new(&env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment")) + let mut client = Client::new(&token) .intents(GatewayIntents::GUILD_MESSAGES | GatewayIntents::GUILDS | GatewayIntents::DIRECT_MESSAGES) .framework_arc(framework_arc.clone()) .await.expect("Error occurred creating client"); @@ -144,23 +154,13 @@ async fn main() -> Result<(), Box> { pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into) -> bool { - let role_ids = env::var("SUBSCRIPTION_ROLES") - .map( - |var| var - .split(',') - .filter_map(|item| { - item.parse::().ok() - }) - .collect::>() - ); - if let Some(subscription_guild) = env::var("CNC_GUILD").map(|var| var.parse::().ok()).ok().flatten() { - if let Ok(role_ids) = role_ids { - // todo remove unwrap and propagate error - let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await.unwrap(); + if let Some(subscription_guild) = *CNC_GUILD { + let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await; - for role in guild_member.roles { - if role_ids.contains(role.as_u64()) { + if let Ok(member) = guild_member { + for role in member.roles { + if SUBSCRIPTION_ROLES.contains(role.as_u64()) { return true } }