moved interval extracting into separate function. pass around the same pool instead of relocking

This commit is contained in:
jude 2020-09-15 14:43:49 +01:00
parent 89d7403a54
commit 9287fb5416
3 changed files with 136 additions and 50 deletions

View File

@ -3,6 +3,7 @@ use custom_error::custom_error;
use regex_command_attr::command;
use serenity::{
http::CacheHttp,
client::Context,
model::{
misc::Mentionable,
@ -26,12 +27,15 @@ use crate::{
Reminder,
Timer,
},
check_subscription,
SQLPool,
time_parser::TimeParser,
};
use chrono::NaiveDateTime;
use chrono_tz::Etc::UTC;
use rand::{
rngs::OsRng,
RngCore,
@ -51,6 +55,7 @@ use std::{
use regex::Regex;
use serde_json::json;
use sqlx::MySqlPool;
lazy_static! {
static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
@ -519,6 +524,7 @@ DELETE FROM timers WHERE owner = ? AND name = ?
Ok(())
}
#[derive(PartialEq)]
enum RemindCommand {
Remind,
Interval,
@ -564,6 +570,16 @@ impl ReminderError {
}
}
fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
let mut bytes = vec![0u8, 64];
generator.fill_bytes(&mut bytes);
bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::<Vec<String>>().join("")
}
#[command]
#[permission_level(Managed)]
async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult {
@ -581,24 +597,70 @@ async fn interval(ctx: &Context, msg: &Message, args: String) -> CommandResult {
}
async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) {
let user_data;
async fn check_interval(
ctx: impl CacheHttp,
msg: &Message,
mut args_iter: impl Iterator<Item=&str>,
scope_id: &ReminderScope,
time_parser: &TimeParser,
command: RemindCommand,
pool: &MySqlPool)
-> Result<(), ReminderError> {
if command == RemindCommand::Interval && check_subscription(&ctx, &msg.author).await {
if let Some(interval_arg) = args_iter.next() {
let interval = TimeParser::new(interval_arg.to_string(), UTC);
if let Ok(interval_seconds) = interval.displacement() {
let content = args_iter.collect::<Vec<&str>>().join(" ");
create_reminder(
ctx,
pool,
msg.author.id.as_u64().to_owned(),
msg.guild_id,
scope_id,
time_parser,
Some(interval_seconds as u32),
content).await
}
else {
Err(ReminderError::InvalidTime)
}
}
else {
Err(ReminderError::NotEnoughArgs)
}
}
else {
let content = args_iter.collect::<Vec<&str>>().join(" ");
create_reminder(
ctx,
pool,
msg.author.id.as_u64().to_owned(),
msg.guild_id,
scope_id,
time_parser,
None,
content).await
}
}
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let mut args_iter = args.split(' ').filter(|s| s.len() > 0);
if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
let mut time_parser = None;
let mut scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned());
let scope_id;
let mut time_parser = None;
let content;
let guild_id = msg.guild_id;
let response = if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER
// todo reimplement using next_if and Peekable
let response = if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER
.captures(&first_arg)
.map(|cap| (cap.get(1), cap.get(2))) {
@ -612,47 +674,46 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
if let Some(next) = args_iter.next().map(|inner| inner.to_string()) {
time_parser = Some(TimeParser::new(next, user_data.timezone.parse().unwrap()));
content = args_iter.collect::<Vec<&str>>().join(" ");
create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await
check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await
}
else {
Err(ReminderError::NotEnoughArgs)
}
}
else {
scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned());
time_parser = Some(TimeParser::new(first_arg, user_data.timezone()));
time_parser = Some(TimeParser::new(first_arg, user_data.timezone.parse().unwrap()));
content = args_iter.collect::<Vec<&str>>().join(" ");
create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await
};
let str_response = match response {
Ok(_) => user_data.response(&pool, "remind/success").await,
Err(reminder_error) => user_data.response(&pool, &reminder_error.to_response()).await,
check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await
}
.replacen("{location}", &scope_id.mention(), 1)
.replacen("{offset}", &time_parser.map(|tp| tp.displacement().ok()).flatten().unwrap_or(-1).to_string(), 1)
.replacen("{min_interval}", "min_interval", 1)
.replacen("{max_time}", "max_time", 1);
let _ = msg.channel_id.say(&ctx, &str_response).await;
}
else {
Err(ReminderError::NotEnoughArgs)
};
let str_response = match response {
Ok(_) => user_data.response(&pool, "remind/success").await,
Err(reminder_error) => user_data.response(&pool, &reminder_error.to_response()).await,
}
.replacen("{location}", &scope_id.mention(), 1)
.replacen("{offset}", &time_parser.map(|tp| tp.displacement().ok()).flatten().unwrap_or(-1).to_string(), 1)
.replacen("{min_interval}", "min_interval", 1)
.replacen("{max_time}", "max_time", 1);
let _ = msg.channel_id.say(&ctx, &str_response).await;
}
async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>, scope_id: &ReminderScope, time_parser: &TimeParser, content: String)
async fn create_reminder(
ctx: impl CacheHttp,
pool: &MySqlPool,
user_id: u64,
guild_id: Option<GuildId>,
scope_id: &ReminderScope,
time_parser: &TimeParser,
interval: Option<u32>,
content: String)
-> Result<(), ReminderError> {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let db_channel_id = match scope_id {
ReminderScope::User(user_id) => {
let user = UserId(*user_id).to_user(&ctx).await.unwrap();
@ -673,7 +734,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>,
if let Some(guild_channel) = channel.guild() {
if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() {
if let Ok(webhook) = ctx.http.create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await {
if let Ok(webhook) = ctx.http().create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await {
channel_data.webhook_id = Some(webhook.id.as_u64().to_owned());
channel_data.webhook_token = Some(webhook.token);
@ -707,7 +768,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>,
"
INSERT INTO messages (content) VALUES (?)
", content)
.execute(&pool)
.execute(&pool.clone())
.await
.unwrap();
@ -719,7 +780,7 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES
?, ?, 'remind',
(SELECT id FROM users WHERE user = ? LIMIT 1))
", generate_uid(), content, db_channel_id, time as u32, user_id)
.execute(&pool)
.execute(pool)
.await
.unwrap();
@ -737,13 +798,3 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES
}
}
}
fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
let mut bytes = vec![0u8, 64];
generator.fill_bytes(&mut bytes);
bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::<Vec<String>>().join("")
}

View File

@ -7,10 +7,14 @@ mod commands;
mod time_parser;
use serenity::{
http::CacheHttp,
client::{
bridge::gateway::GatewayIntents,
Client,
},
model::id::{
GuildId, UserId,
},
framework::Framework,
prelude::TypeMapKey,
};
@ -122,3 +126,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
let role_ids = env::var("SUBSCRIPTION_ROLES")
.map(
|var| var
.split(",")
.filter_map(|item| {
item.parse::<u64>().ok()
})
.collect::<Vec<u64>>()
);
if let Some(subscription_guild) = env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten() {
if let Ok(role_ids) = role_ids {
// todo remove unwrap and propagate error
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await.unwrap();
for role in guild_member.roles {
if role_ids.contains(role.as_u64()) {
return true
}
}
}
false
}
else {
true
}
}

View File

@ -1,5 +1,5 @@
use serenity::{
prelude::Context,
http::CacheHttp,
model::{
guild::Guild,
channel::Channel,
@ -141,7 +141,7 @@ pub struct UserData {
}
impl UserData {
pub async fn from_user(user: &User, ctx: &&Context, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
pub async fn from_user(user: &User, ctx: impl CacheHttp, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().clone();
if let Ok(c) = sqlx::query_as_unchecked!(Self,