check prefix and order commands properly for regex
This commit is contained in:
parent
c2c5e79940
commit
305cf79ac8
@ -14,9 +14,13 @@ use serenity::{
|
|||||||
use log::{
|
use log::{
|
||||||
warn,
|
warn,
|
||||||
error,
|
error,
|
||||||
|
debug,
|
||||||
|
info,
|
||||||
};
|
};
|
||||||
|
|
||||||
use regex::Regex;
|
use regex::{
|
||||||
|
Regex, Match
|
||||||
|
};
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
@ -24,6 +28,7 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use serenity::framework::standard::CommandFn;
|
use serenity::framework::standard::CommandFn;
|
||||||
|
use crate::SQLPool;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum PermissionLevel {
|
pub enum PermissionLevel {
|
||||||
@ -90,11 +95,20 @@ impl RegexFramework {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn build(mut self) -> Self {
|
pub fn build(mut self) -> Self {
|
||||||
let command_names = self.commands
|
let command_names;
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut command_names_vec = self.commands
|
||||||
.keys()
|
.keys()
|
||||||
.map(|k| &k[..])
|
.map(|k| &k[..])
|
||||||
.collect::<Vec<&str>>()
|
.collect::<Vec<&str>>();
|
||||||
.join("|");
|
|
||||||
|
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
|
||||||
|
|
||||||
|
command_names = command_names_vec.join("|");
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Command names: {}", command_names);
|
||||||
|
|
||||||
let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
|
let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
|
||||||
.replace("COMMANDS", command_names.as_str())
|
.replace("COMMANDS", command_names.as_str())
|
||||||
@ -135,6 +149,34 @@ impl Framework for RegexFramework {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn check_prefix(ctx: &Context, guild_id: u64, prefix_opt: Option<Match<'_>>) -> bool {
|
||||||
|
if let Some(prefix) = prefix_opt {
|
||||||
|
let pool = ctx.data.read().await
|
||||||
|
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
|
||||||
|
|
||||||
|
match sqlx::query!("SELECT prefix FROM guilds WHERE id = ?", guild_id)
|
||||||
|
.fetch_one(&pool)
|
||||||
|
.await {
|
||||||
|
Ok(row) => {
|
||||||
|
prefix.as_str() == row.prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(sqlx::Error::RowNotFound) => {
|
||||||
|
prefix.as_str() == "$"
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Unexpected error in prefix query: {:?}", e);
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// gate to prevent analysing messages unnecessarily
|
// gate to prevent analysing messages unnecessarily
|
||||||
if (msg.author.bot && self.ignore_bots) ||
|
if (msg.author.bot && self.ignore_bots) ||
|
||||||
msg.tts ||
|
msg.tts ||
|
||||||
@ -147,17 +189,17 @@ impl Framework for RegexFramework {
|
|||||||
// Guild Command
|
// Guild Command
|
||||||
else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) {
|
else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) {
|
||||||
|
|
||||||
if let Some(full_match) = self.regex_matcher.captures(msg.content.as_str()) {
|
if let Some(full_match) = self.regex_matcher.captures(&msg.content[..]) {
|
||||||
|
|
||||||
|
if check_prefix(&ctx, *guild.id.as_u64(), full_match.name("prefix")).await {
|
||||||
|
|
||||||
|
debug!("Prefix matched on {}", msg.content);
|
||||||
|
|
||||||
match check_self_permissions(&ctx, &guild, &channel).await {
|
match check_self_permissions(&ctx, &guild, &channel).await {
|
||||||
Ok(perms) => match perms {
|
Ok(perms) => match perms {
|
||||||
PermissionCheck::All => {
|
PermissionCheck::All => {}
|
||||||
|
|
||||||
}
|
PermissionCheck::Basic => {}
|
||||||
|
|
||||||
PermissionCheck::Basic => {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
PermissionCheck::None => {
|
PermissionCheck::None => {
|
||||||
warn!("Missing enough permissions for guild {}", guild.id);
|
warn!("Missing enough permissions for guild {}", guild.id);
|
||||||
@ -170,6 +212,7 @@ impl Framework for RegexFramework {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DM Command
|
// DM Command
|
||||||
else {
|
else {
|
||||||
|
14
src/main.rs
14
src/main.rs
@ -21,6 +21,7 @@ use regex_command_attr::command;
|
|||||||
use sqlx::{
|
use sqlx::{
|
||||||
Pool,
|
Pool,
|
||||||
mysql::{
|
mysql::{
|
||||||
|
MySqlPool,
|
||||||
MySqlConnection,
|
MySqlConnection,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -46,14 +47,12 @@ impl TypeMapKey for ReqwestClient {
|
|||||||
type Value = Arc<reqwest::Client>;
|
type Value = Arc<reqwest::Client>;
|
||||||
}
|
}
|
||||||
|
|
||||||
static THEME_COLOR: u32 = 0x00e0f3;
|
static THEME_COLOR: u32 = 0x8fb677;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
dotenv()?;
|
dotenv()?;
|
||||||
|
|
||||||
println!("{:?}", HELP_COMMAND);
|
|
||||||
|
|
||||||
let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?)
|
let framework = RegexFramework::new(env::var("CLIENT_ID").expect("Missing CLIENT_ID from environment").parse()?)
|
||||||
.ignore_bots(true)
|
.ignore_bots(true)
|
||||||
.default_prefix("$")
|
.default_prefix("$")
|
||||||
@ -66,6 +65,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||||||
.framework(framework)
|
.framework(framework)
|
||||||
.await.expect("Error occurred creating client");
|
.await.expect("Error occurred creating client");
|
||||||
|
|
||||||
|
{
|
||||||
|
let pool = MySqlPool::new(&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment")).await.unwrap();
|
||||||
|
|
||||||
|
let mut data = client.data.write().await;
|
||||||
|
|
||||||
|
data.insert::<SQLPool>(pool);
|
||||||
|
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));
|
||||||
|
}
|
||||||
|
|
||||||
client.start_autosharded().await?;
|
client.start_autosharded().await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Loading…
Reference in New Issue
Block a user