8 Commits

18 changed files with 1130 additions and 874 deletions

4
Cargo.lock generated
View File

@ -1285,9 +1285,10 @@ dependencies = [
[[package]] [[package]]
name = "reminder_rs" name = "reminder_rs"
version = "1.5.0-2" version = "1.5.1"
dependencies = [ dependencies = [
"Inflector", "Inflector",
"base64 0.13.0",
"chrono", "chrono",
"chrono-tz", "chrono-tz",
"dashmap", "dashmap",
@ -1302,6 +1303,7 @@ dependencies = [
"regex", "regex",
"regex_command_attr", "regex_command_attr",
"reqwest", "reqwest",
"ring",
"serde", "serde",
"serde_json", "serde_json",
"serenity", "serenity",

View File

@ -1,6 +1,6 @@
[package] [package]
name = "reminder_rs" name = "reminder_rs"
version = "1.5.0-2" version = "1.5.1"
authors = ["jellywx <judesouthworth@pm.me>"] authors = ["jellywx <judesouthworth@pm.me>"]
edition = "2018" edition = "2018"
@ -25,6 +25,8 @@ levenshtein = "1.0"
# serenity = { version = "0.10", features = ["collector"] } # serenity = { version = "0.10", features = ["collector"] }
serenity = { path = "/home/jude/serenity", features = ["collector", "unstable_discord_api"] } serenity = { path = "/home/jude/serenity", features = ["collector", "unstable_discord_api"] }
sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]} sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]}
ring = "0.16"
base64 = "0.13.0"
[dependencies.regex_command_attr] [dependencies.regex_command_attr]
path = "./regex_command_attr" path = "./regex_command_attr"

View File

@ -1,18 +1,22 @@
use regex_command_attr::command; use regex_command_attr::command;
use serenity::{client::Context, model::channel::Message}; use serenity::{builder::CreateEmbedFooter, client::Context, model::channel::Message};
use chrono::offset::Utc; use chrono::offset::Utc;
use crate::{ use crate::{
command_help, consts::DEFAULT_PREFIX, get_ctx_data, language_manager::LanguageManager, command_help,
models::UserData, FrameworkCtx, THEME_COLOR, consts::DEFAULT_PREFIX,
get_ctx_data,
language_manager::LanguageManager,
models::{user_data::UserData, CtxGuildData},
FrameworkCtx, THEME_COLOR,
}; };
use crate::models::CtxGuildData; use std::{
use serenity::builder::CreateEmbedFooter; sync::Arc,
use std::sync::Arc; time::{SystemTime, UNIX_EPOCH},
use std::time::{SystemTime, UNIX_EPOCH}; };
#[command] #[command]
#[can_blacklist(false)] #[can_blacklist(false)]
@ -202,7 +206,6 @@ async fn clock(ctx: &Context, msg: &Message, _args: String) {
let language = UserData::language_of(&msg.author, &pool).await; let language = UserData::language_of(&msg.author, &pool).await;
let timezone = UserData::timezone_of(&msg.author, &pool).await; let timezone = UserData::timezone_of(&msg.author, &pool).await;
let meridian = UserData::meridian_of(&msg.author, &pool).await;
let now = Utc::now().with_timezone(&timezone); let now = Utc::now().with_timezone(&timezone);
@ -212,7 +215,7 @@ async fn clock(ctx: &Context, msg: &Message, _args: String) {
.channel_id .channel_id
.say( .say(
&ctx, &ctx,
clock_display.replacen("{}", &now.format(meridian.fmt_str()).to_string(), 1), clock_display.replacen("{}", &now.format("%H:%M").to_string(), 1),
) )
.await; .await;
} }

View File

@ -24,11 +24,10 @@ use crate::{
consts::{REGEX_ALIAS, REGEX_CHANNEL, REGEX_COMMANDS, REGEX_ROLE, THEME_COLOR}, consts::{REGEX_ALIAS, REGEX_CHANNEL, REGEX_COMMANDS, REGEX_ROLE, THEME_COLOR},
framework::SendIterator, framework::SendIterator,
get_ctx_data, get_ctx_data,
models::{ChannelData, GuildData, UserData}, models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxGuildData},
FrameworkCtx, PopularTimezones, FrameworkCtx, PopularTimezones,
}; };
use crate::models::CtxGuildData;
use std::{collections::HashMap, iter}; use std::{collections::HashMap, iter};
#[command] #[command]
@ -75,18 +74,16 @@ async fn blacklist(ctx: &Context, msg: &Message, args: String) {
.say(&ctx, lm.get(&language, "blacklist/added_from")) .say(&ctx, lm.get(&language, "blacklist/added_from"))
.await; .await;
} }
} else if local {
let _ = msg
.channel_id
.say(&ctx, lm.get(&language, "blacklist/removed"))
.await;
} else { } else {
if local { let _ = msg
let _ = msg .channel_id
.channel_id .say(&ctx, lm.get(&language, "blacklist/removed_from"))
.say(&ctx, lm.get(&language, "blacklist/removed")) .await;
.await;
} else {
let _ = msg
.channel_id
.say(&ctx, lm.get(&language, "blacklist/removed_from"))
.await;
}
} }
} }
@ -113,11 +110,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) {
let content = lm let content = lm
.get(&user_data.language, "timezone/set_p") .get(&user_data.language, "timezone/set_p")
.replacen("{timezone}", &user_data.timezone, 1) .replacen("{timezone}", &user_data.timezone, 1)
.replacen( .replacen("{time}", &now.format("%H:%M").to_string(), 1);
"{time}",
&now.format(user_data.meridian().fmt_str_short()).to_string(),
1,
);
let _ = let _ =
msg.channel_id msg.channel_id
@ -154,10 +147,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) {
tz.to_string(), tz.to_string(),
format!( format!(
"🕗 `{}`", "🕗 `{}`",
Utc::now() Utc::now().with_timezone(tz).format("%H:%M").to_string()
.with_timezone(tz)
.format(user_data.meridian().fmt_str_short())
.to_string()
), ),
true, true,
) )
@ -211,10 +201,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) {
t.to_string(), t.to_string(),
format!( format!(
"🕗 `{}`", "🕗 `{}`",
Utc::now() Utc::now().with_timezone(t).format("%H:%M").to_string()
.with_timezone(t)
.format(user_data.meridian().fmt_str_short())
.to_string()
), ),
true, true,
) )
@ -252,49 +239,6 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) {
} }
} }
#[command("meridian")]
async fn change_meridian(ctx: &Context, msg: &Message, args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
if &args == "12" {
user_data.meridian_time = true;
user_data.commit_changes(&pool).await;
let _ = msg
.channel_id
.send_message(&ctx, |m| {
m.embed(|e| {
e.title(lm.get(&user_data.language, "meridian/title"))
.color(*THEME_COLOR)
.description(lm.get(&user_data.language, "meridian/12"))
})
})
.await;
} else if &args == "24" {
user_data.meridian_time = false;
user_data.commit_changes(&pool).await;
let _ = msg
.channel_id
.send_message(&ctx, |m| {
m.embed(|e| {
e.title(lm.get(&user_data.language, "meridian/title"))
.color(*THEME_COLOR)
.description(lm.get(&user_data.language, "meridian/24"))
})
})
.await;
} else {
let prefix = ctx.prefix(msg.guild_id).await;
command_help(ctx, msg, lm, &prefix, &user_data.language, "meridian").await;
}
}
#[command("lang")] #[command("lang")]
async fn language(ctx: &Context, msg: &Message, args: String) { async fn language(ctx: &Context, msg: &Message, args: String) {
let (pool, lm) = get_ctx_data(&ctx).await; let (pool, lm) = get_ctx_data(&ctx).await;

View File

@ -1,14 +1,14 @@
use regex_command_attr::command; use regex_command_attr::command;
use serenity::{ use serenity::{
cache::Cache,
client::Context, client::Context,
http::CacheHttp, http::CacheHttp,
model::{ model::{
channel::GuildChannel,
channel::Message, channel::Message,
channel::{Channel, GuildChannel},
guild::Guild, guild::Guild,
id::{ChannelId, GuildId, UserId}, id::{ChannelId, GuildId, UserId},
interactions::ButtonStyle,
misc::Mentionable, misc::Mentionable,
webhook::Webhook, webhook::Webhook,
}, },
@ -18,17 +18,23 @@ use serenity::{
use crate::{ use crate::{
check_subscription_on_message, command_help, check_subscription_on_message, command_help,
consts::{ consts::{
CHARACTERS, DAY, HOUR, MAX_TIME, MINUTE, MIN_INTERVAL, REGEX_CHANNEL, REGEX_CHANNEL_USER, CHARACTERS, MAX_TIME, MIN_INTERVAL, REGEX_CHANNEL_USER, REGEX_CONTENT_SUBSTITUTION,
REGEX_CONTENT_SUBSTITUTION, REGEX_NATURAL_COMMAND_1, REGEX_NATURAL_COMMAND_2, REGEX_NATURAL_COMMAND_1, REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND, THEME_COLOR,
REGEX_REMIND_COMMAND, THEME_COLOR,
}, },
framework::SendIterator, framework::SendIterator,
get_ctx_data, get_ctx_data,
models::{ChannelData, GuildData, Timer, UserData}, models::{
channel_data::ChannelData,
guild_data::GuildData,
reminder::{LookFlags, Reminder, ReminderAction},
timer::Timer,
user_data::UserData,
CtxGuildData,
},
time_parser::{natural_parser, TimeParser}, time_parser::{natural_parser, TimeParser},
}; };
use chrono::{offset::TimeZone, NaiveDateTime}; use chrono::NaiveDateTime;
use rand::{rngs::OsRng, seq::IteratorRandom}; use rand::{rngs::OsRng, seq::IteratorRandom};
@ -40,33 +46,13 @@ use std::{
collections::HashSet, collections::HashSet,
convert::TryInto, convert::TryInto,
default::Default, default::Default,
env,
fmt::Display, fmt::Display,
string::ToString, string::ToString,
time::{SystemTime, UNIX_EPOCH}, time::{SystemTime, UNIX_EPOCH},
}; };
use crate::models::CtxGuildData;
use regex::Captures; use regex::Captures;
use serenity::model::channel::Channel;
fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let mut sections = vec![];
for (var, name) in [days, hours, minutes, seconds]
.iter()
.zip(["days", "hours", "minutes", "seconds"].iter())
{
if *var > 0 {
sections.push(format!("{} {}", var, name));
}
}
sections.join(", ")
}
async fn create_webhook( async fn create_webhook(
ctx: impl CacheHttp, ctx: impl CacheHttp,
@ -100,7 +86,6 @@ async fn pause(ctx: &Context, msg: &Message, args: String) {
let language = UserData::language_of(&msg.author, &pool).await; let language = UserData::language_of(&msg.author, &pool).await;
let timezone = UserData::timezone_of(&msg.author, &pool).await; let timezone = UserData::timezone_of(&msg.author, &pool).await;
let meridian = UserData::meridian_of(&msg.author, &pool).await;
let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool) let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool)
.await .await
@ -136,13 +121,9 @@ async fn pause(ctx: &Context, msg: &Message, args: String) {
channel.commit_changes(&pool).await; channel.commit_changes(&pool).await;
let content = lm.get(&language, "pause/paused_until").replace( let content = lm
"{}", .get(&language, "pause/paused_until")
&timezone .replace("{}", &format!("<t:{}:D>", timestamp));
.timestamp(timestamp, 0)
.format(meridian.fmt_str())
.to_string(),
);
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
} }
@ -273,108 +254,6 @@ async fn nudge(ctx: &Context, msg: &Message, args: String) {
} }
} }
enum TimeDisplayType {
Absolute,
Relative,
}
struct LookFlags {
pub limit: u16,
pub show_disabled: bool,
pub channel_id: Option<u64>,
time_display: TimeDisplayType,
}
impl Default for LookFlags {
fn default() -> Self {
Self {
limit: u16::MAX,
show_disabled: true,
channel_id: None,
time_display: TimeDisplayType::Relative,
}
}
}
impl LookFlags {
fn from_string(args: &str) -> Self {
let mut new_flags: Self = Default::default();
for arg in args.split(' ') {
match arg {
"enabled" => {
new_flags.show_disabled = false;
}
"time" => {
new_flags.time_display = TimeDisplayType::Absolute;
}
param => {
if let Ok(val) = param.parse::<u16>() {
new_flags.limit = val;
} else {
if let Some(channel) = REGEX_CHANNEL
.captures(&arg)
.map(|cap| cap.get(1))
.flatten()
.map(|c| c.as_str().parse::<u64>().unwrap())
{
new_flags.channel_id = Some(channel);
}
}
}
}
}
new_flags
}
}
struct LookReminder {
id: u32,
time: NaiveDateTime,
interval: Option<u32>,
channel: u64,
content: String,
description: String,
}
impl LookReminder {
fn display_content(&self) -> String {
if self.content.len() > 0 {
self.content.clone()
} else {
self.description.clone()
}
}
fn display(&self, flags: &LookFlags, inter: &str) -> String {
let time_display = match flags.time_display {
TimeDisplayType::Absolute => format!("<t:{}>", self.time.timestamp()),
TimeDisplayType::Relative => format!("<t:{}:R>", self.time.timestamp()),
};
if let Some(interval) = self.interval {
format!(
"'{}' *{}* **{}**, repeating every **{}**",
self.display_content(),
&inter,
time_display,
longhand_displacement(interval as u64)
)
} else {
format!(
"'{}' *{}* **{}**",
self.display_content(),
&inter,
time_display
)
}
}
}
#[command("look")] #[command("look")]
#[permission_level(Managed)] #[permission_level(Managed)]
async fn look(ctx: &Context, msg: &Message, args: String) { async fn look(ctx: &Context, msg: &Message, args: String) {
@ -384,53 +263,19 @@ async fn look(ctx: &Context, msg: &Message, args: String) {
let flags = LookFlags::from_string(&args); let flags = LookFlags::from_string(&args);
let enabled = if flags.show_disabled { "0,1" } else { "1" };
let channel_opt = msg.channel_id.to_channel_cached(&ctx).await; let channel_opt = msg.channel_id.to_channel_cached(&ctx).await;
let channel_id = if let Some(Channel::Guild(channel)) = channel_opt { let channel_id = if let Some(Channel::Guild(channel)) = channel_opt {
if Some(channel.guild_id) == msg.guild_id { if Some(channel.guild_id) == msg.guild_id {
flags flags.channel_id.unwrap_or(msg.channel_id)
.channel_id
.unwrap_or_else(|| msg.channel_id.as_u64().to_owned())
} else { } else {
msg.channel_id.as_u64().to_owned() msg.channel_id
} }
} else { } else {
msg.channel_id.as_u64().to_owned() msg.channel_id
}; };
let reminders = sqlx::query_as!( let reminders = Reminder::from_channel(ctx, channel_id, &flags).await;
LookReminder,
"
SELECT
reminders.id,
reminders.utc_time AS time,
reminders.interval,
channels.channel,
reminders.content,
reminders.embed_description AS description
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
WHERE
channels.channel = ? AND
FIND_IN_SET(reminders.enabled, ?)
ORDER BY
reminders.utc_time
LIMIT
?
",
channel_id,
enabled,
flags.limit
)
.fetch_all(&pool)
.await
.unwrap();
if reminders.is_empty() { if reminders.is_empty() {
let _ = msg let _ = msg
@ -460,95 +305,10 @@ async fn delete(ctx: &Context, msg: &Message, _args: String) {
.say(&ctx, lm.get(&user_data.language, "del/listing")) .say(&ctx, lm.get(&user_data.language, "del/listing"))
.await; .await;
let reminders = if let Some(guild_id) = msg.guild_id {
let guild_opt = guild_id.to_guild_cached(&ctx).await;
if let Some(guild) = guild_opt {
let channels = guild
.channels
.keys()
.into_iter()
.map(|k| k.as_u64().to_string())
.collect::<Vec<String>>()
.join(",");
sqlx::query_as_unchecked!(
LookReminder,
"
SELECT
reminders.id,
reminders.utc_time AS time,
reminders.interval,
channels.channel,
reminders.content,
reminders.embed_description AS description
FROM
reminders
LEFT OUTER JOIN
channels
ON
channels.id = reminders.channel_id
WHERE
FIND_IN_SET(channels.channel, ?)
",
channels
)
.fetch_all(&pool)
.await
} else {
sqlx::query_as_unchecked!(
LookReminder,
"
SELECT
reminders.id,
reminders.utc_time AS time,
reminders.interval,
channels.channel,
reminders.content,
reminders.embed_description AS description
FROM
reminders
LEFT OUTER JOIN
channels
ON
channels.id = reminders.channel_id
WHERE
channels.guild_id = (SELECT id FROM guilds WHERE guild = ?)
",
guild_id.as_u64()
)
.fetch_all(&pool)
.await
}
} else {
sqlx::query_as!(
LookReminder,
"
SELECT
reminders.id,
reminders.utc_time AS time,
reminders.interval,
channels.channel,
reminders.content,
reminders.embed_description AS description
FROM
reminders
INNER JOIN
channels
ON
channels.id = reminders.channel_id
WHERE
channels.channel = ?
",
msg.channel_id.as_u64()
)
.fetch_all(&pool)
.await
}
.unwrap();
let mut reminder_ids: Vec<u32> = vec![]; let mut reminder_ids: Vec<u32> = vec![];
let reminders = Reminder::from_guild(ctx, msg.guild_id, msg.author.id).await;
let enumerated_reminders = reminders.iter().enumerate().map(|(count, reminder)| { let enumerated_reminders = reminders.iter().enumerate().map(|(count, reminder)| {
reminder_ids.push(reminder.id); reminder_ids.push(reminder.id);
@ -557,7 +317,7 @@ WHERE
count + 1, count + 1,
reminder.display_content(), reminder.display_content(),
reminder.channel, reminder.channel,
reminder.time.timestamp() reminder.utc_time.timestamp()
) )
}); });
@ -801,7 +561,6 @@ impl ReminderScope {
#[derive(PartialEq, Eq, Hash, Debug)] #[derive(PartialEq, Eq, Hash, Debug)]
enum ReminderError { enum ReminderError {
LongTime,
LongInterval, LongInterval,
PastTime, PastTime,
ShortInterval, ShortInterval,
@ -828,7 +587,6 @@ trait ToResponse {
impl ToResponse for ReminderError { impl ToResponse for ReminderError {
fn to_response(&self) -> &'static str { fn to_response(&self) -> &'static str {
match self { match self {
Self::LongTime => "remind/long_time",
Self::LongInterval => "interval/long_interval", Self::LongInterval => "interval/long_interval",
Self::PastTime => "remind/past_time", Self::PastTime => "remind/past_time",
Self::ShortInterval => "interval/short_interval", Self::ShortInterval => "interval/short_interval",
@ -841,7 +599,6 @@ impl ToResponse for ReminderError {
fn to_response_natural(&self) -> &'static str { fn to_response_natural(&self) -> &'static str {
match self { match self {
Self::LongTime => "natural/long_time",
Self::InvalidTime => "natural/invalid_time", Self::InvalidTime => "natural/invalid_time",
_ => self.to_response(), _ => self.to_response(),
} }
@ -925,7 +682,7 @@ impl Content {
Ok(Self { Ok(Self {
content: content.to_string(), content: content.to_string(),
tts: false, tts: false,
attachment: Some(attachment_bytes.clone()), attachment: Some(attachment_bytes),
attachment_name: Some(attachment.filename.clone()), attachment_name: Some(attachment.filename.clone()),
}) })
} else { } else {
@ -1128,7 +885,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
Some(captures) => { Some(captures) => {
let parsed = parse_mention_list(captures.name("mentions").unwrap().as_str()); let parsed = parse_mention_list(captures.name("mentions").unwrap().as_str());
let scopes = if parsed.len() == 0 { let scopes = if parsed.is_empty() {
vec![ReminderScope::Channel(msg.channel_id.into())] vec![ReminderScope::Channel(msg.channel_id.into())]
} else { } else {
parsed parsed
@ -1167,6 +924,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
match content_res { match content_res {
Ok(mut content) => { Ok(mut content) => {
let mut ok_locations = vec![]; let mut ok_locations = vec![];
let mut ok_reminders = vec![];
let mut err_locations = vec![]; let mut err_locations = vec![];
let mut err_types = HashSet::new(); let mut err_types = HashSet::new();
@ -1178,17 +936,22 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
msg.guild_id, msg.guild_id,
&scope, &scope,
&time_parser, &time_parser,
expires_parser.as_ref().clone(), expires_parser.as_ref(),
interval, interval,
&mut content, &mut content,
) )
.await; .await;
if let Err(e) = res { match res {
err_locations.push(scope); Err(e) => {
err_types.insert(e); err_locations.push(scope);
} else { err_types.insert(e);
ok_locations.push(scope); }
Ok(id) => {
ok_locations.push(scope);
ok_reminders.push(id);
}
} }
} }
@ -1265,6 +1028,22 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
.description(format!("{}\n\n{}", success_part, error_part)) .description(format!("{}\n\n{}", success_part, error_part))
.color(*THEME_COLOR) .color(*THEME_COLOR)
}) })
.components(|c| {
if ok_locations.len() == 1 {
c.create_action_row(|r| {
r.create_button(|b| {
b.style(ButtonStyle::Danger)
.label("Delete")
.custom_id(ok_reminders[0].signed_action(
msg.author.id,
ReminderAction::Delete,
))
})
});
}
c
})
}) })
.await; .await;
} }
@ -1397,7 +1176,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
&scope, &scope,
timestamp, timestamp,
expires, expires,
interval.clone(), interval,
&mut content, &mut content,
) )
.await; .await;
@ -1529,7 +1308,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
} }
async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>( async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
ctx: impl CacheHttp + AsRef<Cache>, ctx: &Context,
pool: &MySqlPool, pool: &MySqlPool,
user_id: U, user_id: U,
guild_id: Option<GuildId>, guild_id: Option<GuildId>,
@ -1538,7 +1317,7 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
expires_parser: Option<T>, expires_parser: Option<T>,
interval: Option<i64>, interval: Option<i64>,
content: &mut Content, content: &mut Content,
) -> Result<(), ReminderError> { ) -> Result<Reminder, ReminderError> {
let user_id = user_id.into(); let user_id = user_id.into();
if let Some(g_id) = guild_id { if let Some(g_id) = guild_id {
@ -1551,11 +1330,19 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
let db_channel_id = match scope_id { let db_channel_id = match scope_id {
ReminderScope::User(user_id) => { ReminderScope::User(user_id) => {
let user = UserId(*user_id).to_user(&ctx).await.unwrap(); if let Ok(user) = UserId(*user_id).to_user(&ctx).await {
let user_data = UserData::from_user(&user, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&user, &ctx, &pool).await.unwrap(); if let Some(guild_id) = guild_id {
if guild_id.member(&ctx, user).await.is_err() {
return Err(ReminderError::InvalidTag);
}
}
user_data.dm_channel user_data.dm_channel
} else {
return Err(ReminderError::InvalidTag);
}
} }
ReminderScope::Channel(channel_id) => { ReminderScope::Channel(channel_id) => {
@ -1610,11 +1397,10 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
.as_secs() as i64; .as_secs() as i64;
if time >= unix_time - 10 { if time >= unix_time - 10 {
if time > unix_time + *MAX_TIME { let uid = generate_uid();
Err(ReminderError::LongTime)
} else { sqlx::query!(
sqlx::query!( "
"
INSERT INTO reminders ( INSERT INTO reminders (
uid, uid,
content, content,
@ -1639,23 +1425,24 @@ INSERT INTO reminders (
(SELECT id FROM users WHERE user = ? LIMIT 1) (SELECT id FROM users WHERE user = ? LIMIT 1)
) )
", ",
generate_uid(), uid,
content.content, content.content,
content.tts, content.tts,
content.attachment, content.attachment,
content.attachment_name, content.attachment_name,
db_channel_id, db_channel_id,
time as u32, time,
expires, expires,
interval, interval,
user_id user_id
) )
.execute(pool) .execute(pool)
.await .await
.unwrap(); .unwrap();
Ok(()) let reminder = Reminder::from_uid(ctx, uid).await.unwrap();
}
Ok(reminder)
} else if time < 0 { } else if time < 0 {
// case required for if python returns -1 // case required for if python returns -1
Err(ReminderError::InvalidTime) Err(ReminderError::InvalidTime)

View File

@ -12,8 +12,10 @@ use serenity::{
use std::fmt; use std::fmt;
use crate::models::CtxGuildData; use crate::{
use crate::{command_help, get_ctx_data, models::UserData}; command_help, get_ctx_data,
models::{user_data::UserData, CtxGuildData},
};
use sqlx::MySqlPool; use sqlx::MySqlPool;
use std::convert::TryFrom; use std::convert::TryFrom;

View File

@ -51,7 +51,7 @@ lazy_static! {
.split(',') .split(',')
.filter_map(|item| { item.parse::<u64>().ok() }) .filter_map(|item| { item.parse::<u64>().ok() })
.collect::<Vec<u64>>()) .collect::<Vec<u64>>())
.unwrap_or_else(|_| vec![]) .unwrap_or_else(|_| Vec::new())
); );
pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD") pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD")

View File

@ -8,7 +8,7 @@ use serenity::{
model::{ model::{
channel::{Channel, GuildChannel, Message}, channel::{Channel, GuildChannel, Message},
guild::{Guild, Member}, guild::{Guild, Member},
id::ChannelId, id::{ChannelId, MessageId},
}, },
Result as SerenityResult, Result as SerenityResult,
}; };
@ -19,10 +19,11 @@ use regex::{Match, Regex, RegexBuilder};
use std::{collections::HashMap, fmt}; use std::{collections::HashMap, fmt};
use crate::language_manager::LanguageManager; use crate::{
use crate::models::{CtxGuildData, GuildData, UserData}; language_manager::LanguageManager,
use crate::{models::ChannelData, LimitExecutors, SQLPool}; models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxGuildData},
use serenity::model::id::MessageId; LimitExecutors, SQLPool,
};
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, ()>; type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, ()>;
@ -239,7 +240,7 @@ impl RegexFramework {
let mut command_names_vec = let mut command_names_vec =
self.commands.keys().map(|k| &k[..]).collect::<Vec<&str>>(); self.commands.keys().map(|k| &k[..]).collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len())); command_names_vec.sort_unstable_by_key(|a| a.len());
command_names = command_names_vec.join("|"); command_names = command_names_vec.join("|");
} }
@ -275,7 +276,7 @@ impl RegexFramework {
}) })
.collect::<Vec<&str>>(); .collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len())); command_names_vec.sort_unstable_by_key(|a| a.len());
dm_command_names = command_names_vec.join("|"); dm_command_names = command_names_vec.join("|");
} }
@ -398,12 +399,14 @@ impl Framework for RegexFramework {
{ {
let guild_id = guild.id.as_u64().to_owned(); let guild_id = guild.id.as_u64().to_owned();
GuildData::from_guild(guild, &pool).await.expect( GuildData::from_guild(guild, &pool)
&format!( .await
.unwrap_or_else(|_| {
panic!(
"Failed to create new guild object for {}", "Failed to create new guild object for {}",
guild_id guild_id
), )
); });
} }
if msg.id == MessageId(0) if msg.id == MessageId(0)

View File

@ -14,7 +14,7 @@ pub struct LanguageManager {
impl LanguageManager { impl LanguageManager {
pub fn from_compiled(content: &'static str) -> Result<Self, Box<dyn Error + Send + Sync>> { pub fn from_compiled(content: &'static str) -> Result<Self, Box<dyn Error + Send + Sync>> {
let new: Self = from_str(content.as_ref())?; let new: Self = from_str(content)?;
Ok(new) Ok(new)
} }
@ -23,13 +23,13 @@ impl LanguageManager {
self.strings self.strings
.get(language) .get(language)
.map(|sm| sm.get(name)) .map(|sm| sm.get(name))
.expect(&format!(r#"Language does not exist: "{}""#, language)) .unwrap_or_else(|| panic!(r#"Language does not exist: "{}""#, language))
.unwrap_or_else(|| { .unwrap_or_else(|| {
self.strings self.strings
.get(&*LOCAL_LANGUAGE) .get(&*LOCAL_LANGUAGE)
.map(|sm| { .map(|sm| {
sm.get(name) sm.get(name)
.expect(&format!(r#"String does not exist: "{}""#, name)) .unwrap_or_else(|| panic!(r#"String does not exist: "{}""#, name))
}) })
.expect("LOCAL_LANGUAGE is not available") .expect("LOCAL_LANGUAGE is not available")
}) })

View File

@ -36,7 +36,7 @@ use crate::{
consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR}, consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR},
framework::RegexFramework, framework::RegexFramework,
language_manager::LanguageManager, language_manager::LanguageManager,
models::GuildData, models::{guild_data::GuildData, user_data::UserData},
}; };
use inflector::Inflector; use inflector::Inflector;
@ -46,7 +46,7 @@ use dashmap::DashMap;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::models::UserData; use crate::models::reminder::{Reminder, ReminderAction};
use chrono::Utc; use chrono::Utc;
use chrono_tz::Tz; use chrono_tz::Tz;
use serenity::model::prelude::{ use serenity::model::prelude::{
@ -179,10 +179,11 @@ DELETE FROM channels WHERE channel = ?
.cloned() .cloned()
.expect("Could not get SQLPool from data"); .expect("Could not get SQLPool from data");
GuildData::from_guild(guild, &pool).await.expect(&format!( GuildData::from_guild(guild, &pool)
"Failed to create new guild object for {}", .await
guild_id .unwrap_or_else(|_| {
)); panic!("Failed to create new guild object for {}", guild_id)
});
} }
if let Ok(token) = env::var("DISCORDBOTS_TOKEN") { if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
@ -230,7 +231,12 @@ DELETE FROM channels WHERE channel = ?
} }
} }
async fn guild_delete(&self, ctx: Context, guild: GuildUnavailable, _guild: Option<Guild>) { async fn guild_delete(
&self,
ctx: Context,
deleted_guild: GuildUnavailable,
_guild: Option<Guild>,
) {
let pool = ctx let pool = ctx
.data .data
.read() .read()
@ -246,13 +252,13 @@ DELETE FROM channels WHERE channel = ?
.get::<GuildDataCache>() .get::<GuildDataCache>()
.cloned() .cloned()
.unwrap(); .unwrap();
guild_data_cache.remove(&guild.id); guild_data_cache.remove(&deleted_guild.id);
sqlx::query!( sqlx::query!(
" "
DELETE FROM guilds WHERE guild = ? DELETE FROM guilds WHERE guild = ?
", ",
guild.id.as_u64() deleted_guild.id.as_u64()
) )
.execute(&pool) .execute(&pool)
.await .await
@ -268,8 +274,6 @@ DELETE FROM guilds WHERE guild = ?
if let (Some(InteractionData::MessageComponent(data)), Some(member)) = if let (Some(InteractionData::MessageComponent(data)), Some(member)) =
(interaction.clone().data, interaction.clone().member) (interaction.clone().data, interaction.clone().member)
{ {
println!("{}", data.custom_id);
if data.custom_id.starts_with("timezone:") { if data.custom_id.starts_with("timezone:") {
let mut user_data = UserData::from_user(&member.user, &ctx, &pool) let mut user_data = UserData::from_user(&member.user, &ctx, &pool)
.await .await
@ -296,7 +300,7 @@ DELETE FROM guilds WHERE guild = ?
.replacen("{timezone}", &user_data.timezone, 1) .replacen("{timezone}", &user_data.timezone, 1)
.replacen( .replacen(
"{time}", "{time}",
&now.format(user_data.meridian().fmt_str_short()).to_string(), &now.format("%H:%M").to_string(),
1, 1,
); );
@ -333,10 +337,45 @@ DELETE FROM guilds WHERE guild = ?
lm.get(&user_data.language, "lang/set_p"), lm.get(&user_data.language, "lang/set_p"),
) )
}) })
.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
}) })
}) })
.await; .await;
} }
} else {
match Reminder::from_interaction(&ctx, member.user.id, data.custom_id).await
{
Ok((reminder, action)) => {
let response = match action {
ReminderAction::Delete => {
reminder.delete(&ctx).await;
"Reminder has been deleted"
}
};
let _ = interaction
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| d
.content(response)
.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
)
})
.await;
}
Err(ie) => {
let _ = interaction
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| d
.content(ie.to_string())
.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
)
})
.await;
}
}
} }
} }
} }
@ -401,7 +440,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND) .add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND)
.add_command("restrict", &moderation_cmds::RESTRICT_COMMAND) .add_command("restrict", &moderation_cmds::RESTRICT_COMMAND)
.add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND) .add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND)
.add_command("meridian", &moderation_cmds::CHANGE_MERIDIAN_COMMAND)
.add_command("prefix", &moderation_cmds::PREFIX_COMMAND) .add_command("prefix", &moderation_cmds::PREFIX_COMMAND)
.add_command("lang", &moderation_cmds::LANGUAGE_COMMAND) .add_command("lang", &moderation_cmds::LANGUAGE_COMMAND)
.add_command("pause", &reminder_cmds::PAUSE_COMMAND) .add_command("pause", &reminder_cmds::PAUSE_COMMAND)

View File

@ -1,452 +0,0 @@
use serenity::{
async_trait,
http::CacheHttp,
model::{
channel::Channel,
guild::Guild,
id::{GuildId, UserId},
user::User,
},
prelude::Context,
};
use sqlx::MySqlPool;
use chrono::NaiveDateTime;
use chrono_tz::Tz;
use log::error;
use crate::{
consts::{DEFAULT_PREFIX, LOCAL_LANGUAGE, LOCAL_TIMEZONE},
GuildDataCache, SQLPool,
};
use std::sync::Arc;
use tokio::sync::RwLock;
#[async_trait]
pub trait CtxGuildData {
async fn guild_data<G: Into<GuildId> + Send + Sync>(
&self,
guild_id: G,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error>;
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String;
}
#[async_trait]
impl CtxGuildData for Context {
async fn guild_data<G: Into<GuildId> + Send + Sync>(
&self,
guild_id: G,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error> {
let guild_id = guild_id.into();
let guild = guild_id.to_guild_cached(&self.cache).await.unwrap();
let guild_cache = self
.data
.read()
.await
.get::<GuildDataCache>()
.cloned()
.unwrap();
let x = if let Some(guild_data) = guild_cache.get(&guild_id) {
Ok(guild_data.clone())
} else {
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
match GuildData::from_guild(guild, &pool).await {
Ok(d) => {
let lock = Arc::new(RwLock::new(d));
guild_cache.insert(guild_id, lock.clone());
Ok(lock)
}
Err(e) => Err(e),
}
};
x
}
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String {
if let Some(guild_id) = guild_id {
self.guild_data(guild_id)
.await
.unwrap()
.read()
.await
.prefix
.clone()
} else {
DEFAULT_PREFIX.clone()
}
}
}
pub struct GuildData {
pub id: u32,
pub name: Option<String>,
pub prefix: String,
}
impl GuildData {
pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result<Self, sqlx::Error> {
let guild_id = guild.id.as_u64().to_owned();
match sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
",
guild_id
)
.fetch_one(pool)
.await
{
Ok(mut g) => {
g.name = Some(guild.name);
Ok(g)
}
Err(sqlx::Error::RowNotFound) => {
sqlx::query!(
"
INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?)
",
guild_id,
guild.name,
*DEFAULT_PREFIX
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
",
guild_id
)
.fetch_one(pool)
.await?)
}
Err(e) => {
error!("Unexpected error in guild query: {:?}", e);
Err(e)
}
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE guilds SET name = ?, prefix = ? WHERE id = ?
",
self.name,
self.prefix,
self.id
)
.execute(pool)
.await
.unwrap();
}
}
pub struct ChannelData {
pub id: u32,
pub name: Option<String>,
pub nudge: i16,
pub blacklisted: bool,
pub webhook_id: Option<u64>,
pub webhook_token: Option<String>,
pub paused: bool,
pub paused_until: Option<NaiveDateTime>,
}
impl ChannelData {
pub async fn from_channel(
channel: Channel,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel.id().as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id)
.fetch_one(pool)
.await {
Ok(c)
}
else {
let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
let (guild_id, channel_name) = if let Some((a, b)) = props {
(Some(a), Some(b))
} else {
(None, None)
};
sqlx::query!(
"
INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))
", channel_id, channel_name, guild_id)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as_unchecked!(Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id)
.fetch_one(pool)
.await?)
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ?
", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id)
.execute(pool)
.await.unwrap();
}
}
pub struct UserData {
pub id: u32,
pub user: u64,
pub name: String,
pub dm_channel: u32,
pub language: String,
pub timezone: String,
pub meridian_time: bool,
}
pub struct MeridianType(bool);
impl MeridianType {
pub fn fmt_str(&self) -> &str {
if self.0 {
"%Y-%m-%d %I:%M:%S %p"
} else {
"%Y-%m-%d %H:%M:%S"
}
}
pub fn fmt_str_short(&self) -> &str {
if self.0 {
"%I:%M %p"
} else {
"%H:%M"
}
}
}
impl UserData {
pub async fn language_of<U>(user: U, pool: &MySqlPool) -> String
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT language FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => r.language,
Err(_) => LOCAL_LANGUAGE.clone(),
}
}
pub async fn timezone_of<U>(user: U, pool: &MySqlPool) -> Tz
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT timezone FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => r.timezone,
Err(_) => LOCAL_TIMEZONE.clone(),
}
.parse()
.unwrap()
}
pub async fn meridian_of<U>(user: U, pool: &MySqlPool) -> MeridianType
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT meridian_time FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => MeridianType(r.meridian_time != 0),
Err(_) => MeridianType(false),
}
}
pub async fn from_user(
user: &User,
ctx: impl CacheHttp,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().to_owned();
match sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, IF(language IS NULL, ?, language) AS language, IF(timezone IS NULL, ?, timezone) AS timezone, meridian_time FROM users WHERE user = ?
",
*LOCAL_LANGUAGE, *LOCAL_TIMEZONE, user_id
)
.fetch_one(pool)
.await
{
Ok(c) => Ok(c),
Err(sqlx::Error::RowNotFound) => {
let dm_channel = user.create_dm_channel(ctx).await?;
let dm_id = dm_channel.id.as_u64().to_owned();
let pool_c = pool.clone();
sqlx::query!(
"
INSERT IGNORE INTO channels (channel) VALUES (?)
",
dm_id
)
.execute(&pool_c)
.await?;
sqlx::query!(
"
INSERT INTO users (user, name, dm_channel, language, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?, ?)
", user_id, user.name, dm_id, *LOCAL_LANGUAGE, *LOCAL_TIMEZONE)
.execute(&pool_c)
.await?;
Ok(sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, language, timezone, meridian_time FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await?)
}
Err(e) => {
error!("Error querying for user: {:?}", e);
Err(Box::new(e))
},
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE users SET name = ?, language = ?, timezone = ?, meridian_time = ? WHERE id = ?
",
self.name,
self.language,
self.timezone,
self.meridian_time,
self.id
)
.execute(pool)
.await
.unwrap();
}
pub fn timezone(&self) -> Tz {
self.timezone.parse().unwrap()
}
pub fn meridian(&self) -> MeridianType {
MeridianType(self.meridian_time)
}
}
pub struct Timer {
pub name: String,
pub start_time: NaiveDateTime,
pub owner: u64,
}
impl Timer {
pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> {
sqlx::query_as_unchecked!(
Timer,
"
SELECT name, start_time, owner FROM timers WHERE owner = ?
",
owner
)
.fetch_all(pool)
.await
.unwrap()
}
pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 {
sqlx::query!(
"
SELECT COUNT(1) as count FROM timers WHERE owner = ?
",
owner
)
.fetch_one(pool)
.await
.unwrap()
.count as u32
}
pub async fn create(name: &str, owner: u64, pool: &MySqlPool) {
sqlx::query!(
"
INSERT INTO timers (name, owner) VALUES (?, ?)
",
name,
owner
)
.execute(pool)
.await
.unwrap();
}
}

View File

@ -0,0 +1,67 @@
use serenity::model::channel::Channel;
use sqlx::MySqlPool;
use chrono::NaiveDateTime;
pub struct ChannelData {
pub id: u32,
pub name: Option<String>,
pub nudge: i16,
pub blacklisted: bool,
pub webhook_id: Option<u64>,
pub webhook_token: Option<String>,
pub paused: bool,
pub paused_until: Option<NaiveDateTime>,
}
impl ChannelData {
pub async fn from_channel(
channel: Channel,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel.id().as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id)
.fetch_one(pool)
.await {
Ok(c)
}
else {
let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
let (guild_id, channel_name) = if let Some((a, b)) = props {
(Some(a), Some(b))
} else {
(None, None)
};
sqlx::query!(
"
INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))
", channel_id, channel_name, guild_id)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as_unchecked!(Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id)
.fetch_one(pool)
.await?)
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ?
", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id)
.execute(pool)
.await.unwrap();
}
}

79
src/models/guild_data.rs Normal file
View File

@ -0,0 +1,79 @@
use serenity::model::guild::Guild;
use sqlx::MySqlPool;
use log::error;
use crate::consts::DEFAULT_PREFIX;
pub struct GuildData {
pub id: u32,
pub name: Option<String>,
pub prefix: String,
}
impl GuildData {
pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result<Self, sqlx::Error> {
let guild_id = guild.id.as_u64().to_owned();
match sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
",
guild_id
)
.fetch_one(pool)
.await
{
Ok(mut g) => {
g.name = Some(guild.name);
Ok(g)
}
Err(sqlx::Error::RowNotFound) => {
sqlx::query!(
"
INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?)
",
guild_id,
guild.name,
*DEFAULT_PREFIX
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
",
guild_id
)
.fetch_one(pool)
.await?)
}
Err(e) => {
error!("Unexpected error in guild query: {:?}", e);
Err(e)
}
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE guilds SET name = ?, prefix = ? WHERE id = ?
",
self.name,
self.prefix,
self.id
)
.execute(pool)
.await
.unwrap();
}
}

78
src/models/mod.rs Normal file
View File

@ -0,0 +1,78 @@
pub mod channel_data;
pub mod guild_data;
pub mod reminder;
pub mod timer;
pub mod user_data;
use serenity::{async_trait, model::id::GuildId, prelude::Context};
use crate::{consts::DEFAULT_PREFIX, GuildDataCache, SQLPool};
use guild_data::GuildData;
use std::sync::Arc;
use tokio::sync::RwLock;
#[async_trait]
pub trait CtxGuildData {
async fn guild_data<G: Into<GuildId> + Send + Sync>(
&self,
guild_id: G,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error>;
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String;
}
#[async_trait]
impl CtxGuildData for Context {
async fn guild_data<G: Into<GuildId> + Send + Sync>(
&self,
guild_id: G,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error> {
let guild_id = guild_id.into();
let guild = guild_id.to_guild_cached(&self.cache).await.unwrap();
let guild_cache = self
.data
.read()
.await
.get::<GuildDataCache>()
.cloned()
.unwrap();
let x = if let Some(guild_data) = guild_cache.get(&guild_id) {
Ok(guild_data.clone())
} else {
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
match GuildData::from_guild(guild, &pool).await {
Ok(d) => {
let lock = Arc::new(RwLock::new(d));
guild_cache.insert(guild_id, lock.clone());
Ok(lock)
}
Err(e) => Err(e),
}
};
x
}
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String {
if let Some(guild_id) = guild_id {
self.guild_data(guild_id)
.await
.unwrap()
.read()
.await
.prefix
.clone()
} else {
DEFAULT_PREFIX.clone()
}
}
}

507
src/models/reminder.rs Normal file
View File

@ -0,0 +1,507 @@
use serenity::{
client::Context,
model::id::{ChannelId, GuildId, UserId},
};
use chrono::NaiveDateTime;
use crate::{
consts::{DAY, HOUR, MINUTE, REGEX_CHANNEL},
SQLPool,
};
use num_integer::Integer;
use ring::hmac;
use std::convert::{TryFrom, TryInto};
use std::env;
fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let mut sections = vec![];
for (var, name) in [days, hours, minutes, seconds]
.iter()
.zip(["days", "hours", "minutes", "seconds"].iter())
{
if *var > 0 {
sections.push(format!("{} {}", var, name));
}
}
sections.join(", ")
}
#[derive(Debug)]
pub struct Reminder {
pub id: u32,
pub uid: String,
pub channel: u64,
pub utc_time: NaiveDateTime,
pub interval: Option<u32>,
pub expires: Option<NaiveDateTime>,
pub enabled: bool,
pub content: String,
pub embed_description: String,
pub set_by: Option<u64>,
}
impl Reminder {
pub async fn from_uid(ctx: &Context, uid: String) -> Option<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
reminders.uid = ?
",
uid
)
.fetch_one(&pool)
.await
.ok()
}
pub async fn from_id(ctx: &Context, id: u32) -> Option<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
reminders.id = ?
",
id
)
.fetch_one(&pool)
.await
.ok()
}
pub async fn from_channel<C: Into<ChannelId>>(
ctx: &Context,
channel_id: C,
flags: &LookFlags,
) -> Vec<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let enabled = if flags.show_disabled { "0,1" } else { "1" };
let channel_id = channel_id.into();
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
channels.channel = ? AND
FIND_IN_SET(reminders.enabled, ?)
ORDER BY
reminders.utc_time
LIMIT
?
",
channel_id.as_u64(),
enabled,
flags.limit
)
.fetch_all(&pool)
.await
.unwrap()
}
pub async fn from_guild(ctx: &Context, guild_id: Option<GuildId>, user: UserId) -> Vec<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
if let Some(guild_id) = guild_id {
let guild_opt = guild_id.to_guild_cached(&ctx).await;
if let Some(guild) = guild_opt {
let channels = guild
.channels
.keys()
.into_iter()
.map(|k| k.as_u64().to_string())
.collect::<Vec<String>>()
.join(",");
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
LEFT JOIN
channels
ON
channels.id = reminders.channel_id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
FIND_IN_SET(channels.channel, ?)
",
channels
)
.fetch_all(&pool)
.await
} else {
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
LEFT JOIN
channels
ON
channels.id = reminders.channel_id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
channels.guild_id = (SELECT id FROM guilds WHERE guild = ?)
",
guild_id.as_u64()
)
.fetch_all(&pool)
.await
}
} else {
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
INNER JOIN
channels
ON
channels.id = reminders.channel_id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
channels.id = (SELECT dm_channel FROM users WHERE user = ?)
",
user.as_u64()
)
.fetch_all(&pool)
.await
}
.unwrap()
}
pub fn display_content(&self) -> &str {
if self.content.is_empty() {
&self.embed_description
} else {
&self.content
}
}
pub fn display(&self, flags: &LookFlags, inter: &str) -> String {
let time_display = match flags.time_display {
TimeDisplayType::Absolute => format!("<t:{}>", self.utc_time.timestamp()),
TimeDisplayType::Relative => format!("<t:{}:R>", self.utc_time.timestamp()),
};
if let Some(interval) = self.interval {
format!(
"'{}' *{}* **{}**, repeating every **{}** (set by {})",
self.display_content(),
&inter,
time_display,
longhand_displacement(interval as u64),
self.set_by
.map(|i| format!("<@{}>", i))
.unwrap_or_else(|| "unknown".to_string())
)
} else {
format!(
"'{}' *{}* **{}** (set by {})",
self.display_content(),
&inter,
time_display,
self.set_by
.map(|i| format!("<@{}>", i))
.unwrap_or_else(|| "unknown".to_string())
)
}
}
pub async fn from_interaction<U: Into<u64>>(
ctx: &Context,
member_id: U,
payload: String,
) -> Result<(Self, ReminderAction), InteractionError> {
let sections = payload.split(".").collect::<Vec<&str>>();
if sections.len() != 3 {
Err(InteractionError::InvalidFormat)
} else {
let action = ReminderAction::try_from(sections[0])
.map_err(|_| InteractionError::InvalidAction)?;
let reminder_id = u32::from_le_bytes(
base64::decode(sections[1])
.map_err(|_| InteractionError::InvalidBase64)?
.try_into()
.map_err(|_| InteractionError::InvalidSize)?,
);
if let Some(reminder) = Self::from_id(ctx, reminder_id).await {
if reminder.signed_action(member_id, action) == payload {
Ok((reminder, action))
} else {
Err(InteractionError::SignatureMismatch)
}
} else {
Err(InteractionError::NoReminder)
}
}
}
pub fn signed_action<U: Into<u64>>(&self, member_id: U, action: ReminderAction) -> String {
let s_key = hmac::Key::new(
hmac::HMAC_SHA256,
env::var("SECRET_KEY")
.expect("No SECRET_KEY provided")
.as_bytes(),
);
let mut context = hmac::Context::with_key(&s_key);
context.update(&self.id.to_le_bytes());
context.update(&member_id.into().to_le_bytes());
let signature = context.sign();
format!(
"{}.{}.{}",
action.to_string(),
base64::encode(self.id.to_le_bytes()),
base64::encode(&signature)
)
}
pub async fn delete(&self, ctx: &Context) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
sqlx::query!(
"
DELETE FROM reminders WHERE id = ?
",
self.id
)
.execute(&pool)
.await
.unwrap();
}
}
#[derive(Debug)]
pub enum InteractionError {
InvalidFormat,
InvalidBase64,
InvalidSize,
NoReminder,
SignatureMismatch,
InvalidAction,
}
impl ToString for InteractionError {
fn to_string(&self) -> String {
match self {
InteractionError::InvalidFormat => {
String::from("The interaction data was improperly formatted")
}
InteractionError::InvalidBase64 => String::from("The interaction data was invalid"),
InteractionError::InvalidSize => String::from("The interaction data was invalid"),
InteractionError::NoReminder => String::from("Reminder could not be found"),
InteractionError::SignatureMismatch => {
String::from("Only the user who did the command can use interactions")
}
InteractionError::InvalidAction => String::from("The action was invalid"),
}
}
}
#[derive(Clone, Copy)]
pub enum ReminderAction {
Delete,
}
impl ToString for ReminderAction {
fn to_string(&self) -> String {
match self {
Self::Delete => String::from("del"),
}
}
}
impl TryFrom<&str> for ReminderAction {
type Error = ();
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"del" => Ok(Self::Delete),
_ => Err(()),
}
}
}
enum TimeDisplayType {
Absolute,
Relative,
}
pub struct LookFlags {
pub limit: u16,
pub show_disabled: bool,
pub channel_id: Option<ChannelId>,
time_display: TimeDisplayType,
}
impl Default for LookFlags {
fn default() -> Self {
Self {
limit: u16::MAX,
show_disabled: true,
channel_id: None,
time_display: TimeDisplayType::Relative,
}
}
}
impl LookFlags {
pub fn from_string(args: &str) -> Self {
let mut new_flags: Self = Default::default();
for arg in args.split(' ') {
match arg {
"enabled" => {
new_flags.show_disabled = false;
}
"time" => {
new_flags.time_display = TimeDisplayType::Absolute;
}
param => {
if let Ok(val) = param.parse::<u16>() {
new_flags.limit = val;
} else if let Some(channel) = REGEX_CHANNEL
.captures(&arg)
.map(|cap| cap.get(1))
.flatten()
.map(|c| c.as_str().parse::<u64>().unwrap())
{
new_flags.channel_id = Some(ChannelId(channel));
}
}
}
}
new_flags
}
}

50
src/models/timer.rs Normal file
View File

@ -0,0 +1,50 @@
use sqlx::MySqlPool;
use chrono::NaiveDateTime;
pub struct Timer {
pub name: String,
pub start_time: NaiveDateTime,
pub owner: u64,
}
impl Timer {
pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> {
sqlx::query_as_unchecked!(
Timer,
"
SELECT name, start_time, owner FROM timers WHERE owner = ?
",
owner
)
.fetch_all(pool)
.await
.unwrap()
}
pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 {
sqlx::query!(
"
SELECT COUNT(1) as count FROM timers WHERE owner = ?
",
owner
)
.fetch_one(pool)
.await
.unwrap()
.count as u32
}
pub async fn create(name: &str, owner: u64, pool: &MySqlPool) {
sqlx::query!(
"
INSERT INTO timers (name, owner) VALUES (?, ?)
",
name,
owner
)
.execute(pool)
.await
.unwrap();
}
}

146
src/models/user_data.rs Normal file
View File

@ -0,0 +1,146 @@
use serenity::{
http::CacheHttp,
model::{id::UserId, user::User},
};
use sqlx::MySqlPool;
use chrono_tz::Tz;
use log::error;
use crate::consts::{LOCAL_LANGUAGE, LOCAL_TIMEZONE};
pub struct UserData {
pub id: u32,
pub user: u64,
pub name: String,
pub dm_channel: u32,
pub language: String,
pub timezone: String,
}
impl UserData {
pub async fn language_of<U>(user: U, pool: &MySqlPool) -> String
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT language FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => r.language,
Err(_) => LOCAL_LANGUAGE.clone(),
}
}
pub async fn timezone_of<U>(user: U, pool: &MySqlPool) -> Tz
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT timezone FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => r.timezone,
Err(_) => LOCAL_TIMEZONE.clone(),
}
.parse()
.unwrap()
}
pub async fn from_user(
user: &User,
ctx: impl CacheHttp,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().to_owned();
match sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, IF(language IS NULL, ?, language) AS language, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ?
",
*LOCAL_LANGUAGE, *LOCAL_TIMEZONE, user_id
)
.fetch_one(pool)
.await
{
Ok(c) => Ok(c),
Err(sqlx::Error::RowNotFound) => {
let dm_channel = user.create_dm_channel(ctx).await?;
let dm_id = dm_channel.id.as_u64().to_owned();
let pool_c = pool.clone();
sqlx::query!(
"
INSERT IGNORE INTO channels (channel) VALUES (?)
",
dm_id
)
.execute(&pool_c)
.await?;
sqlx::query!(
"
INSERT INTO users (user, name, dm_channel, language, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?, ?)
", user_id, user.name, dm_id, *LOCAL_LANGUAGE, *LOCAL_TIMEZONE)
.execute(&pool_c)
.await?;
Ok(sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await?)
}
Err(e) => {
error!("Error querying for user: {:?}", e);
Err(Box::new(e))
},
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE users SET name = ?, language = ?, timezone = ? WHERE id = ?
",
self.name,
self.language,
self.timezone,
self.id
)
.execute(pool)
.await
.unwrap();
}
pub fn timezone(&self) -> Tz {
self.timezone.parse().unwrap()
}
}

View File

@ -112,7 +112,7 @@ impl TimeParser {
DateTime::with_second, DateTime::with_second,
]) { ]) {
time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorHMS)?) time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorHMS)?)
.map_or_else(|| Err(InvalidTime::ParseErrorHMS), |inner| Ok(inner))?; .map_or_else(|| Err(InvalidTime::ParseErrorHMS), Ok)?;
} }
if let Some(dmy) = segments.next() { if let Some(dmy) = segments.next() {
@ -128,7 +128,7 @@ impl TimeParser {
{ {
if let Some(t) = t { if let Some(t) = t {
time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorDMY)?) time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorDMY)?)
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?; .map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
} }
} }
@ -136,7 +136,7 @@ impl TimeParser {
if year.len() == 4 { if year.len() == 4 {
time = time time = time
.with_year(year.parse().map_err(|_| InvalidTime::ParseErrorDMY)?) .with_year(year.parse().map_err(|_| InvalidTime::ParseErrorDMY)?)
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?; .map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
} else if year.len() == 2 { } else if year.len() == 2 {
time = time time = time
.with_year( .with_year(
@ -144,9 +144,9 @@ impl TimeParser {
.parse() .parse()
.map_err(|_| InvalidTime::ParseErrorDMY)?, .map_err(|_| InvalidTime::ParseErrorDMY)?,
) )
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?; .map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
} else { } else {
Err(InvalidTime::ParseErrorDMY)?; return Err(InvalidTime::ParseErrorDMY);
} }
} }
} }
@ -157,10 +157,10 @@ impl TimeParser {
fn process_displacement(&self) -> Result<i64, InvalidTime> { fn process_displacement(&self) -> Result<i64, InvalidTime> {
let mut current_buffer = "0".to_string(); let mut current_buffer = "0".to_string();
let mut seconds = 0 as i64; let mut seconds = 0_i64;
let mut minutes = 0 as i64; let mut minutes = 0_i64;
let mut hours = 0 as i64; let mut hours = 0_i64;
let mut days = 0 as i64; let mut days = 0_i64;
for character in self.time_string.chars() { for character in self.time_string.chars() {
match character { match character {
@ -205,7 +205,7 @@ impl TimeParser {
} }
} }
pub(crate) async fn natural_parser(time: &str, timezone: &str) -> Option<i64> { pub async fn natural_parser(time: &str, timezone: &str) -> Option<i64> {
Command::new(&*PYTHON_LOCATION) Command::new(&*PYTHON_LOCATION)
.arg("-c") .arg("-c")
.arg(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/dp.py"))) .arg(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/dp.py")))