check prefix and order commands properly for regex

This commit is contained in:
jude 2020-08-07 16:45:19 +01:00
parent c2c5e79940
commit 305cf79ac8
2 changed files with 74 additions and 23 deletions

View File

@ -14,9 +14,13 @@ use serenity::{
use log::{ use log::{
warn, warn,
error, error,
debug,
info,
}; };
use regex::Regex; use regex::{
Regex, Match
};
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -24,6 +28,7 @@ use std::{
}; };
use serenity::framework::standard::CommandFn; use serenity::framework::standard::CommandFn;
use crate::SQLPool;
#[derive(Debug)] #[derive(Debug)]
pub enum PermissionLevel { pub enum PermissionLevel {
@ -90,11 +95,20 @@ impl RegexFramework {
} }
pub fn build(mut self) -> Self { pub fn build(mut self) -> Self {
let command_names = self.commands let command_names;
{
let mut command_names_vec = self.commands
.keys() .keys()
.map(|k| &k[..]) .map(|k| &k[..])
.collect::<Vec<&str>>() .collect::<Vec<&str>>();
.join("|");
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
command_names = command_names_vec.join("|");
}
info!("Command names: {}", command_names);
let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"# let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
.replace("COMMANDS", command_names.as_str()) .replace("COMMANDS", command_names.as_str())
@ -135,6 +149,34 @@ impl Framework for RegexFramework {
}) })
} }
async fn check_prefix(ctx: &Context, guild_id: u64, prefix_opt: Option<Match<'_>>) -> bool {
if let Some(prefix) = prefix_opt {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
match sqlx::query!("SELECT prefix FROM guilds WHERE id = ?", guild_id)
.fetch_one(&pool)
.await {
Ok(row) => {
prefix.as_str() == row.prefix
}
Err(sqlx::Error::RowNotFound) => {
prefix.as_str() == "$"
}
Err(e) => {
warn!("Unexpected error in prefix query: {:?}", e);
false
}
}
}
else {
true
}
}
// gate to prevent analysing messages unnecessarily // gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) || if (msg.author.bot && self.ignore_bots) ||
msg.tts || msg.tts ||
@ -147,17 +189,17 @@ impl Framework for RegexFramework {
// Guild Command // Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) { else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) {
if let Some(full_match) = self.regex_matcher.captures(msg.content.as_str()) { if let Some(full_match) = self.regex_matcher.captures(&msg.content[..]) {
if check_prefix(&ctx, *guild.id.as_u64(), full_match.name("prefix")).await {
debug!("Prefix matched on {}", msg.content);
match check_self_permissions(&ctx, &guild, &channel).await { match check_self_permissions(&ctx, &guild, &channel).await {
Ok(perms) => match perms { Ok(perms) => match perms {
PermissionCheck::All => { PermissionCheck::All => {}
} PermissionCheck::Basic => {}
PermissionCheck::Basic => {
}
PermissionCheck::None => { PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id); warn!("Missing enough permissions for guild {}", guild.id);
@ -170,6 +212,7 @@ impl Framework for RegexFramework {
} }
} }
} }
}
// DM Command // DM Command
else { else {

View File

@ -21,6 +21,7 @@ use regex_command_attr::command;
use sqlx::{ use sqlx::{
Pool, Pool,
mysql::{ mysql::{
MySqlPool,
MySqlConnection, MySqlConnection,
} }
}; };
@ -46,14 +47,12 @@ impl TypeMapKey for ReqwestClient {
type Value = Arc<reqwest::Client>; type Value = Arc<reqwest::Client>;
} }
static THEME_COLOR: u32 = 0x00e0f3; static THEME_COLOR: u32 = 0x8fb677;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
dotenv()?; dotenv()?;
println!("{:?}", HELP_COMMAND);
let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?) let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?)
.ignore_bots(true) .ignore_bots(true)
.default_prefix("$") .default_prefix("$")
@ -66,6 +65,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.framework(framework) .framework(framework)
.await.expect("Error occurred creating client"); .await.expect("Error occurred creating client");
{
let pool = MySqlPool::new(&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment")).await.unwrap();
let mut data = client.data.write().await;
data.insert::<SQLPool>(pool);
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));
}
client.start_autosharded().await?; client.start_autosharded().await?;
Ok(()) Ok(())