reminder-bot/src/framework.rs

395 lines
12 KiB
Rust
Raw Normal View History

2020-08-06 14:22:13 +00:00
use async_trait::async_trait;
use serenity::{
2020-08-29 17:07:15 +00:00
Result as SerenityResult,
2020-08-06 14:22:13 +00:00
client::Context,
framework::{
Framework,
standard::CommandResult,
},
2020-08-07 00:02:01 +00:00
model::{
2020-08-29 17:07:15 +00:00
id::ChannelId,
2020-08-09 20:01:50 +00:00
guild::{
Guild,
Member,
},
2020-08-07 00:02:01 +00:00
channel::{
Channel, GuildChannel, Message,
}
},
futures::prelude::future::BoxFuture,
2020-08-07 00:02:01 +00:00
};
use log::{
warn,
error,
debug,
info,
2020-08-06 14:22:13 +00:00
};
use regex::{
Regex, Match
};
2020-08-07 00:02:01 +00:00
use std::{
collections::HashMap,
fmt,
};
2020-08-06 14:22:13 +00:00
2020-08-25 16:19:08 +00:00
use crate::{
models::ChannelData,
SQLPool,
};
2020-08-06 14:22:13 +00:00
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, CommandResult>;
2020-08-29 17:07:15 +00:00
#[async_trait]
pub trait SendFromDb {
async fn say_named(&self, ctx: &&Context, language: String, name: &str) -> SerenityResult<Message>;
}
struct Value {
value: String,
}
#[async_trait]
impl SendFromDb for ChannelId {
async fn say_named(&self, ctx: &&Context, language: String, name: &str) -> SerenityResult<Message> {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let row = sqlx::query_as!(Value,
"
SELECT value FROM strings WHERE (language = ? OR language = 'EN') AND name = ? ORDER BY language = 'EN'
", language, name)
.fetch_one(&pool)
.await
.expect("No string with that name");
self.say(ctx, row.value).await
}
}
#[derive(Debug)]
2020-08-06 14:22:13 +00:00
pub enum PermissionLevel {
Unrestricted,
Managed,
Restricted,
}
pub struct Command {
pub name: &'static str,
pub required_perms: PermissionLevel,
pub supports_dm: bool,
pub can_blacklist: bool,
pub func: CommandFn,
2020-08-06 14:22:13 +00:00
}
2020-08-09 20:01:50 +00:00
impl Command {
async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
guild.member_permissions(&member.user).manage_guild() || match self.required_perms {
PermissionLevel::Unrestricted => true,
PermissionLevel::Managed => {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
match sqlx::query!("
SELECT
role
FROM
roles
INNER JOIN
command_restrictions ON roles.id = command_restrictions.role_id
WHERE
command_restrictions.command = ? AND
command_restrictions.guild_id = (
SELECT
id
FROM
guilds
WHERE
guild = ?
)", self.name, guild.id.as_u64())
.fetch_all(&pool)
.await {
Ok(rows) => {
let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>();
for row in rows {
if role_ids.contains(&row.role) {
return true
}
}
false
}
Err(sqlx::Error::RowNotFound) => {
false
}
Err(e) => {
warn!("Unexpected error occurred querying command_restrictions: {:?}", e);
false
}
}
}
PermissionLevel::Restricted => {
false
}
}
}
}
impl fmt::Debug for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Command")
.field("name", &self.name)
.field("required_perms", &self.required_perms)
.field("supports_dm", &self.supports_dm)
.field("can_blacklist", &self.can_blacklist)
.finish()
}
}
2020-08-06 14:22:13 +00:00
// create event handler for bot
pub struct RegexFramework {
commands: HashMap<String, &'static Command>,
2020-08-07 00:02:01 +00:00
regex_matcher: Regex,
dm_regex_matcher: Regex,
2020-08-06 14:22:13 +00:00
default_prefix: String,
2020-08-07 00:02:01 +00:00
client_id: u64,
2020-08-06 14:22:13 +00:00
ignore_bots: bool,
}
impl RegexFramework {
2020-08-07 00:02:01 +00:00
pub fn new(client_id: u64) -> Self {
2020-08-06 14:22:13 +00:00
Self {
commands: HashMap::new(),
2020-08-07 00:02:01 +00:00
regex_matcher: Regex::new(r#"^$"#).unwrap(),
dm_regex_matcher: Regex::new(r#"^$"#).unwrap(),
2020-08-06 14:22:13 +00:00
default_prefix: String::from("$"),
2020-08-07 00:02:01 +00:00
client_id,
2020-08-06 14:22:13 +00:00
ignore_bots: true,
}
}
pub fn default_prefix(mut self, new_prefix: &str) -> Self {
self.default_prefix = new_prefix.to_string();
self
}
pub fn ignore_bots(mut self, ignore_bots: bool) -> Self {
self.ignore_bots = ignore_bots;
self
}
2020-08-09 22:59:31 +00:00
pub fn add_command(mut self, name: &str, command: &'static Command) -> Self {
self.commands.insert(name.to_string(), command);
2020-08-06 14:22:13 +00:00
self
}
pub fn build(mut self) -> Self {
{
let command_names;
{
let mut command_names_vec = self.commands
.keys()
.map(|k| &k[..])
.collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
command_names = command_names_vec.join("|");
}
info!("Command names: {}", command_names);
{
let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
.replace("COMMANDS", command_names.as_str())
.replace("ID", self.client_id.to_string().as_str());
self.regex_matcher = Regex::new(match_string.as_str()).unwrap();
}
}
{
let dm_command_names;
{
let mut command_names_vec = self.commands
.iter()
.filter_map(|(key, command)| {
if command.supports_dm {
Some(&key[..])
} else {
None
}
})
.collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
dm_command_names = command_names_vec.join("|");
}
2020-08-06 14:22:13 +00:00
{
let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(\$)|())(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
.replace("COMMANDS", dm_command_names.as_str())
.replace("ID", self.client_id.to_string().as_str());
2020-08-07 00:02:01 +00:00
self.dm_regex_matcher = Regex::new(match_string.as_str()).unwrap();
}
}
2020-08-07 00:02:01 +00:00
2020-08-06 14:22:13 +00:00
self
}
}
2020-08-07 00:02:01 +00:00
enum PermissionCheck {
None, // No permissions
Basic, // Send + Embed permissions (sufficient to reply)
All, // Above + Manage Webhooks (sufficient to operate)
}
2020-08-06 14:22:13 +00:00
#[async_trait]
impl Framework for RegexFramework {
async fn dispatch(&self, ctx: Context, msg: Message) {
2020-08-07 00:02:01 +00:00
async fn check_self_permissions(ctx: &Context, guild: &Guild, channel: &GuildChannel) -> Result<PermissionCheck, Box<dyn std::error::Error + Sync + Send>> {
2020-08-07 00:02:01 +00:00
let user_id = ctx.cache.current_user_id().await;
let guild_perms = guild.member_permissions(user_id);
let perms = channel.permissions_for_user(ctx, user_id).await?;
let basic_perms = perms.send_messages() && perms.embed_links();
Ok(if basic_perms && guild_perms.manage_webhooks() {
PermissionCheck::All
}
else if basic_perms {
PermissionCheck::Basic
}
else {
PermissionCheck::None
})
}
async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool {
if let Some(prefix) = prefix_opt {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
match sqlx::query!("SELECT prefix FROM guilds WHERE id = ?", guild.id.as_u64())
.fetch_one(&pool)
.await {
Ok(row) => {
prefix.as_str() == row.prefix
}
Err(sqlx::Error::RowNotFound) => {
2020-08-09 22:59:31 +00:00
let _ = sqlx::query!("INSERT INTO guilds (guild, name) VALUES (?, ?)", guild.id.as_u64(), guild.name)
.execute(&pool)
.await;
prefix.as_str() == "$"
}
Err(e) => {
warn!("Unexpected error in prefix query: {:?}", e);
false
}
}
}
else {
true
}
}
2020-08-07 00:02:01 +00:00
// gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) ||
msg.tts ||
msg.content.len() == 0 ||
msg.attachments.len() > 0
{
return
}
// Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) {
2020-08-09 20:01:50 +00:00
let member = guild.member(&ctx, &msg.author).await.unwrap();
if let Some(full_match) = self.regex_matcher.captures(&msg.content[..]) {
2020-08-07 00:02:01 +00:00
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
2020-08-07 00:02:01 +00:00
debug!("Prefix matched on {}", msg.content);
2020-08-07 00:02:01 +00:00
match check_self_permissions(&ctx, &guild, &channel).await {
Ok(perms) => match perms {
PermissionCheck::All => {
2020-08-25 16:19:08 +00:00
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap();
2020-08-25 16:19:08 +00:00
let channel_data = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), pool).await;
2020-08-26 17:26:28 +00:00
if !command.can_blacklist || !channel_data.map(|c| c.blacklisted).unwrap_or(false) {
2020-08-25 16:19:08 +00:00
let args = full_match.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
2020-08-25 16:19:08 +00:00
if command.check_permissions(&ctx, &guild, &member).await {
(command.func)(&ctx, &msg, args).await.unwrap();
}
2020-08-09 20:01:50 +00:00
}
}
2020-08-07 00:02:01 +00:00
PermissionCheck::Basic => {
let _ = msg.channel_id.say(&ctx, "Not enough perms").await;
}
2020-08-07 00:02:01 +00:00
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
2020-08-07 00:02:01 +00:00
}
Err(e) => {
error!("Error occurred getting permissions in guild {}: {:?}", guild.id, e);
}
2020-08-07 00:02:01 +00:00
}
}
}
}
// DM Command
else {
if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap();
let args = full_match.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
(command.func)(&ctx, &msg, args).await.unwrap();
}
2020-08-07 00:02:01 +00:00
}
2020-08-06 14:22:13 +00:00
}
}