reminder-bot/src/framework.rs

473 lines
16 KiB
Rust
Raw Normal View History

2020-08-06 14:22:13 +00:00
use async_trait::async_trait;
use serenity::{
client::Context,
constants::MESSAGE_CODE_LIMIT,
2020-10-26 11:10:00 +00:00
framework::Framework,
futures::prelude::future::BoxFuture,
http::Http,
2020-08-07 00:02:01 +00:00
model::{
channel::{Channel, GuildChannel, Message},
guild::{Guild, Member},
id::ChannelId,
2020-08-07 00:02:01 +00:00
},
Result as SerenityResult,
2020-08-07 00:02:01 +00:00
};
use log::{error, info, warn};
2020-08-06 14:22:13 +00:00
2020-10-22 09:31:47 +00:00
use regex::{Match, Regex, RegexBuilder};
2020-08-07 00:02:01 +00:00
use std::{collections::HashMap, fmt};
2020-08-06 14:22:13 +00:00
use crate::language_manager::LanguageManager;
use crate::models::{GuildData, UserData};
use crate::{models::ChannelData, SQLPool};
2020-08-06 14:22:13 +00:00
2020-10-26 11:10:00 +00:00
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, ()>;
#[derive(Debug, PartialEq)]
2020-08-06 14:22:13 +00:00
pub enum PermissionLevel {
Unrestricted,
Managed,
Restricted,
}
pub struct Command {
pub name: &'static str,
pub required_perms: PermissionLevel,
pub supports_dm: bool,
pub can_blacklist: bool,
pub func: CommandFn,
2020-08-06 14:22:13 +00:00
}
2020-08-09 20:01:50 +00:00
impl Command {
async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
if self.required_perms == PermissionLevel::Unrestricted {
true
} else {
let permissions = guild.member_permissions(&ctx, &member.user).await.unwrap();
if permissions.manage_guild()
|| (permissions.manage_messages()
&& self.required_perms == PermissionLevel::Managed)
2020-10-13 10:36:20 +00:00
{
return true;
}
if self.required_perms == PermissionLevel::Managed {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
2020-08-09 20:01:50 +00:00
match sqlx::query!(
"
2020-08-09 20:01:50 +00:00
SELECT
role
FROM
roles
INNER JOIN
command_restrictions ON roles.id = command_restrictions.role_id
WHERE
command_restrictions.command = ? AND
2020-09-02 16:13:17 +00:00
roles.guild_id = (
2020-08-09 20:01:50 +00:00
SELECT
id
FROM
guilds
WHERE
guild = ?)
",
self.name,
guild.id.as_u64()
)
.fetch_all(&pool)
.await
{
2020-08-09 20:01:50 +00:00
Ok(rows) => {
let role_ids = member
.roles
.iter()
.map(|r| *r.as_u64())
.collect::<Vec<u64>>();
2020-08-09 20:01:50 +00:00
for row in rows {
if role_ids.contains(&row.role) {
return true;
2020-08-09 20:01:50 +00:00
}
}
false
}
Err(sqlx::Error::RowNotFound) => false,
2020-08-09 20:01:50 +00:00
Err(e) => {
warn!(
"Unexpected error occurred querying command_restrictions: {:?}",
e
);
2020-08-09 20:01:50 +00:00
false
}
}
} else {
2020-08-09 20:01:50 +00:00
false
}
}
}
}
impl fmt::Debug for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Command")
.field("name", &self.name)
.field("required_perms", &self.required_perms)
.field("supports_dm", &self.supports_dm)
.field("can_blacklist", &self.can_blacklist)
.finish()
}
}
#[async_trait]
pub trait SendIterator {
async fn say_lines(
self,
http: impl AsRef<Http> + Send + Sync + 'async_trait,
content: impl Iterator<Item = String> + Send + 'async_trait,
) -> SerenityResult<()>;
}
#[async_trait]
impl SendIterator for ChannelId {
async fn say_lines(
self,
http: impl AsRef<Http> + Send + Sync + 'async_trait,
content: impl Iterator<Item = String> + Send + 'async_trait,
) -> SerenityResult<()> {
let mut current_content = String::new();
for line in content {
if current_content.len() + line.len() > MESSAGE_CODE_LIMIT as usize {
2020-10-13 23:19:41 +00:00
self.send_message(&http, |m| {
m.allowed_mentions(|am| am.empty_parse())
.content(&current_content)
})
.await?;
current_content = line;
} else {
current_content = format!("{}\n{}", current_content, line);
}
}
if !current_content.is_empty() {
2020-10-13 23:19:41 +00:00
self.send_message(&http, |m| {
m.allowed_mentions(|am| am.empty_parse())
.content(&current_content)
})
.await?;
}
Ok(())
}
}
2020-10-22 09:31:47 +00:00
pub struct RegexFramework {
commands: HashMap<String, &'static Command>,
command_matcher: Regex,
dm_regex_matcher: Regex,
default_prefix: String,
client_id: u64,
ignore_bots: bool,
case_insensitive: bool,
}
2020-08-06 14:22:13 +00:00
impl RegexFramework {
pub fn new<T: Into<u64>>(client_id: T) -> Self {
2020-08-06 14:22:13 +00:00
Self {
commands: HashMap::new(),
command_matcher: Regex::new(r#"^$"#).unwrap(),
dm_regex_matcher: Regex::new(r#"^$"#).unwrap(),
default_prefix: "".to_string(),
client_id: client_id.into(),
2020-08-06 14:22:13 +00:00
ignore_bots: true,
2020-10-22 09:31:47 +00:00
case_insensitive: true,
2020-08-06 14:22:13 +00:00
}
}
2020-10-22 09:31:47 +00:00
pub fn case_insensitive(mut self, case_insensitive: bool) -> Self {
self.case_insensitive = case_insensitive;
self
}
pub fn default_prefix<T: ToString>(mut self, new_prefix: T) -> Self {
2020-08-06 14:22:13 +00:00
self.default_prefix = new_prefix.to_string();
self
}
pub fn ignore_bots(mut self, ignore_bots: bool) -> Self {
self.ignore_bots = ignore_bots;
self
}
pub fn add_command<S: ToString>(mut self, name: S, command: &'static Command) -> Self {
2020-08-09 22:59:31 +00:00
self.commands.insert(name.to_string(), command);
2020-08-06 14:22:13 +00:00
self
}
pub fn build(mut self) -> Self {
{
let command_names;
{
let mut command_names_vec =
self.commands.keys().map(|k| &k[..]).collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
command_names = command_names_vec.join("|");
}
info!("Command names: {}", command_names);
{
let match_string = r#"^(?:(?:<@ID>\s*)|(?:<@!ID>\s*)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
.replace("COMMANDS", command_names.as_str())
.replace("ID", self.client_id.to_string().as_str());
2020-10-22 09:31:47 +00:00
self.command_matcher = RegexBuilder::new(match_string.as_str())
.case_insensitive(self.case_insensitive)
2020-10-26 10:16:38 +00:00
.dot_matches_new_line(true)
2020-10-22 09:31:47 +00:00
.build()
.unwrap();
}
}
{
let dm_command_names;
{
let mut command_names_vec = self
.commands
.iter()
.filter_map(|(key, command)| {
if command.supports_dm {
Some(&key[..])
} else {
None
}
})
.collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
dm_command_names = command_names_vec.join("|");
}
2020-08-06 14:22:13 +00:00
{
let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(\$)|())(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
.replace("COMMANDS", dm_command_names.as_str())
.replace("ID", self.client_id.to_string().as_str());
2020-08-07 00:02:01 +00:00
2020-10-22 09:31:47 +00:00
self.dm_regex_matcher = RegexBuilder::new(match_string.as_str())
.case_insensitive(self.case_insensitive)
2020-10-26 10:16:38 +00:00
.dot_matches_new_line(true)
2020-10-22 09:31:47 +00:00
.build()
.unwrap();
}
}
2020-08-07 00:02:01 +00:00
2020-08-06 14:22:13 +00:00
self
}
}
2020-08-07 00:02:01 +00:00
enum PermissionCheck {
None, // No permissions
Basic(bool, bool, bool, bool), // Send + Embed permissions (sufficient to reply)
All, // Above + Manage Webhooks (sufficient to operate)
2020-08-07 00:02:01 +00:00
}
2020-08-06 14:22:13 +00:00
#[async_trait]
impl Framework for RegexFramework {
async fn dispatch(&self, ctx: Context, msg: Message) {
async fn check_self_permissions(
ctx: &Context,
guild: &Guild,
channel: &GuildChannel,
) -> SerenityResult<PermissionCheck> {
2020-08-07 00:02:01 +00:00
let user_id = ctx.cache.current_user_id().await;
let guild_perms = guild.member_permissions(&ctx, user_id).await?;
let channel_perms = channel.permissions_for_user(ctx, user_id).await?;
2020-08-07 00:02:01 +00:00
let basic_perms = channel_perms.send_messages();
Ok(
if basic_perms && guild_perms.manage_webhooks() && channel_perms.embed_links() {
PermissionCheck::All
} else if basic_perms {
PermissionCheck::Basic(
guild_perms.manage_webhooks(),
channel_perms.embed_links(),
channel_perms.add_reactions(),
channel_perms.manage_messages(),
)
} else {
PermissionCheck::None
},
)
2020-08-07 00:02:01 +00:00
}
async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool {
if let Some(prefix) = prefix_opt {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let guild_prefix = GuildData::prefix_from_id(Some(guild.id), &pool).await;
guild_prefix.as_str() == prefix.as_str()
} else {
true
}
}
2020-08-07 00:02:01 +00:00
// gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) || msg.content.is_empty() {
}
2020-08-07 00:02:01 +00:00
// Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) =
(msg.guild(&ctx).await, msg.channel(&ctx).await)
{
2020-12-22 14:28:18 +00:00
let data = ctx.data.read().await;
let pool = data
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
2020-12-22 14:28:18 +00:00
GuildData::from_guild(guild, &pool).await;
2020-12-22 14:28:18 +00:00
if let Some(full_match) = self.command_matcher.captures(&msg.content) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
let lm = data.get::<LanguageManager>().unwrap();
let language = UserData::language_of(&msg.author, &pool).await;
match check_self_permissions(&ctx, &guild, &channel).await {
Ok(perms) => match perms {
PermissionCheck::All => {
let command = self
.commands
2020-10-22 09:31:47 +00:00
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.unwrap();
let channel_data = ChannelData::from_channel(
msg.channel(&ctx).await.unwrap(),
&pool,
)
.await
.unwrap();
if !command.can_blacklist || !channel_data.blacklisted {
let args = full_match
.name("args")
2020-08-25 16:19:08 +00:00
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
let member = guild.member(&ctx, &msg.author).await.unwrap();
2020-08-25 16:19:08 +00:00
if command.check_permissions(&ctx, &guild, &member).await {
2020-10-26 11:10:00 +00:00
(command.func)(&ctx, &msg, args).await;
} else if command.required_perms == PermissionLevel::Restricted
{
let _ = msg
.channel_id
.say(&ctx, lm.get(&language, "no_perms_restricted"))
.await;
} else if command.required_perms == PermissionLevel::Managed {
let _ = msg
.channel_id
.say(
&ctx,
lm.get(&language, "no_perms_managed").replace(
"{prefix}",
&GuildData::prefix_from_id(msg.guild_id, &pool)
.await,
),
)
.await;
}
2020-08-09 20:01:50 +00:00
}
}
2020-08-07 00:02:01 +00:00
PermissionCheck::Basic(
manage_webhooks,
embed_links,
add_reactions,
manage_messages,
) => {
let response = lm
.get(&language, "no_perms_general")
.replace(
"{manage_webhooks}",
if manage_webhooks { "" } else { "" },
)
.replace("{embed_links}", if embed_links { "" } else { "" })
.replace(
"{add_reactions}",
if add_reactions { "" } else { "" },
)
.replace(
"{manage_messages}",
if manage_messages { "" } else { "" },
);
let _ = msg.channel_id.say(&ctx, response).await;
}
2020-08-07 00:02:01 +00:00
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
},
2020-08-07 00:02:01 +00:00
Err(e) => {
error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
}
2020-08-07 00:02:01 +00:00
}
}
}
}
// DM Command
2020-09-25 22:07:22 +00:00
else if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
let command = self
.commands
2020-10-22 09:31:47 +00:00
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.unwrap();
let args = full_match
.name("args")
2020-09-25 22:07:22 +00:00
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
2020-10-26 11:10:00 +00:00
(command.func)(&ctx, &msg, args).await;
2020-08-07 00:02:01 +00:00
}
2020-08-06 14:22:13 +00:00
}
}