moved interval extracting into separate function. pass around the same pool instead of relocking
This commit is contained in:
parent
89d7403a54
commit
9287fb5416
@ -3,6 +3,7 @@ use custom_error::custom_error;
|
||||
use regex_command_attr::command;
|
||||
|
||||
use serenity::{
|
||||
http::CacheHttp,
|
||||
client::Context,
|
||||
model::{
|
||||
misc::Mentionable,
|
||||
@ -26,12 +27,15 @@ use crate::{
|
||||
Reminder,
|
||||
Timer,
|
||||
},
|
||||
check_subscription,
|
||||
SQLPool,
|
||||
time_parser::TimeParser,
|
||||
};
|
||||
|
||||
use chrono::NaiveDateTime;
|
||||
|
||||
use chrono_tz::Etc::UTC;
|
||||
|
||||
use rand::{
|
||||
rngs::OsRng,
|
||||
RngCore,
|
||||
@ -51,6 +55,7 @@ use std::{
|
||||
use regex::Regex;
|
||||
|
||||
use serde_json::json;
|
||||
use sqlx::MySqlPool;
|
||||
|
||||
lazy_static! {
|
||||
static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
|
||||
@ -519,6 +524,7 @@ DELETE FROM timers WHERE owner = ? AND name = ?
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum RemindCommand {
|
||||
Remind,
|
||||
Interval,
|
||||
@ -564,6 +570,16 @@ impl ReminderError {
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_uid() -> String {
|
||||
let mut generator: OsRng = Default::default();
|
||||
|
||||
let mut bytes = vec![0u8, 64];
|
||||
|
||||
generator.fill_bytes(&mut bytes);
|
||||
|
||||
bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::<Vec<String>>().join("")
|
||||
}
|
||||
|
||||
#[command]
|
||||
#[permission_level(Managed)]
|
||||
async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult {
|
||||
@ -581,24 +597,70 @@ async fn interval(ctx: &Context, msg: &Message, args: String) -> CommandResult {
|
||||
}
|
||||
|
||||
async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) {
|
||||
let user_data;
|
||||
|
||||
async fn check_interval(
|
||||
ctx: impl CacheHttp,
|
||||
msg: &Message,
|
||||
mut args_iter: impl Iterator<Item=&str>,
|
||||
scope_id: &ReminderScope,
|
||||
time_parser: &TimeParser,
|
||||
command: RemindCommand,
|
||||
pool: &MySqlPool)
|
||||
-> Result<(), ReminderError> {
|
||||
|
||||
if command == RemindCommand::Interval && check_subscription(&ctx, &msg.author).await {
|
||||
if let Some(interval_arg) = args_iter.next() {
|
||||
let interval = TimeParser::new(interval_arg.to_string(), UTC);
|
||||
|
||||
if let Ok(interval_seconds) = interval.displacement() {
|
||||
let content = args_iter.collect::<Vec<&str>>().join(" ");
|
||||
|
||||
create_reminder(
|
||||
ctx,
|
||||
pool,
|
||||
msg.author.id.as_u64().to_owned(),
|
||||
msg.guild_id,
|
||||
scope_id,
|
||||
time_parser,
|
||||
Some(interval_seconds as u32),
|
||||
content).await
|
||||
}
|
||||
else {
|
||||
Err(ReminderError::InvalidTime)
|
||||
}
|
||||
}
|
||||
else {
|
||||
Err(ReminderError::NotEnoughArgs)
|
||||
}
|
||||
}
|
||||
else {
|
||||
let content = args_iter.collect::<Vec<&str>>().join(" ");
|
||||
|
||||
create_reminder(
|
||||
ctx,
|
||||
pool,
|
||||
msg.author.id.as_u64().to_owned(),
|
||||
msg.guild_id,
|
||||
scope_id,
|
||||
time_parser,
|
||||
None,
|
||||
content).await
|
||||
}
|
||||
}
|
||||
|
||||
let pool = ctx.data.read().await
|
||||
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
|
||||
|
||||
user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
|
||||
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
|
||||
|
||||
let mut args_iter = args.split(' ').filter(|s| s.len() > 0);
|
||||
|
||||
if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
|
||||
let mut time_parser = None;
|
||||
let mut scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned());
|
||||
|
||||
let scope_id;
|
||||
let mut time_parser = None;
|
||||
let content;
|
||||
|
||||
let guild_id = msg.guild_id;
|
||||
|
||||
let response = if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER
|
||||
// todo reimplement using next_if and Peekable
|
||||
let response = if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
|
||||
if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER
|
||||
.captures(&first_arg)
|
||||
.map(|cap| (cap.get(1), cap.get(2))) {
|
||||
|
||||
@ -612,47 +674,46 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
|
||||
if let Some(next) = args_iter.next().map(|inner| inner.to_string()) {
|
||||
time_parser = Some(TimeParser::new(next, user_data.timezone.parse().unwrap()));
|
||||
|
||||
content = args_iter.collect::<Vec<&str>>().join(" ");
|
||||
|
||||
create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await
|
||||
check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await
|
||||
}
|
||||
else {
|
||||
Err(ReminderError::NotEnoughArgs)
|
||||
}
|
||||
}
|
||||
else {
|
||||
scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned());
|
||||
time_parser = Some(TimeParser::new(first_arg, user_data.timezone()));
|
||||
|
||||
time_parser = Some(TimeParser::new(first_arg, user_data.timezone.parse().unwrap()));
|
||||
|
||||
content = args_iter.collect::<Vec<&str>>().join(" ");
|
||||
|
||||
create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await
|
||||
};
|
||||
|
||||
let str_response = match response {
|
||||
Ok(_) => user_data.response(&pool, "remind/success").await,
|
||||
|
||||
Err(reminder_error) => user_data.response(&pool, &reminder_error.to_response()).await,
|
||||
check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await
|
||||
}
|
||||
.replacen("{location}", &scope_id.mention(), 1)
|
||||
.replacen("{offset}", &time_parser.map(|tp| tp.displacement().ok()).flatten().unwrap_or(-1).to_string(), 1)
|
||||
.replacen("{min_interval}", "min_interval", 1)
|
||||
.replacen("{max_time}", "max_time", 1);
|
||||
|
||||
let _ = msg.channel_id.say(&ctx, &str_response).await;
|
||||
}
|
||||
else {
|
||||
Err(ReminderError::NotEnoughArgs)
|
||||
};
|
||||
|
||||
let str_response = match response {
|
||||
Ok(_) => user_data.response(&pool, "remind/success").await,
|
||||
|
||||
Err(reminder_error) => user_data.response(&pool, &reminder_error.to_response()).await,
|
||||
}
|
||||
.replacen("{location}", &scope_id.mention(), 1)
|
||||
.replacen("{offset}", &time_parser.map(|tp| tp.displacement().ok()).flatten().unwrap_or(-1).to_string(), 1)
|
||||
.replacen("{min_interval}", "min_interval", 1)
|
||||
.replacen("{max_time}", "max_time", 1);
|
||||
|
||||
let _ = msg.channel_id.say(&ctx, &str_response).await;
|
||||
}
|
||||
|
||||
async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>, scope_id: &ReminderScope, time_parser: &TimeParser, content: String)
|
||||
async fn create_reminder(
|
||||
ctx: impl CacheHttp,
|
||||
pool: &MySqlPool,
|
||||
user_id: u64,
|
||||
guild_id: Option<GuildId>,
|
||||
scope_id: &ReminderScope,
|
||||
time_parser: &TimeParser,
|
||||
interval: Option<u32>,
|
||||
content: String)
|
||||
-> Result<(), ReminderError> {
|
||||
|
||||
let pool = ctx.data.read().await
|
||||
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
|
||||
|
||||
let db_channel_id = match scope_id {
|
||||
ReminderScope::User(user_id) => {
|
||||
let user = UserId(*user_id).to_user(&ctx).await.unwrap();
|
||||
@ -673,7 +734,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>,
|
||||
|
||||
if let Some(guild_channel) = channel.guild() {
|
||||
if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() {
|
||||
if let Ok(webhook) = ctx.http.create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await {
|
||||
if let Ok(webhook) = ctx.http().create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await {
|
||||
channel_data.webhook_id = Some(webhook.id.as_u64().to_owned());
|
||||
channel_data.webhook_token = Some(webhook.token);
|
||||
|
||||
@ -707,7 +768,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>,
|
||||
"
|
||||
INSERT INTO messages (content) VALUES (?)
|
||||
", content)
|
||||
.execute(&pool)
|
||||
.execute(&pool.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -719,7 +780,7 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES
|
||||
?, ?, 'remind',
|
||||
(SELECT id FROM users WHERE user = ? LIMIT 1))
|
||||
", generate_uid(), content, db_channel_id, time as u32, user_id)
|
||||
.execute(&pool)
|
||||
.execute(pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -737,13 +798,3 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_uid() -> String {
|
||||
let mut generator: OsRng = Default::default();
|
||||
|
||||
let mut bytes = vec![0u8, 64];
|
||||
|
||||
generator.fill_bytes(&mut bytes);
|
||||
|
||||
bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::<Vec<String>>().join("")
|
||||
}
|
||||
|
35
src/main.rs
35
src/main.rs
@ -7,10 +7,14 @@ mod commands;
|
||||
mod time_parser;
|
||||
|
||||
use serenity::{
|
||||
http::CacheHttp,
|
||||
client::{
|
||||
bridge::gateway::GatewayIntents,
|
||||
Client,
|
||||
},
|
||||
model::id::{
|
||||
GuildId, UserId,
|
||||
},
|
||||
framework::Framework,
|
||||
prelude::TypeMapKey,
|
||||
};
|
||||
@ -122,3 +126,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
|
||||
let role_ids = env::var("SUBSCRIPTION_ROLES")
|
||||
.map(
|
||||
|var| var
|
||||
.split(",")
|
||||
.filter_map(|item| {
|
||||
item.parse::<u64>().ok()
|
||||
})
|
||||
.collect::<Vec<u64>>()
|
||||
);
|
||||
|
||||
if let Some(subscription_guild) = env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten() {
|
||||
if let Ok(role_ids) = role_ids {
|
||||
// todo remove unwrap and propagate error
|
||||
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await.unwrap();
|
||||
|
||||
for role in guild_member.roles {
|
||||
if role_ids.contains(role.as_u64()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use serenity::{
|
||||
prelude::Context,
|
||||
http::CacheHttp,
|
||||
model::{
|
||||
guild::Guild,
|
||||
channel::Channel,
|
||||
@ -141,7 +141,7 @@ pub struct UserData {
|
||||
}
|
||||
|
||||
impl UserData {
|
||||
pub async fn from_user(user: &User, ctx: &&Context, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
|
||||
pub async fn from_user(user: &User, ctx: impl CacheHttp, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
|
||||
let user_id = user.id.as_u64().clone();
|
||||
|
||||
if let Ok(c) = sqlx::query_as_unchecked!(Self,
|
||||
|
Loading…
Reference in New Issue
Block a user