13 Commits

8 changed files with 926 additions and 648 deletions

543
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[package]
name = "reminder_rs"
version = "1.4.13"
version = "1.5.0-2"
authors = ["jellywx <judesouthworth@pm.me>"]
edition = "2018"
@ -23,7 +23,7 @@ rand = "0.7"
Inflector = "0.11"
levenshtein = "1.0"
# serenity = { version = "0.10", features = ["collector"] }
serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["collector"] }
serenity = { path = "/home/jude/serenity", features = ["collector", "unstable_discord_api"] }
sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]}
[dependencies.regex_command_attr]

View File

@ -1,3 +1,5 @@
CREATE DATABASE IF NOT EXISTS reminders;
SET FOREIGN_KEY_CHECKS=0;
USE reminders;

File diff suppressed because one or more lines are too long

View File

@ -1,13 +1,13 @@
use regex_command_attr::command;
use serenity::{
builder::CreateActionRow,
client::Context,
framework::Framework,
model::{
channel::ReactionType,
channel::{Channel, Message},
id::ChannelId,
id::RoleId,
channel::Message,
id::{ChannelId, MessageId, RoleId},
interactions::ButtonStyle,
},
};
@ -28,11 +28,8 @@ use crate::{
FrameworkCtx, PopularTimezones,
};
#[cfg(feature = "prefix-cache")]
use crate::PrefixCache;
use crate::models::CtxGuildData;
use std::{collections::HashMap, iter, time::Duration};
use std::{collections::HashMap, iter};
#[command]
#[supports_dm(false)]
@ -143,22 +140,28 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) {
Err(_) => {
let filtered_tz = TZ_VARIANTS
.iter()
.map(|tz| (tz, tz.to_string(), levenshtein(&tz.to_string(), &args)))
.filter(|(_, tz, dist)| args.contains(tz) || tz.contains(&args) || dist < &4)
.filter(|tz| {
args.contains(&tz.to_string())
|| tz.to_string().contains(&args)
|| levenshtein(&tz.to_string(), &args) < 4
})
.take(25)
.map(|(tz, tz_s, _)| {
(
tz_s,
format!(
"🕗 `{}`",
Utc::now()
.with_timezone(tz)
.format(user_data.meridian().fmt_str_short())
.to_string()
),
true,
)
});
.map(|t| t.to_owned())
.collect::<Vec<Tz>>();
let fields = filtered_tz.iter().map(|tz| {
(
tz.to_string(),
format!(
"🕗 `{}`",
Utc::now()
.with_timezone(tz)
.format(user_data.meridian().fmt_str_short())
.to_string()
),
true,
)
});
let _ = msg
.channel_id
@ -167,9 +170,24 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) {
e.title(lm.get(&user_data.language, "timezone/no_timezone_title"))
.description(lm.get(&user_data.language, "timezone/no_timezone"))
.color(*THEME_COLOR)
.fields(filtered_tz)
.fields(fields)
.footer(|f| f.text(footer_text))
.url("https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee")
}).components(|c| {
for row in filtered_tz.as_slice().chunks(5) {
let mut action_row = CreateActionRow::default();
for timezone in row {
action_row.create_button(|b| {
b.style(ButtonStyle::Secondary)
.label(timezone.to_string())
.custom_id(format!("timezone:{}", timezone.to_string()))
});
}
c.add_action_row(action_row);
}
c
})
})
.await;
@ -213,6 +231,22 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) {
.footer(|f| f.text(footer_text))
.url("https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee")
})
.components(|c| {
for row in popular_timezones.as_slice().chunks(5) {
let mut action_row = CreateActionRow::default();
for timezone in row {
action_row.create_button(|b| {
b.style(ButtonStyle::Secondary)
.label(timezone.to_string())
.custom_id(format!("timezone:{}", timezone.to_string()))
});
}
c.add_action_row(action_row);
}
c
})
})
.await;
}
@ -304,6 +338,28 @@ async fn language(ctx: &Context, msg: &Message, args: String) {
.description(lm.get(&user_data.language, "lang/invalid"))
.fields(language_codes)
})
.components(|c| {
for row in lm
.all_languages()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<Vec<(String, String)>>()
.as_slice()
.chunks(5)
{
let mut action_row = CreateActionRow::default();
for (code, name) in row {
action_row.create_button(|b| {
b.style(ButtonStyle::Primary)
.label(name.to_title_case())
.custom_id(format!("lang:{}", code.to_uppercase()))
});
}
c.add_action_row(action_row);
}
c
})
})
.await;
}
@ -317,21 +373,7 @@ async fn language(ctx: &Context, msg: &Message, args: String) {
)
});
let flags = lm
.all_languages()
.map(|(k, _)| ReactionType::Unicode(lm.get(k, "flag").to_string()));
let can_react = if let Some(Channel::Guild(channel)) = msg.channel(&ctx).await {
channel
.permissions_for_user(&ctx, ctx.cache.current_user().await)
.await
.map(|p| p.add_reactions())
.unwrap_or(false)
} else {
true
};
let reactor = msg
let _ = msg
.channel_id
.send_message(&ctx, |m| {
m.embed(|e| {
@ -339,57 +381,31 @@ async fn language(ctx: &Context, msg: &Message, args: String) {
.color(*THEME_COLOR)
.description(lm.get(&user_data.language, "lang/select"))
.fields(language_codes)
});
})
.components(|c| {
for row in lm
.all_languages()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<Vec<(String, String)>>()
.as_slice()
.chunks(5)
{
let mut action_row = CreateActionRow::default();
for (code, name) in row {
action_row.create_button(|b| {
b.style(ButtonStyle::Primary)
.label(name.to_title_case())
.custom_id(format!("lang:{}", code.to_uppercase()))
});
}
if can_react {
m.reactions(flags);
}
c.add_action_row(action_row);
}
m
c
})
})
.await;
if let Ok(sent_msg) = reactor {
let reaction_reply = sent_msg
.await_reaction(&ctx)
.timeout(Duration::from_secs(45))
.await;
if let Some(reaction_action) = reaction_reply {
if reaction_action.is_added() {
if let ReactionType::Unicode(emoji) = &reaction_action.as_inner_ref().emoji {
if let Some(lang) = lm.get_language_by_flag(emoji) {
user_data.language = lang.to_string();
user_data.commit_changes(&pool).await;
let _ = msg
.channel_id
.send_message(&ctx, |m| {
m.embed(|e| {
e.title(lm.get(&user_data.language, "lang/set_p_title"))
.color(*THEME_COLOR)
.description(lm.get(&user_data.language, "lang/set_p"))
})
})
.await;
}
}
}
}
if let Some(Channel::Guild(channel)) = msg.channel(&ctx).await {
let has_perms = channel
.permissions_for_user(&ctx, ctx.cache.current_user().await)
.await
.map(|p| p.manage_messages())
.unwrap_or(false);
if has_perms {
let _ = sent_msg.delete_reactions(&ctx).await;
}
}
}
}
}
@ -399,9 +415,7 @@ async fn language(ctx: &Context, msg: &Message, args: String) {
async fn prefix(ctx: &Context, msg: &Message, args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let mut guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool)
.await
.unwrap();
let guild_data = ctx.guild_data(msg.guild_id.unwrap()).await.unwrap();
let language = UserData::language_of(&msg.author, &pool).await;
if args.len() > 5 {
@ -415,18 +429,15 @@ async fn prefix(ctx: &Context, msg: &Message, args: String) {
.say(&ctx, lm.get(&language, "prefix/no_argument"))
.await;
} else {
guild_data.prefix = args;
guild_data.write().await.prefix = args;
#[cfg(feature = "prefix-cache")]
let prefix_cache = ctx.data.read().await.get::<PrefixCache>().cloned().unwrap();
#[cfg(feature = "prefix-cache")]
prefix_cache.insert(msg.guild_id.unwrap(), guild_data.prefix.clone());
guild_data.read().await.commit_changes(&pool).await;
guild_data.commit_changes(&pool).await;
let content =
lm.get(&language, "prefix/success")
.replacen("{prefix}", &guild_data.prefix, 1);
let content = lm.get(&language, "prefix/success").replacen(
"{prefix}",
&guild_data.read().await.prefix,
1,
);
let _ = msg.channel_id.say(&ctx, content).await;
}
@ -670,6 +681,7 @@ SELECT command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHER
let mut new_msg = msg.clone();
new_msg.content = format!("<@{}> {}", &ctx.cache.current_user_id().await, row.command);
new_msg.id = MessageId(0);
framework.dispatch(ctx.clone(), new_msg).await;
},

View File

@ -1,7 +1,5 @@
use regex_command_attr::command;
use chrono_tz::Tz;
use serenity::{
cache::Cache,
client::Context,
@ -47,20 +45,10 @@ use std::{
time::{SystemTime, UNIX_EPOCH},
};
use crate::models::{CtxGuildData, MeridianType};
use crate::models::CtxGuildData;
use regex::Captures;
use serenity::model::channel::Channel;
fn shorthand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let time_repr = format!("{:02}:{:02}:{:02}", hours, minutes, seconds);
format!("{} days, {}", days, time_repr)
}
fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
@ -193,7 +181,7 @@ UPDATE reminders
INNER JOIN `channels`
ON `channels`.id = reminders.channel_id
SET
reminders.`time` = reminders.`time` + ?
reminders.`utc_time` = reminders.`utc_time` + ?
WHERE channels.guild_id = ?
",
displacement,
@ -205,7 +193,7 @@ UPDATE reminders
} else {
sqlx::query!(
"
UPDATE reminders SET `time` = `time` + ? WHERE reminders.channel_id = ?
UPDATE reminders SET `utc_time` = `utc_time` + ? WHERE reminders.channel_id = ?
",
displacement,
user_data.dm_channel
@ -345,11 +333,11 @@ impl LookFlags {
struct LookReminder {
id: u32,
time: u32,
time: NaiveDateTime,
interval: Option<u32>,
channel: u64,
content: String,
description: Option<String>,
description: String,
}
impl LookReminder {
@ -357,31 +345,15 @@ impl LookReminder {
if self.content.len() > 0 {
self.content.clone()
} else {
self.description.clone().unwrap_or(String::from(""))
self.description.clone()
}
}
fn display(
&self,
flags: &LookFlags,
meridian: &MeridianType,
timezone: &Tz,
inter: &str,
) -> String {
fn display(&self, flags: &LookFlags, inter: &str) -> String {
let time_display = match flags.time_display {
TimeDisplayType::Absolute => timezone
.timestamp(self.time as i64, 0)
.format(meridian.fmt_str())
.to_string(),
TimeDisplayType::Absolute => format!("<t:{}>", self.time.timestamp()),
TimeDisplayType::Relative => {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
longhand_displacement((self.time as u64).checked_sub(now).unwrap_or(1))
}
TimeDisplayType::Relative => format!("<t:{}:R>", self.time.timestamp()),
};
if let Some(interval) = self.interval {
@ -409,8 +381,6 @@ async fn look(ctx: &Context, msg: &Message, args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&msg.author, &pool).await;
let timezone = UserData::timezone_of(&msg.author, &pool).await;
let meridian = UserData::meridian_of(&msg.author, &pool).await;
let flags = LookFlags::from_string(&args);
@ -434,26 +404,23 @@ async fn look(ctx: &Context, msg: &Message, args: String) {
LookReminder,
"
SELECT
reminders.id, reminders.time, reminders.interval, channels.channel, messages.content, embeds.description
reminders.id,
reminders.utc_time AS time,
reminders.interval,
channels.channel,
reminders.content,
reminders.embed_description AS description
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
INNER JOIN
messages
ON
messages.id = reminders.message_id
LEFT JOIN
embeds
ON
embeds.id = messages.embed_id
WHERE
channels.channel = ? AND
FIND_IN_SET(reminders.enabled, ?)
ORDER BY
reminders.time
reminders.utc_time
LIMIT
?
",
@ -475,7 +442,7 @@ LIMIT
let display = reminders
.iter()
.map(|reminder| reminder.display(&flags, &meridian, &timezone, &inter));
.map(|reminder| reminder.display(&flags, &inter));
let _ = msg.channel_id.say_lines(&ctx, display).await;
}
@ -509,21 +476,18 @@ async fn delete(ctx: &Context, msg: &Message, _args: String) {
LookReminder,
"
SELECT
reminders.id, reminders.time, reminders.interval, channels.channel, messages.content, embeds.description
reminders.id,
reminders.utc_time AS time,
reminders.interval,
channels.channel,
reminders.content,
reminders.embed_description AS description
FROM
reminders
LEFT OUTER JOIN
channels
ON
channels.id = reminders.channel_id
INNER JOIN
messages
ON
messages.id = reminders.message_id
LEFT JOIN
embeds
ON
embeds.id = messages.embed_id
WHERE
FIND_IN_SET(channels.channel, ?)
",
@ -536,21 +500,18 @@ WHERE
LookReminder,
"
SELECT
reminders.id, reminders.time, reminders.interval, channels.channel, messages.content, embeds.description
reminders.id,
reminders.utc_time AS time,
reminders.interval,
channels.channel,
reminders.content,
reminders.embed_description AS description
FROM
reminders
LEFT OUTER JOIN
channels
ON
channels.id = reminders.channel_id
INNER JOIN
messages
ON
messages.id = reminders.message_id
LEFT JOIN
embeds
ON
embeds.id = messages.embed_id
WHERE
channels.guild_id = (SELECT id FROM guilds WHERE guild = ?)
",
@ -564,17 +525,14 @@ WHERE
LookReminder,
"
SELECT
reminders.id, reminders.time, reminders.interval, channels.channel, messages.content, embeds.description
reminders.id,
reminders.utc_time AS time,
reminders.interval,
channels.channel,
reminders.content,
reminders.embed_description AS description
FROM
reminders
INNER JOIN
messages
ON
reminders.message_id = messages.id
LEFT JOIN
embeds
ON
embeds.id = messages.embed_id
INNER JOIN
channels
ON
@ -593,14 +551,13 @@ WHERE
let enumerated_reminders = reminders.iter().enumerate().map(|(count, reminder)| {
reminder_ids.push(reminder.id);
let time = user_data.timezone().timestamp(reminder.time as i64, 0);
format!(
"**{}**: '{}' *<#{}>* at {}",
"**{}**: '{}' *<#{}>* at <t:{}>",
count + 1,
reminder.display_content(),
reminder.channel,
time.format(user_data.meridian().fmt_str())
reminder.time.timestamp()
)
});
@ -1047,65 +1004,36 @@ async fn countdown(ctx: &Context, msg: &Message, args: String) {
event_name, target_ts
);
sqlx::query!(
"
INSERT INTO embeds (title, description, color) VALUES (?, ?, ?)
",
event_name,
description,
*THEME_COLOR
)
.execute(&pool)
.await
.unwrap();
let embed_id = sqlx::query!(
"
SELECT id FROM embeds WHERE title = ? AND description = ?
",
event_name,
description
)
.fetch_one(&pool)
.await
.unwrap();
sqlx::query!(
"
INSERT INTO messages (embed_id) VALUES (?)
",
embed_id.id
)
.execute(&pool)
.await
.unwrap();
sqlx::query!(
"
INSERT INTO reminders (
`uid`,
`name`,
`message_id`,
`embed_title`,
`embed_description`,
`embed_color`,
`channel_id`,
`time`,
`utc_time`,
`interval`,
`method`,
`set_by`,
`expires`
) VALUES (
?,
'Countdown',
(SELECT id FROM messages WHERE embed_id = ?),
(SELECT id FROM channels WHERE channel = ?),
?,
?,
'countdown',
?,
?,
?,
?,
(SELECT id FROM users WHERE user = ?),
FROM_UNIXTIME(?)
)
",
generate_uid(),
embed_id.id,
event_name,
description,
*THEME_COLOR,
msg.channel_id.as_u64(),
first_time,
interval,
@ -1271,9 +1199,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
.replace("{location}", &ok_locations[0].mention())
.replace(
"{offset}",
&shorthand_displacement(
time_parser.displacement().unwrap() as u64,
),
&format!("<t:{}:R>", time_parser.timestamp().unwrap()),
),
n => lm
.get(&language, "remind/success_bulk")
@ -1288,9 +1214,7 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
)
.replace(
"{offset}",
&shorthand_displacement(
time_parser.displacement().unwrap() as u64,
),
&format!("<t:{}:R>", time_parser.timestamp().unwrap()),
),
};
@ -1397,11 +1321,6 @@ async fn remind_command(ctx: &Context, msg: &Message, args: String, command: Rem
async fn natural(ctx: &Context, msg: &Message, args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let now = SystemTime::now();
let since_epoch = now
.duration_since(UNIX_EPOCH)
.expect("Time calculated as going backwards. Very bad");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
match REGEX_NATURAL_COMMAND_1.captures(&args) {
@ -1465,8 +1384,6 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
match content_res {
Ok(mut content) => {
let offset = timestamp as u64 - since_epoch.as_secs();
let mut ok_locations = vec![];
let mut err_locations = vec![];
let mut err_types = HashSet::new();
@ -1498,7 +1415,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
1 => lm
.get(&user_data.language, "remind/success")
.replace("{location}", &ok_locations[0].mention())
.replace("{offset}", &shorthand_displacement(offset)),
.replace("{offset}", &format!("<t:{}:R>", timestamp)),
n => lm
.get(&user_data.language, "remind/success_bulk")
.replace("{number}", &n.to_string())
@ -1510,7 +1427,7 @@ async fn natural(ctx: &Context, msg: &Message, args: String) {
.collect::<Vec<String>>()
.join(", "),
)
.replace("{offset}", &shorthand_displacement(offset)),
.replace("{offset}", &format!("<t:{}:R>", timestamp)),
};
let error_part = format!(
@ -1698,37 +1615,45 @@ async fn create_reminder<'a, U: Into<u64>, T: TryInto<i64>>(
} else {
sqlx::query!(
"
INSERT INTO messages (content, tts, attachment, attachment_name) VALUES (?, ?, ?, ?)
INSERT INTO reminders (
uid,
content,
tts,
attachment,
attachment_name,
channel_id,
`utc_time`,
expires,
`interval`,
set_by
) VALUES (
?,
?,
?,
?,
?,
?,
DATE_ADD(FROM_UNIXTIME(0), INTERVAL ? SECOND),
DATE_ADD(FROM_UNIXTIME(0), INTERVAL ? SECOND),
?,
(SELECT id FROM users WHERE user = ? LIMIT 1)
)
",
generate_uid(),
content.content,
content.tts,
content.attachment,
content.attachment_name,
db_channel_id,
time as u32,
expires,
interval,
user_id
)
.execute(&pool.clone())
.execute(pool)
.await
.unwrap();
sqlx::query!(
"
INSERT INTO reminders (uid, message_id, channel_id, time, expires, `interval`, method, set_by) VALUES
(?,
(SELECT id FROM messages WHERE content = ? ORDER BY id DESC LIMIT 1),
?, ?, FROM_UNIXTIME(?), ?, 'remind',
(SELECT id FROM users WHERE user = ? LIMIT 1))
",
generate_uid(),
content.content,
db_channel_id,
time as u32,
expires,
interval,
user_id
)
.execute(pool)
.await
.unwrap();
Ok(())
}
} else if time < 0 {

View File

@ -21,7 +21,8 @@ use std::{collections::HashMap, fmt};
use crate::language_manager::LanguageManager;
use crate::models::{CtxGuildData, GuildData, UserData};
use crate::{models::ChannelData, SQLPool};
use crate::{models::ChannelData, LimitExecutors, SQLPool};
use serenity::model::id::MessageId;
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, ()>;
@ -297,9 +298,9 @@ impl RegexFramework {
}
enum PermissionCheck {
None, // No permissions
Basic(bool, bool, bool, bool), // Send + Embed permissions (sufficient to reply)
All, // Above + Manage Webhooks (sufficient to operate)
None, // No permissions
Basic(bool, bool), // Send + Embed permissions (sufficient to reply)
All, // Above + Manage Webhooks (sufficient to operate)
}
#[async_trait]
@ -324,8 +325,6 @@ impl Framework for RegexFramework {
PermissionCheck::Basic(
guild_perms.manage_webhooks(),
channel_perms.embed_links(),
channel_perms.add_reactions(),
channel_perms.manage_messages(),
)
} else {
PermissionCheck::None
@ -345,144 +344,153 @@ impl Framework for RegexFramework {
// gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() {
}
// Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) =
(msg.guild(&ctx).await, msg.channel(&ctx).await)
{
let data = ctx.data.read().await;
} else {
// Guild Command
if let (Some(guild), Some(Channel::Guild(channel))) =
(msg.guild(&ctx).await, msg.channel(&ctx).await)
{
let data = ctx.data.read().await;
let pool = data
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let pool = data
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
if let Some(full_match) = self.command_matcher.captures(&msg.content) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
let lm = data.get::<LanguageManager>().unwrap();
if let Some(full_match) = self.command_matcher.captures(&msg.content) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
let lm = data.get::<LanguageManager>().unwrap();
let language = UserData::language_of(&msg.author, &pool);
let language = UserData::language_of(&msg.author, &pool);
match check_self_permissions(&ctx, &guild, &channel).await {
Ok(perms) => match perms {
PermissionCheck::All => {
let command = self
.commands
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
match check_self_permissions(&ctx, &guild, &channel).await {
Ok(perms) => match perms {
PermissionCheck::All => {
let command = self
.commands
.get(
&full_match
.name("cmd")
.unwrap()
.as_str()
.to_lowercase(),
)
.unwrap();
let channel_data = ChannelData::from_channel(
msg.channel(&ctx).await.unwrap(),
&pool,
)
.await
.unwrap();
let channel_data = ChannelData::from_channel(
msg.channel(&ctx).await.unwrap(),
&pool,
)
.await
.unwrap();
if !command.can_blacklist || !channel_data.blacklisted {
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
if !command.can_blacklist || !channel_data.blacklisted {
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
let member = guild.member(&ctx, &msg.author).await.unwrap();
let member = guild.member(&ctx, &msg.author).await.unwrap();
if command.check_permissions(&ctx, &guild, &member).await {
dbg!(command.name);
if command.check_permissions(&ctx, &guild, &member).await {
dbg!(command.name);
{
let guild_id = guild.id.as_u64().to_owned();
{
let guild_id = guild.id.as_u64().to_owned();
GuildData::from_guild(guild, &pool).await.expect(
&format!(
"Failed to create new guild object for {}",
guild_id
),
);
}
(command.func)(&ctx, &msg, args).await;
} else if command.required_perms == PermissionLevel::Restricted
{
let _ = msg
.channel_id
.say(
&ctx,
lm.get(&language.await, "no_perms_restricted"),
)
.await;
} else if command.required_perms == PermissionLevel::Managed {
let _ = msg
.channel_id
.say(
&ctx,
lm.get(&language.await, "no_perms_managed")
.replace(
"{prefix}",
&ctx.prefix(msg.guild_id).await,
GuildData::from_guild(guild, &pool).await.expect(
&format!(
"Failed to create new guild object for {}",
guild_id
),
)
.await;
);
}
if msg.id == MessageId(0)
|| !ctx.check_executing(msg.author.id).await
{
ctx.set_executing(msg.author.id).await;
(command.func)(&ctx, &msg, args).await;
ctx.drop_executing(msg.author.id).await;
}
} else if command.required_perms
== PermissionLevel::Restricted
{
let _ = msg
.channel_id
.say(
&ctx,
lm.get(&language.await, "no_perms_restricted"),
)
.await;
} else if command.required_perms == PermissionLevel::Managed
{
let _ = msg
.channel_id
.say(
&ctx,
lm.get(&language.await, "no_perms_managed")
.replace(
"{prefix}",
&ctx.prefix(msg.guild_id).await,
),
)
.await;
}
}
}
PermissionCheck::Basic(manage_webhooks, embed_links) => {
let response = lm
.get(&language.await, "no_perms_general")
.replace(
"{manage_webhooks}",
if manage_webhooks { "" } else { "" },
)
.replace(
"{embed_links}",
if embed_links { "" } else { "" },
);
let _ = msg.channel_id.say(&ctx, response).await;
}
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
},
Err(e) => {
error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
}
PermissionCheck::Basic(
manage_webhooks,
embed_links,
add_reactions,
manage_messages,
) => {
let response = lm
.get(&language.await, "no_perms_general")
.replace(
"{manage_webhooks}",
if manage_webhooks { "" } else { "" },
)
.replace("{embed_links}", if embed_links { "" } else { "" })
.replace(
"{add_reactions}",
if add_reactions { "" } else { "" },
)
.replace(
"{manage_messages}",
if manage_messages { "" } else { "" },
);
let _ = msg.channel_id.say(&ctx, response).await;
}
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
},
Err(e) => {
error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
}
}
}
}
}
// DM Command
else if self.dm_enabled {
if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
let command = self
.commands
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.unwrap();
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
// DM Command
else if self.dm_enabled {
if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
let command = self
.commands
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.unwrap();
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
dbg!(command.name);
dbg!(command.name);
(command.func)(&ctx, &msg, args).await;
if msg.id == MessageId(0) || !ctx.check_executing(msg.author.id).await {
ctx.set_executing(msg.author.id).await;
(command.func)(&ctx, &msg, args).await;
ctx.drop_executing(msg.author.id).await;
}
}
}
}
}

View File

@ -12,12 +12,14 @@ use serenity::{
async_trait,
cache::Cache,
client::{bridge::gateway::GatewayIntents, Client},
futures::TryFutureExt,
http::{client::Http, CacheHttp},
model::{
channel::GuildChannel,
channel::Message,
guild::{Guild, GuildUnavailable},
id::{GuildId, UserId},
interactions::{Interaction, InteractionData, InteractionType},
},
prelude::{Context, EventHandler, TypeMapKey},
utils::shard_id,
@ -27,7 +29,7 @@ use sqlx::mysql::MySqlPool;
use dotenv::dotenv;
use std::{collections::HashMap, env, sync::Arc};
use std::{collections::HashMap, env, sync::Arc, time::Instant};
use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds},
@ -37,8 +39,6 @@ use crate::{
models::GuildData,
};
use serenity::futures::TryFutureExt;
use inflector::Inflector;
use log::info;
@ -46,7 +46,12 @@ use dashmap::DashMap;
use tokio::sync::RwLock;
use crate::models::UserData;
use chrono::Utc;
use chrono_tz::Tz;
use serenity::model::prelude::{
InteractionApplicationCommandCallbackDataFlags, InteractionResponseType,
};
struct GuildDataCache;
@ -78,6 +83,65 @@ impl TypeMapKey for PopularTimezones {
type Value = Arc<Vec<Tz>>;
}
struct CurrentlyExecuting;
impl TypeMapKey for CurrentlyExecuting {
type Value = Arc<RwLock<HashMap<UserId, Instant>>>;
}
#[async_trait]
trait LimitExecutors {
async fn check_executing(&self, user: UserId) -> bool;
async fn set_executing(&self, user: UserId);
async fn drop_executing(&self, user: UserId);
}
#[async_trait]
impl LimitExecutors for Context {
async fn check_executing(&self, user: UserId) -> bool {
let currently_executing = self
.data
.read()
.await
.get::<CurrentlyExecuting>()
.cloned()
.unwrap();
let lock = currently_executing.read().await;
lock.get(&user)
.map_or(false, |now| now.elapsed().as_secs() < 4)
}
async fn set_executing(&self, user: UserId) {
let currently_executing = self
.data
.read()
.await
.get::<CurrentlyExecuting>()
.cloned()
.unwrap();
let mut lock = currently_executing.write().await;
lock.insert(user, Instant::now());
}
async fn drop_executing(&self, user: UserId) {
let currently_executing = self
.data
.read()
.await
.get::<CurrentlyExecuting>()
.cloned()
.unwrap();
let mut lock = currently_executing.write().await;
lock.remove(&user);
}
}
struct Handler;
#[async_trait]
@ -194,6 +258,91 @@ DELETE FROM guilds WHERE guild = ?
.await
.unwrap();
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
let (pool, lm) = get_ctx_data(&&ctx).await;
match interaction.kind {
InteractionType::ApplicationCommand => {}
InteractionType::MessageComponent => {
if let (Some(InteractionData::MessageComponent(data)), Some(member)) =
(interaction.clone().data, interaction.clone().member)
{
println!("{}", data.custom_id);
if data.custom_id.starts_with("timezone:") {
let mut user_data = UserData::from_user(&member.user, &ctx, &pool)
.await
.unwrap();
let new_timezone = data.custom_id.replace("timezone:", "").parse::<Tz>();
if let Ok(timezone) = new_timezone {
user_data.timezone = timezone.to_string();
user_data.commit_changes(&pool).await;
let _ = interaction.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
let footer_text = lm.get(&user_data.language, "timezone/footer").replacen(
"{timezone}",
&user_data.timezone,
1,
);
let now = Utc::now().with_timezone(&user_data.timezone());
let content = lm
.get(&user_data.language, "timezone/set_p")
.replacen("{timezone}", &user_data.timezone, 1)
.replacen(
"{time}",
&now.format(user_data.meridian().fmt_str_short()).to_string(),
1,
);
d.create_embed(|e| e.title(lm.get(&user_data.language, "timezone/set_p_title"))
.color(*THEME_COLOR)
.description(content)
.footer(|f| f.text(footer_text)))
.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL);
d
})
}).await;
}
} else if data.custom_id.starts_with("lang:") {
let mut user_data = UserData::from_user(&member.user, &ctx, &pool)
.await
.unwrap();
let lang_code = data.custom_id.replace("lang:", "");
if let Some(lang) = lm.get_language(&lang_code) {
user_data.language = lang.to_string();
user_data.commit_changes(&pool).await;
let _ = interaction
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.create_embed(|e| {
e.title(
lm.get(&user_data.language, "lang/set_p_title"),
)
.color(*THEME_COLOR)
.description(
lm.get(&user_data.language, "lang/set_p"),
)
})
})
})
.await;
}
}
}
}
_ => {}
}
}
}
#[tokio::main]
@ -210,6 +359,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.get_current_user()
.map_ok(|user| user.id.as_u64().to_owned())
.await?;
let application_id = http.get_current_application_info().await?.id;
let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1");
@ -275,6 +425,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
| GatewayIntents::GUILDS
| GatewayIntents::GUILD_MESSAGE_REACTIONS
})
.application_id(application_id.0)
.event_handler(Handler)
.framework_arc(framework_arc.clone())
.await
@ -309,7 +460,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut data = client.data.write().await;
data.insert::<GuildDataCache>(Arc::new(guild_data_cache));
data.insert::<CurrentlyExecuting>(Arc::new(RwLock::new(HashMap::new())));
data.insert::<SQLPool>(pool);
data.insert::<PopularTimezones>(Arc::new(popular_timezones));
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));