diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 6422241..04add6f 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -3,6 +3,7 @@ use custom_error::custom_error; use regex_command_attr::command; use serenity::{ + http::CacheHttp, client::Context, model::{ misc::Mentionable, @@ -26,12 +27,15 @@ use crate::{ Reminder, Timer, }, + check_subscription, SQLPool, time_parser::TimeParser, }; use chrono::NaiveDateTime; +use chrono_tz::Etc::UTC; + use rand::{ rngs::OsRng, RngCore, @@ -51,6 +55,7 @@ use std::{ use regex::Regex; use serde_json::json; +use sqlx::MySqlPool; lazy_static! { static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); @@ -519,6 +524,7 @@ DELETE FROM timers WHERE owner = ? AND name = ? Ok(()) } +#[derive(PartialEq)] enum RemindCommand { Remind, Interval, @@ -564,6 +570,16 @@ impl ReminderError { } } +fn generate_uid() -> String { + let mut generator: OsRng = Default::default(); + + let mut bytes = vec![0u8, 64]; + + generator.fill_bytes(&mut bytes); + + bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::>().join("") +} + #[command] #[permission_level(Managed)] async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult { @@ -581,24 +597,70 @@ async fn interval(ctx: &Context, msg: &Message, args: String) -> CommandResult { } async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) { - let user_data; + + async fn check_interval( + ctx: impl CacheHttp, + msg: &Message, + mut args_iter: impl Iterator, + scope_id: &ReminderScope, + time_parser: &TimeParser, + command: RemindCommand, + pool: &MySqlPool) + -> Result<(), ReminderError> { + + if command == RemindCommand::Interval && check_subscription(&ctx, &msg.author).await { + if let Some(interval_arg) = args_iter.next() { + let interval = TimeParser::new(interval_arg.to_string(), UTC); + + if let Ok(interval_seconds) = interval.displacement() { + let content = args_iter.collect::>().join(" "); + + create_reminder( + ctx, + pool, + msg.author.id.as_u64().to_owned(), + msg.guild_id, + scope_id, + time_parser, + Some(interval_seconds as u32), + content).await + } + else { + Err(ReminderError::InvalidTime) + } + } + else { + Err(ReminderError::NotEnoughArgs) + } + } + else { + let content = args_iter.collect::>().join(" "); + + create_reminder( + ctx, + pool, + msg.author.id.as_u64().to_owned(), + msg.guild_id, + scope_id, + time_parser, + None, + content).await + } + } let pool = ctx.data.read().await .get::().cloned().expect("Could not get SQLPool from data"); - user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); + let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let mut args_iter = args.split(' ').filter(|s| s.len() > 0); - if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) { + let mut time_parser = None; + let mut scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned()); - let scope_id; - let mut time_parser = None; - let content; - - let guild_id = msg.guild_id; - - let response = if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER + // todo reimplement using next_if and Peekable + let response = if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) { + if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER .captures(&first_arg) .map(|cap| (cap.get(1), cap.get(2))) { @@ -612,47 +674,46 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem if let Some(next) = args_iter.next().map(|inner| inner.to_string()) { time_parser = Some(TimeParser::new(next, user_data.timezone.parse().unwrap())); - content = args_iter.collect::>().join(" "); - - create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await + check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await } else { Err(ReminderError::NotEnoughArgs) } } else { - scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned()); + time_parser = Some(TimeParser::new(first_arg, user_data.timezone())); - time_parser = Some(TimeParser::new(first_arg, user_data.timezone.parse().unwrap())); - - content = args_iter.collect::>().join(" "); - - create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await - }; - - let str_response = match response { - Ok(_) => user_data.response(&pool, "remind/success").await, - - Err(reminder_error) => user_data.response(&pool, &reminder_error.to_response()).await, + check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await } - .replacen("{location}", &scope_id.mention(), 1) - .replacen("{offset}", &time_parser.map(|tp| tp.displacement().ok()).flatten().unwrap_or(-1).to_string(), 1) - .replacen("{min_interval}", "min_interval", 1) - .replacen("{max_time}", "max_time", 1); - - let _ = msg.channel_id.say(&ctx, &str_response).await; } else { + Err(ReminderError::NotEnoughArgs) + }; + let str_response = match response { + Ok(_) => user_data.response(&pool, "remind/success").await, + + Err(reminder_error) => user_data.response(&pool, &reminder_error.to_response()).await, } + .replacen("{location}", &scope_id.mention(), 1) + .replacen("{offset}", &time_parser.map(|tp| tp.displacement().ok()).flatten().unwrap_or(-1).to_string(), 1) + .replacen("{min_interval}", "min_interval", 1) + .replacen("{max_time}", "max_time", 1); + + let _ = msg.channel_id.say(&ctx, &str_response).await; } -async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option, scope_id: &ReminderScope, time_parser: &TimeParser, content: String) +async fn create_reminder( + ctx: impl CacheHttp, + pool: &MySqlPool, + user_id: u64, + guild_id: Option, + scope_id: &ReminderScope, + time_parser: &TimeParser, + interval: Option, + content: String) -> Result<(), ReminderError> { - let pool = ctx.data.read().await - .get::().cloned().expect("Could not get SQLPool from data"); - let db_channel_id = match scope_id { ReminderScope::User(user_id) => { let user = UserId(*user_id).to_user(&ctx).await.unwrap(); @@ -673,7 +734,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option, if let Some(guild_channel) = channel.guild() { if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() { - if let Ok(webhook) = ctx.http.create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await { + if let Ok(webhook) = ctx.http().create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await { channel_data.webhook_id = Some(webhook.id.as_u64().to_owned()); channel_data.webhook_token = Some(webhook.token); @@ -707,7 +768,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option, " INSERT INTO messages (content) VALUES (?) ", content) - .execute(&pool) + .execute(&pool.clone()) .await .unwrap(); @@ -719,7 +780,7 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES ?, ?, 'remind', (SELECT id FROM users WHERE user = ? LIMIT 1)) ", generate_uid(), content, db_channel_id, time as u32, user_id) - .execute(&pool) + .execute(pool) .await .unwrap(); @@ -737,13 +798,3 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES } } } - -fn generate_uid() -> String { - let mut generator: OsRng = Default::default(); - - let mut bytes = vec![0u8, 64]; - - generator.fill_bytes(&mut bytes); - - bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::>().join("") -} diff --git a/src/main.rs b/src/main.rs index 6c071bd..1f57848 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,10 +7,14 @@ mod commands; mod time_parser; use serenity::{ + http::CacheHttp, client::{ bridge::gateway::GatewayIntents, Client, }, + model::id::{ + GuildId, UserId, + }, framework::Framework, prelude::TypeMapKey, }; @@ -122,3 +126,34 @@ async fn main() -> Result<(), Box> { Ok(()) } + + +pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into) -> bool { + let role_ids = env::var("SUBSCRIPTION_ROLES") + .map( + |var| var + .split(",") + .filter_map(|item| { + item.parse::().ok() + }) + .collect::>() + ); + + if let Some(subscription_guild) = env::var("CNC_GUILD").map(|var| var.parse::().ok()).ok().flatten() { + if let Ok(role_ids) = role_ids { + // todo remove unwrap and propagate error + let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await.unwrap(); + + for role in guild_member.roles { + if role_ids.contains(role.as_u64()) { + return true + } + } + } + + false + } + else { + true + } +} diff --git a/src/models.rs b/src/models.rs index 9a52443..6680963 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,5 +1,5 @@ use serenity::{ - prelude::Context, + http::CacheHttp, model::{ guild::Guild, channel::Channel, @@ -141,7 +141,7 @@ pub struct UserData { } impl UserData { - pub async fn from_user(user: &User, ctx: &&Context, pool: &MySqlPool) -> Result> { + pub async fn from_user(user: &User, ctx: impl CacheHttp, pool: &MySqlPool) -> Result> { let user_id = user.id.as_u64().clone(); if let Ok(c) = sqlx::query_as_unchecked!(Self,