moved interval extracting into separate function. pass around the same pool instead of relocking

This commit is contained in:
jude 2020-09-15 14:43:49 +01:00
parent 89d7403a54
commit 9287fb5416
3 changed files with 136 additions and 50 deletions

View File

@ -3,6 +3,7 @@ use custom_error::custom_error;
use regex_command_attr::command; use regex_command_attr::command;
use serenity::{ use serenity::{
http::CacheHttp,
client::Context, client::Context,
model::{ model::{
misc::Mentionable, misc::Mentionable,
@ -26,12 +27,15 @@ use crate::{
Reminder, Reminder,
Timer, Timer,
}, },
check_subscription,
SQLPool, SQLPool,
time_parser::TimeParser, time_parser::TimeParser,
}; };
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use chrono_tz::Etc::UTC;
use rand::{ use rand::{
rngs::OsRng, rngs::OsRng,
RngCore, RngCore,
@ -51,6 +55,7 @@ use std::{
use regex::Regex; use regex::Regex;
use serde_json::json; use serde_json::json;
use sqlx::MySqlPool;
lazy_static! { lazy_static! {
static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
@ -519,6 +524,7 @@ DELETE FROM timers WHERE owner = ? AND name = ?
Ok(()) Ok(())
} }
#[derive(PartialEq)]
enum RemindCommand { enum RemindCommand {
Remind, Remind,
Interval, Interval,
@ -564,6 +570,16 @@ impl ReminderError {
} }
} }
fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
let mut bytes = vec![0u8, 64];
generator.fill_bytes(&mut bytes);
bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::<Vec<String>>().join("")
}
#[command] #[command]
#[permission_level(Managed)] #[permission_level(Managed)]
async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult {
@ -581,24 +597,70 @@ async fn interval(ctx: &Context, msg: &Message, args: String) -> CommandResult {
} }
async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) { async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) {
let user_data;
async fn check_interval(
ctx: impl CacheHttp,
msg: &Message,
mut args_iter: impl Iterator<Item=&str>,
scope_id: &ReminderScope,
time_parser: &TimeParser,
command: RemindCommand,
pool: &MySqlPool)
-> Result<(), ReminderError> {
if command == RemindCommand::Interval && check_subscription(&ctx, &msg.author).await {
if let Some(interval_arg) = args_iter.next() {
let interval = TimeParser::new(interval_arg.to_string(), UTC);
if let Ok(interval_seconds) = interval.displacement() {
let content = args_iter.collect::<Vec<&str>>().join(" ");
create_reminder(
ctx,
pool,
msg.author.id.as_u64().to_owned(),
msg.guild_id,
scope_id,
time_parser,
Some(interval_seconds as u32),
content).await
}
else {
Err(ReminderError::InvalidTime)
}
}
else {
Err(ReminderError::NotEnoughArgs)
}
}
else {
let content = args_iter.collect::<Vec<&str>>().join(" ");
create_reminder(
ctx,
pool,
msg.author.id.as_u64().to_owned(),
msg.guild_id,
scope_id,
time_parser,
None,
content).await
}
}
let pool = ctx.data.read().await let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let mut args_iter = args.split(' ').filter(|s| s.len() > 0); let mut args_iter = args.split(' ').filter(|s| s.len() > 0);
if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
let scope_id;
let mut time_parser = None; let mut time_parser = None;
let content; let mut scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned());
let guild_id = msg.guild_id; // todo reimplement using next_if and Peekable
let response = if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
let response = if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER
.captures(&first_arg) .captures(&first_arg)
.map(|cap| (cap.get(1), cap.get(2))) { .map(|cap| (cap.get(1), cap.get(2))) {
@ -612,22 +674,20 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
if let Some(next) = args_iter.next().map(|inner| inner.to_string()) { if let Some(next) = args_iter.next().map(|inner| inner.to_string()) {
time_parser = Some(TimeParser::new(next, user_data.timezone.parse().unwrap())); time_parser = Some(TimeParser::new(next, user_data.timezone.parse().unwrap()));
content = args_iter.collect::<Vec<&str>>().join(" "); check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await
create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await
} }
else { else {
Err(ReminderError::NotEnoughArgs) Err(ReminderError::NotEnoughArgs)
} }
} }
else { else {
scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned()); time_parser = Some(TimeParser::new(first_arg, user_data.timezone()));
time_parser = Some(TimeParser::new(first_arg, user_data.timezone.parse().unwrap())); check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await
}
content = args_iter.collect::<Vec<&str>>().join(" "); }
else {
create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await Err(ReminderError::NotEnoughArgs)
}; };
let str_response = match response { let str_response = match response {
@ -642,17 +702,18 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
let _ = msg.channel_id.say(&ctx, &str_response).await; let _ = msg.channel_id.say(&ctx, &str_response).await;
} }
else {
} async fn create_reminder(
} ctx: impl CacheHttp,
pool: &MySqlPool,
async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>, scope_id: &ReminderScope, time_parser: &TimeParser, content: String) user_id: u64,
guild_id: Option<GuildId>,
scope_id: &ReminderScope,
time_parser: &TimeParser,
interval: Option<u32>,
content: String)
-> Result<(), ReminderError> { -> Result<(), ReminderError> {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let db_channel_id = match scope_id { let db_channel_id = match scope_id {
ReminderScope::User(user_id) => { ReminderScope::User(user_id) => {
let user = UserId(*user_id).to_user(&ctx).await.unwrap(); let user = UserId(*user_id).to_user(&ctx).await.unwrap();
@ -673,7 +734,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>,
if let Some(guild_channel) = channel.guild() { if let Some(guild_channel) = channel.guild() {
if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() { if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() {
if let Ok(webhook) = ctx.http.create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await { if let Ok(webhook) = ctx.http().create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await {
channel_data.webhook_id = Some(webhook.id.as_u64().to_owned()); channel_data.webhook_id = Some(webhook.id.as_u64().to_owned());
channel_data.webhook_token = Some(webhook.token); channel_data.webhook_token = Some(webhook.token);
@ -707,7 +768,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>,
" "
INSERT INTO messages (content) VALUES (?) INSERT INTO messages (content) VALUES (?)
", content) ", content)
.execute(&pool) .execute(&pool.clone())
.await .await
.unwrap(); .unwrap();
@ -719,7 +780,7 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES
?, ?, 'remind', ?, ?, 'remind',
(SELECT id FROM users WHERE user = ? LIMIT 1)) (SELECT id FROM users WHERE user = ? LIMIT 1))
", generate_uid(), content, db_channel_id, time as u32, user_id) ", generate_uid(), content, db_channel_id, time as u32, user_id)
.execute(&pool) .execute(pool)
.await .await
.unwrap(); .unwrap();
@ -737,13 +798,3 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES
} }
} }
} }
fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
let mut bytes = vec![0u8, 64];
generator.fill_bytes(&mut bytes);
bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::<Vec<String>>().join("")
}

View File

@ -7,10 +7,14 @@ mod commands;
mod time_parser; mod time_parser;
use serenity::{ use serenity::{
http::CacheHttp,
client::{ client::{
bridge::gateway::GatewayIntents, bridge::gateway::GatewayIntents,
Client, Client,
}, },
model::id::{
GuildId, UserId,
},
framework::Framework, framework::Framework,
prelude::TypeMapKey, prelude::TypeMapKey,
}; };
@ -122,3 +126,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(()) Ok(())
} }
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
let role_ids = env::var("SUBSCRIPTION_ROLES")
.map(
|var| var
.split(",")
.filter_map(|item| {
item.parse::<u64>().ok()
})
.collect::<Vec<u64>>()
);
if let Some(subscription_guild) = env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten() {
if let Ok(role_ids) = role_ids {
// todo remove unwrap and propagate error
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await.unwrap();
for role in guild_member.roles {
if role_ids.contains(role.as_u64()) {
return true
}
}
}
false
}
else {
true
}
}

View File

@ -1,5 +1,5 @@
use serenity::{ use serenity::{
prelude::Context, http::CacheHttp,
model::{ model::{
guild::Guild, guild::Guild,
channel::Channel, channel::Channel,
@ -141,7 +141,7 @@ pub struct UserData {
} }
impl UserData { impl UserData {
pub async fn from_user(user: &User, ctx: &&Context, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> { pub async fn from_user(user: &User, ctx: impl CacheHttp, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().clone(); let user_id = user.id.as_u64().clone();
if let Ok(c) = sqlx::query_as_unchecked!(Self, if let Ok(c) = sqlx::query_as_unchecked!(Self,