moved interval extracting into separate function. pass around the same pool instead of relocking
This commit is contained in:
parent
89d7403a54
commit
9287fb5416
@ -3,6 +3,7 @@ use custom_error::custom_error;
|
|||||||
use regex_command_attr::command;
|
use regex_command_attr::command;
|
||||||
|
|
||||||
use serenity::{
|
use serenity::{
|
||||||
|
http::CacheHttp,
|
||||||
client::Context,
|
client::Context,
|
||||||
model::{
|
model::{
|
||||||
misc::Mentionable,
|
misc::Mentionable,
|
||||||
@ -26,12 +27,15 @@ use crate::{
|
|||||||
Reminder,
|
Reminder,
|
||||||
Timer,
|
Timer,
|
||||||
},
|
},
|
||||||
|
check_subscription,
|
||||||
SQLPool,
|
SQLPool,
|
||||||
time_parser::TimeParser,
|
time_parser::TimeParser,
|
||||||
};
|
};
|
||||||
|
|
||||||
use chrono::NaiveDateTime;
|
use chrono::NaiveDateTime;
|
||||||
|
|
||||||
|
use chrono_tz::Etc::UTC;
|
||||||
|
|
||||||
use rand::{
|
use rand::{
|
||||||
rngs::OsRng,
|
rngs::OsRng,
|
||||||
RngCore,
|
RngCore,
|
||||||
@ -51,6 +55,7 @@ use std::{
|
|||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use sqlx::MySqlPool;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
|
static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
|
||||||
@ -519,6 +524,7 @@ DELETE FROM timers WHERE owner = ? AND name = ?
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq)]
|
||||||
enum RemindCommand {
|
enum RemindCommand {
|
||||||
Remind,
|
Remind,
|
||||||
Interval,
|
Interval,
|
||||||
@ -564,6 +570,16 @@ impl ReminderError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn generate_uid() -> String {
|
||||||
|
let mut generator: OsRng = Default::default();
|
||||||
|
|
||||||
|
let mut bytes = vec![0u8, 64];
|
||||||
|
|
||||||
|
generator.fill_bytes(&mut bytes);
|
||||||
|
|
||||||
|
bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::<Vec<String>>().join("")
|
||||||
|
}
|
||||||
|
|
||||||
#[command]
|
#[command]
|
||||||
#[permission_level(Managed)]
|
#[permission_level(Managed)]
|
||||||
async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult {
|
async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult {
|
||||||
@ -581,24 +597,70 @@ async fn interval(ctx: &Context, msg: &Message, args: String) -> CommandResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) {
|
async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) {
|
||||||
let user_data;
|
|
||||||
|
async fn check_interval(
|
||||||
|
ctx: impl CacheHttp,
|
||||||
|
msg: &Message,
|
||||||
|
mut args_iter: impl Iterator<Item=&str>,
|
||||||
|
scope_id: &ReminderScope,
|
||||||
|
time_parser: &TimeParser,
|
||||||
|
command: RemindCommand,
|
||||||
|
pool: &MySqlPool)
|
||||||
|
-> Result<(), ReminderError> {
|
||||||
|
|
||||||
|
if command == RemindCommand::Interval && check_subscription(&ctx, &msg.author).await {
|
||||||
|
if let Some(interval_arg) = args_iter.next() {
|
||||||
|
let interval = TimeParser::new(interval_arg.to_string(), UTC);
|
||||||
|
|
||||||
|
if let Ok(interval_seconds) = interval.displacement() {
|
||||||
|
let content = args_iter.collect::<Vec<&str>>().join(" ");
|
||||||
|
|
||||||
|
create_reminder(
|
||||||
|
ctx,
|
||||||
|
pool,
|
||||||
|
msg.author.id.as_u64().to_owned(),
|
||||||
|
msg.guild_id,
|
||||||
|
scope_id,
|
||||||
|
time_parser,
|
||||||
|
Some(interval_seconds as u32),
|
||||||
|
content).await
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Err(ReminderError::InvalidTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Err(ReminderError::NotEnoughArgs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
let content = args_iter.collect::<Vec<&str>>().join(" ");
|
||||||
|
|
||||||
|
create_reminder(
|
||||||
|
ctx,
|
||||||
|
pool,
|
||||||
|
msg.author.id.as_u64().to_owned(),
|
||||||
|
msg.guild_id,
|
||||||
|
scope_id,
|
||||||
|
time_parser,
|
||||||
|
None,
|
||||||
|
content).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let pool = ctx.data.read().await
|
let pool = ctx.data.read().await
|
||||||
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
|
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
|
||||||
|
|
||||||
user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
|
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
|
||||||
|
|
||||||
let mut args_iter = args.split(' ').filter(|s| s.len() > 0);
|
let mut args_iter = args.split(' ').filter(|s| s.len() > 0);
|
||||||
|
|
||||||
if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
|
|
||||||
|
|
||||||
let scope_id;
|
|
||||||
let mut time_parser = None;
|
let mut time_parser = None;
|
||||||
let content;
|
let mut scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned());
|
||||||
|
|
||||||
let guild_id = msg.guild_id;
|
// todo reimplement using next_if and Peekable
|
||||||
|
let response = if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
|
||||||
let response = if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER
|
if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER
|
||||||
.captures(&first_arg)
|
.captures(&first_arg)
|
||||||
.map(|cap| (cap.get(1), cap.get(2))) {
|
.map(|cap| (cap.get(1), cap.get(2))) {
|
||||||
|
|
||||||
@ -612,22 +674,20 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
|
|||||||
if let Some(next) = args_iter.next().map(|inner| inner.to_string()) {
|
if let Some(next) = args_iter.next().map(|inner| inner.to_string()) {
|
||||||
time_parser = Some(TimeParser::new(next, user_data.timezone.parse().unwrap()));
|
time_parser = Some(TimeParser::new(next, user_data.timezone.parse().unwrap()));
|
||||||
|
|
||||||
content = args_iter.collect::<Vec<&str>>().join(" ");
|
check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await
|
||||||
|
|
||||||
create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
Err(ReminderError::NotEnoughArgs)
|
Err(ReminderError::NotEnoughArgs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned());
|
time_parser = Some(TimeParser::new(first_arg, user_data.timezone()));
|
||||||
|
|
||||||
time_parser = Some(TimeParser::new(first_arg, user_data.timezone.parse().unwrap()));
|
check_interval(&ctx, msg, args_iter, &scope_id, &time_parser.as_ref().unwrap(), command, &pool).await
|
||||||
|
}
|
||||||
content = args_iter.collect::<Vec<&str>>().join(" ");
|
}
|
||||||
|
else {
|
||||||
create_reminder(ctx, msg.author.id.as_u64().to_owned(), guild_id, &scope_id, &time_parser.as_ref().unwrap(), content).await
|
Err(ReminderError::NotEnoughArgs)
|
||||||
};
|
};
|
||||||
|
|
||||||
let str_response = match response {
|
let str_response = match response {
|
||||||
@ -642,17 +702,18 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
|
|||||||
|
|
||||||
let _ = msg.channel_id.say(&ctx, &str_response).await;
|
let _ = msg.channel_id.say(&ctx, &str_response).await;
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
|
|
||||||
}
|
async fn create_reminder(
|
||||||
}
|
ctx: impl CacheHttp,
|
||||||
|
pool: &MySqlPool,
|
||||||
async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>, scope_id: &ReminderScope, time_parser: &TimeParser, content: String)
|
user_id: u64,
|
||||||
|
guild_id: Option<GuildId>,
|
||||||
|
scope_id: &ReminderScope,
|
||||||
|
time_parser: &TimeParser,
|
||||||
|
interval: Option<u32>,
|
||||||
|
content: String)
|
||||||
-> Result<(), ReminderError> {
|
-> Result<(), ReminderError> {
|
||||||
|
|
||||||
let pool = ctx.data.read().await
|
|
||||||
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
|
|
||||||
|
|
||||||
let db_channel_id = match scope_id {
|
let db_channel_id = match scope_id {
|
||||||
ReminderScope::User(user_id) => {
|
ReminderScope::User(user_id) => {
|
||||||
let user = UserId(*user_id).to_user(&ctx).await.unwrap();
|
let user = UserId(*user_id).to_user(&ctx).await.unwrap();
|
||||||
@ -673,7 +734,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>,
|
|||||||
|
|
||||||
if let Some(guild_channel) = channel.guild() {
|
if let Some(guild_channel) = channel.guild() {
|
||||||
if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() {
|
if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() {
|
||||||
if let Ok(webhook) = ctx.http.create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await {
|
if let Ok(webhook) = ctx.http().create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await {
|
||||||
channel_data.webhook_id = Some(webhook.id.as_u64().to_owned());
|
channel_data.webhook_id = Some(webhook.id.as_u64().to_owned());
|
||||||
channel_data.webhook_token = Some(webhook.token);
|
channel_data.webhook_token = Some(webhook.token);
|
||||||
|
|
||||||
@ -707,7 +768,7 @@ async fn create_reminder(ctx: &Context, user_id: u64, guild_id: Option<GuildId>,
|
|||||||
"
|
"
|
||||||
INSERT INTO messages (content) VALUES (?)
|
INSERT INTO messages (content) VALUES (?)
|
||||||
", content)
|
", content)
|
||||||
.execute(&pool)
|
.execute(&pool.clone())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -719,7 +780,7 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES
|
|||||||
?, ?, 'remind',
|
?, ?, 'remind',
|
||||||
(SELECT id FROM users WHERE user = ? LIMIT 1))
|
(SELECT id FROM users WHERE user = ? LIMIT 1))
|
||||||
", generate_uid(), content, db_channel_id, time as u32, user_id)
|
", generate_uid(), content, db_channel_id, time as u32, user_id)
|
||||||
.execute(&pool)
|
.execute(pool)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -737,13 +798,3 @@ INSERT INTO reminders (uid, message_id, channel_id, time, method, set_by) VALUES
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_uid() -> String {
|
|
||||||
let mut generator: OsRng = Default::default();
|
|
||||||
|
|
||||||
let mut bytes = vec![0u8, 64];
|
|
||||||
|
|
||||||
generator.fill_bytes(&mut bytes);
|
|
||||||
|
|
||||||
bytes.iter().map(|i| (CHARACTERS.as_bytes()[(i.to_owned() as usize) % CHARACTERS.len()] as char).to_string()).collect::<Vec<String>>().join("")
|
|
||||||
}
|
|
||||||
|
35
src/main.rs
35
src/main.rs
@ -7,10 +7,14 @@ mod commands;
|
|||||||
mod time_parser;
|
mod time_parser;
|
||||||
|
|
||||||
use serenity::{
|
use serenity::{
|
||||||
|
http::CacheHttp,
|
||||||
client::{
|
client::{
|
||||||
bridge::gateway::GatewayIntents,
|
bridge::gateway::GatewayIntents,
|
||||||
Client,
|
Client,
|
||||||
},
|
},
|
||||||
|
model::id::{
|
||||||
|
GuildId, UserId,
|
||||||
|
},
|
||||||
framework::Framework,
|
framework::Framework,
|
||||||
prelude::TypeMapKey,
|
prelude::TypeMapKey,
|
||||||
};
|
};
|
||||||
@ -122,3 +126,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
|
||||||
|
let role_ids = env::var("SUBSCRIPTION_ROLES")
|
||||||
|
.map(
|
||||||
|
|var| var
|
||||||
|
.split(",")
|
||||||
|
.filter_map(|item| {
|
||||||
|
item.parse::<u64>().ok()
|
||||||
|
})
|
||||||
|
.collect::<Vec<u64>>()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(subscription_guild) = env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten() {
|
||||||
|
if let Ok(role_ids) = role_ids {
|
||||||
|
// todo remove unwrap and propagate error
|
||||||
|
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await.unwrap();
|
||||||
|
|
||||||
|
for role in guild_member.roles {
|
||||||
|
if role_ids.contains(role.as_u64()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use serenity::{
|
use serenity::{
|
||||||
prelude::Context,
|
http::CacheHttp,
|
||||||
model::{
|
model::{
|
||||||
guild::Guild,
|
guild::Guild,
|
||||||
channel::Channel,
|
channel::Channel,
|
||||||
@ -141,7 +141,7 @@ pub struct UserData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl UserData {
|
impl UserData {
|
||||||
pub async fn from_user(user: &User, ctx: &&Context, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
|
pub async fn from_user(user: &User, ctx: impl CacheHttp, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
|
||||||
let user_id = user.id.as_u64().clone();
|
let user_id = user.id.as_u64().clone();
|
||||||
|
|
||||||
if let Ok(c) = sqlx::query_as_unchecked!(Self,
|
if let Ok(c) = sqlx::query_as_unchecked!(Self,
|
||||||
|
Loading…
Reference in New Issue
Block a user