2020-08-06 14:22:13 +00:00
use async_trait ::async_trait ;
use serenity ::{
2020-10-01 17:07:27 +00:00
http ::Http ,
Result as SerenityResult ,
2020-08-06 14:22:13 +00:00
client ::Context ,
2020-08-07 22:45:34 +00:00
framework ::{
Framework ,
2020-08-10 21:12:26 +00:00
standard ::CommandResult ,
2020-08-07 22:45:34 +00:00
} ,
2020-08-07 00:02:01 +00:00
model ::{
2020-08-09 20:01:50 +00:00
guild ::{
Guild ,
Member ,
} ,
2020-08-07 00:02:01 +00:00
channel ::{
Channel , GuildChannel , Message ,
}
} ,
2020-08-10 21:12:26 +00:00
futures ::prelude ::future ::BoxFuture ,
2020-08-07 00:02:01 +00:00
} ;
use log ::{
warn ,
error ,
2020-08-07 15:45:19 +00:00
info ,
2020-08-06 14:22:13 +00:00
} ;
2020-08-07 15:45:19 +00:00
use regex ::{
Regex , Match
} ;
2020-08-07 00:02:01 +00:00
2020-08-06 19:39:05 +00:00
use std ::{
collections ::HashMap ,
fmt ,
2020-09-28 12:42:20 +00:00
env ,
2020-08-06 19:39:05 +00:00
} ;
2020-08-06 14:22:13 +00:00
2020-08-25 16:19:08 +00:00
use crate ::{
models ::ChannelData ,
SQLPool ,
2020-09-28 12:42:20 +00:00
consts ::PREFIX ,
2020-08-25 16:19:08 +00:00
} ;
2020-10-01 17:07:27 +00:00
use serenity ::model ::id ::ChannelId ;
use crate ::consts ::MAX_MESSAGE_LENGTH ;
2020-08-06 14:22:13 +00:00
2020-08-10 21:12:26 +00:00
type CommandFn = for < ' fut > fn ( & ' fut Context , & ' fut Message , String ) -> BoxFuture < ' fut , CommandResult > ;
2020-10-11 17:56:27 +00:00
#[ derive(Debug, PartialEq) ]
2020-08-06 14:22:13 +00:00
pub enum PermissionLevel {
Unrestricted ,
Managed ,
Restricted ,
}
pub struct Command {
2020-08-06 18:18:30 +00:00
pub name : & 'static str ,
pub required_perms : PermissionLevel ,
pub supports_dm : bool ,
2020-08-06 19:39:05 +00:00
pub can_blacklist : bool ,
2020-08-06 18:18:30 +00:00
pub func : CommandFn ,
2020-08-06 14:22:13 +00:00
}
2020-08-09 20:01:50 +00:00
impl Command {
async fn check_permissions ( & self , ctx : & Context , guild : & Guild , member : & Member ) -> bool {
2020-10-11 17:56:27 +00:00
if self . required_perms = = PermissionLevel ::Unrestricted {
true
}
else {
for role_id in & member . roles {
let role = role_id . to_role_cached ( & ctx ) . await ;
2020-08-09 20:01:50 +00:00
2020-10-11 17:56:27 +00:00
if let Some ( cached_role ) = role {
if cached_role . permissions . manage_guild ( ) {
return true
}
else if self . required_perms = = PermissionLevel ::Managed & & cached_role . permissions . manage_messages ( ) {
return true
}
}
}
if self . required_perms = = PermissionLevel ::Managed {
2020-08-09 20:01:50 +00:00
let pool = ctx . data . read ( ) . await
. get ::< SQLPool > ( ) . cloned ( ) . expect ( " Could not get SQLPool from data " ) ;
2020-10-11 17:56:27 +00:00
match sqlx ::query! (
"
2020-08-09 20:01:50 +00:00
SELECT
role
FROM
roles
INNER JOIN
command_restrictions ON roles . id = command_restrictions . role_id
WHERE
command_restrictions . command = ? AND
2020-09-02 16:13:17 +00:00
roles . guild_id = (
2020-08-09 20:01:50 +00:00
SELECT
id
FROM
guilds
WHERE
2020-10-11 17:56:27 +00:00
guild = ? )
" , self.name, guild.id.as_u64())
2020-08-09 20:01:50 +00:00
. fetch_all ( & pool )
. await {
Ok ( rows ) = > {
let role_ids = member . roles . iter ( ) . map ( | r | * r . as_u64 ( ) ) . collect ::< Vec < u64 > > ( ) ;
for row in rows {
if role_ids . contains ( & row . role ) {
return true
}
}
false
}
Err ( sqlx ::Error ::RowNotFound ) = > {
false
}
Err ( e ) = > {
warn! ( " Unexpected error occurred querying command_restrictions: {:?} " , e ) ;
false
}
}
}
2020-10-11 17:56:27 +00:00
else {
2020-08-09 20:01:50 +00:00
false
}
}
}
}
2020-08-06 19:39:05 +00:00
impl fmt ::Debug for Command {
fn fmt ( & self , f : & mut fmt ::Formatter < '_ > ) -> fmt ::Result {
f . debug_struct ( " Command " )
. field ( " name " , & self . name )
. field ( " required_perms " , & self . required_perms )
. field ( " supports_dm " , & self . supports_dm )
. field ( " can_blacklist " , & self . can_blacklist )
. finish ( )
}
}
2020-08-06 14:22:13 +00:00
// create event handler for bot
pub struct RegexFramework {
2020-08-06 18:18:30 +00:00
commands : HashMap < String , & 'static Command > ,
2020-09-03 23:29:19 +00:00
command_matcher : Regex ,
2020-08-07 22:45:34 +00:00
dm_regex_matcher : Regex ,
2020-08-06 14:22:13 +00:00
default_prefix : String ,
2020-08-07 00:02:01 +00:00
client_id : u64 ,
2020-08-06 14:22:13 +00:00
ignore_bots : bool ,
}
2020-10-01 17:07:27 +00:00
#[ async_trait ]
pub trait SendIterator {
async fn say_lines ( self , http : impl AsRef < Http > + Send + Sync + ' async_trait , content : impl Iterator < Item = String > + Send + ' async_trait ) -> SerenityResult < ( ) > ;
}
#[ async_trait ]
impl SendIterator for ChannelId {
async fn say_lines ( self , http : impl AsRef < Http > + Send + Sync + ' async_trait , content : impl Iterator < Item = String > + Send + ' async_trait ) -> SerenityResult < ( ) > {
let mut current_content = String ::new ( ) ;
for line in content {
if current_content . len ( ) + line . len ( ) > MAX_MESSAGE_LENGTH {
self . say ( & http , & current_content ) . await ? ;
current_content = line ;
}
else {
current_content = format! ( " {} \n {} " , current_content , line ) ;
}
}
if ! current_content . is_empty ( ) {
self . say ( & http , & current_content ) . await ? ;
}
Ok ( ( ) )
}
}
2020-08-06 14:22:13 +00:00
impl RegexFramework {
2020-10-12 17:37:14 +00:00
pub fn new < T : Into < u64 > > ( client_id : T ) -> Self {
2020-08-06 14:22:13 +00:00
Self {
2020-08-06 18:18:30 +00:00
commands : HashMap ::new ( ) ,
2020-09-03 23:29:19 +00:00
command_matcher : Regex ::new ( r # "^$"# ) . unwrap ( ) ,
2020-08-07 22:45:34 +00:00
dm_regex_matcher : Regex ::new ( r # "^$"# ) . unwrap ( ) ,
2020-09-28 12:42:20 +00:00
default_prefix : env ::var ( " DEFAULT_PREFIX " ) . unwrap_or_else ( | _ | PREFIX . to_string ( ) ) ,
2020-10-12 17:37:14 +00:00
client_id : client_id . into ( ) ,
2020-08-06 14:22:13 +00:00
ignore_bots : true ,
}
}
2020-09-28 12:42:20 +00:00
pub fn default_prefix < T : ToString > ( mut self , new_prefix : T ) -> Self {
2020-08-06 14:22:13 +00:00
self . default_prefix = new_prefix . to_string ( ) ;
self
}
pub fn ignore_bots ( mut self , ignore_bots : bool ) -> Self {
self . ignore_bots = ignore_bots ;
self
}
2020-10-12 17:37:14 +00:00
pub fn add_command < S : ToString > ( mut self , name : S , command : & 'static Command ) -> Self {
2020-08-09 22:59:31 +00:00
self . commands . insert ( name . to_string ( ) , command ) ;
2020-08-06 14:22:13 +00:00
self
}
pub fn build ( mut self ) -> Self {
2020-08-07 15:45:19 +00:00
{
2020-08-07 22:45:34 +00:00
let command_names ;
{
let mut command_names_vec = self . commands
. keys ( )
. map ( | k | & k [ .. ] )
. collect ::< Vec < & str > > ( ) ;
command_names_vec . sort_unstable_by ( | a , b | b . len ( ) . cmp ( & a . len ( ) ) ) ;
command_names = command_names_vec . join ( " | " ) ;
}
info! ( " Command names: {} " , command_names ) ;
2020-08-07 15:45:19 +00:00
2020-08-07 22:45:34 +00:00
{
let match_string = r # "^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
. replace ( " COMMANDS " , command_names . as_str ( ) )
. replace ( " ID " , self . client_id . to_string ( ) . as_str ( ) ) ;
2020-08-07 15:45:19 +00:00
2020-09-03 23:29:19 +00:00
self . command_matcher = Regex ::new ( match_string . as_str ( ) ) . unwrap ( ) ;
2020-08-07 22:45:34 +00:00
}
2020-08-07 15:45:19 +00:00
}
2020-08-07 22:45:34 +00:00
{
let dm_command_names ;
{
let mut command_names_vec = self . commands
. iter ( )
. filter_map ( | ( key , command ) | {
if command . supports_dm {
Some ( & key [ .. ] )
} else {
None
}
} )
. collect ::< Vec < & str > > ( ) ;
command_names_vec . sort_unstable_by ( | a , b | b . len ( ) . cmp ( & a . len ( ) ) ) ;
dm_command_names = command_names_vec . join ( " | " ) ;
}
2020-08-06 14:22:13 +00:00
2020-08-07 22:45:34 +00:00
{
let match_string = r # "^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(\$)|())(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
. replace ( " COMMANDS " , dm_command_names . as_str ( ) )
. replace ( " ID " , self . client_id . to_string ( ) . as_str ( ) ) ;
2020-08-07 00:02:01 +00:00
2020-08-07 22:45:34 +00:00
self . dm_regex_matcher = Regex ::new ( match_string . as_str ( ) ) . unwrap ( ) ;
}
}
2020-08-07 00:02:01 +00:00
2020-08-06 14:22:13 +00:00
self
}
}
2020-08-07 00:02:01 +00:00
enum PermissionCheck {
None , // No permissions
Basic , // Send + Embed permissions (sufficient to reply)
All , // Above + Manage Webhooks (sufficient to operate)
}
2020-08-06 14:22:13 +00:00
#[ async_trait ]
impl Framework for RegexFramework {
async fn dispatch ( & self , ctx : Context , msg : Message ) {
2020-08-07 00:02:01 +00:00
2020-08-07 22:45:34 +00:00
async fn check_self_permissions ( ctx : & Context , guild : & Guild , channel : & GuildChannel ) -> Result < PermissionCheck , Box < dyn std ::error ::Error + Sync + Send > > {
2020-08-07 00:02:01 +00:00
let user_id = ctx . cache . current_user_id ( ) . await ;
let guild_perms = guild . member_permissions ( user_id ) ;
let perms = channel . permissions_for_user ( ctx , user_id ) . await ? ;
let basic_perms = perms . send_messages ( ) & & perms . embed_links ( ) ;
Ok ( if basic_perms & & guild_perms . manage_webhooks ( ) {
PermissionCheck ::All
}
else if basic_perms {
PermissionCheck ::Basic
}
else {
PermissionCheck ::None
} )
}
2020-08-07 22:45:34 +00:00
async fn check_prefix ( ctx : & Context , guild : & Guild , prefix_opt : Option < Match < '_ > > ) -> bool {
2020-08-07 15:45:19 +00:00
if let Some ( prefix ) = prefix_opt {
let pool = ctx . data . read ( ) . await
. get ::< SQLPool > ( ) . cloned ( ) . expect ( " Could not get SQLPool from data " ) ;
2020-09-01 16:07:51 +00:00
match sqlx ::query! ( " SELECT prefix FROM guilds WHERE guild = ? " , guild . id . as_u64 ( ) )
2020-08-07 15:45:19 +00:00
. fetch_one ( & pool )
. await {
Ok ( row ) = > {
prefix . as_str ( ) = = row . prefix
}
Err ( sqlx ::Error ::RowNotFound ) = > {
2020-08-09 22:59:31 +00:00
let _ = sqlx ::query! ( " INSERT INTO guilds (guild, name) VALUES (?, ?) " , guild . id . as_u64 ( ) , guild . name )
2020-08-07 22:45:34 +00:00
. execute ( & pool )
. await ;
2020-08-07 15:45:19 +00:00
prefix . as_str ( ) = = " $ "
}
Err ( e ) = > {
warn! ( " Unexpected error in prefix query: {:?} " , e ) ;
false
}
}
}
else {
true
}
}
2020-08-07 00:02:01 +00:00
// gate to prevent analysing messages unnecessarily
if ( msg . author . bot & & self . ignore_bots ) | |
msg . tts | |
2020-09-25 22:07:22 +00:00
msg . content . is_empty ( ) | |
! msg . attachments . is_empty ( ) { }
2020-08-07 00:02:01 +00:00
// Guild Command
else if let ( Some ( guild ) , Some ( Channel ::Guild ( channel ) ) ) = ( msg . guild ( & ctx ) . await , msg . channel ( & ctx ) . await ) {
2020-08-09 20:01:50 +00:00
let member = guild . member ( & ctx , & msg . author ) . await . unwrap ( ) ;
2020-09-03 23:29:19 +00:00
if let Some ( full_match ) = self . command_matcher . captures ( & msg . content [ .. ] ) {
2020-08-07 00:02:01 +00:00
2020-08-07 22:45:34 +00:00
if check_prefix ( & ctx , & guild , full_match . name ( " prefix " ) ) . await {
2020-08-07 00:02:01 +00:00
2020-08-07 15:45:19 +00:00
match check_self_permissions ( & ctx , & guild , & channel ) . await {
Ok ( perms ) = > match perms {
2020-08-07 22:45:34 +00:00
PermissionCheck ::All = > {
2020-08-25 16:19:08 +00:00
let pool = ctx . data . read ( ) . await
. get ::< SQLPool > ( ) . cloned ( ) . expect ( " Could not get SQLPool from data " ) ;
2020-08-07 22:45:34 +00:00
let command = self . commands . get ( full_match . name ( " cmd " ) . unwrap ( ) . as_str ( ) ) . unwrap ( ) ;
2020-09-01 16:07:51 +00:00
let channel_data = ChannelData ::from_channel ( msg . channel ( & ctx ) . await . unwrap ( ) , & pool ) . await ;
2020-08-25 16:19:08 +00:00
2020-08-26 17:26:28 +00:00
if ! command . can_blacklist | | ! channel_data . map ( | c | c . blacklisted ) . unwrap_or ( false ) {
2020-08-25 16:19:08 +00:00
let args = full_match . name ( " args " )
. map ( | m | m . as_str ( ) )
. unwrap_or ( " " )
. to_string ( ) ;
2020-08-07 22:45:34 +00:00
2020-08-25 16:19:08 +00:00
if command . check_permissions ( & ctx , & guild , & member ) . await {
( command . func ) ( & ctx , & msg , args ) . await . unwrap ( ) ;
}
2020-10-11 17:56:27 +00:00
else if command . required_perms = = PermissionLevel ::Restricted {
let _ = msg . channel_id . say ( & ctx , " You must have permission level `Manage Server` or greater to use this command. " ) . await ;
}
else if command . required_perms = = PermissionLevel ::Managed {
let _ = msg . channel_id . say ( & ctx , " You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `{prefix}restrict` command to specify allowed roles. " ) . await ;
}
2020-08-09 20:01:50 +00:00
}
2020-08-07 22:45:34 +00:00
}
2020-08-07 00:02:01 +00:00
2020-08-07 22:45:34 +00:00
PermissionCheck ::Basic = > {
2020-08-10 21:12:26 +00:00
let _ = msg . channel_id . say ( & ctx , " Not enough perms " ) . await ;
2020-08-07 22:45:34 +00:00
}
2020-08-07 00:02:01 +00:00
2020-08-07 15:45:19 +00:00
PermissionCheck ::None = > {
warn! ( " Missing enough permissions for guild {} " , guild . id ) ;
}
2020-08-07 00:02:01 +00:00
}
2020-08-07 15:45:19 +00:00
Err ( e ) = > {
error! ( " Error occurred getting permissions in guild {}: {:?} " , guild . id , e ) ;
}
2020-08-07 00:02:01 +00:00
}
}
}
}
// DM Command
2020-09-25 22:07:22 +00:00
else if let Some ( full_match ) = self . dm_regex_matcher . captures ( & msg . content [ .. ] ) {
let command = self . commands . get ( full_match . name ( " cmd " ) . unwrap ( ) . as_str ( ) ) . unwrap ( ) ;
let args = full_match . name ( " args " )
. map ( | m | m . as_str ( ) )
. unwrap_or ( " " )
. to_string ( ) ;
( command . func ) ( & ctx , & msg , args ) . await . unwrap ( ) ;
2020-08-07 00:02:01 +00:00
}
2020-08-06 14:22:13 +00:00
}
}