move to regex framework to hopefully reduce bad cpu load

This commit is contained in:
2021-01-26 01:32:43 +00:00
parent d57e1d3ab1
commit 0ea979a2b7
12 changed files with 1277 additions and 234 deletions

339
src/framework.rs Normal file
View File

@ -0,0 +1,339 @@
use serenity::{
async_trait,
client::Context,
constants::MESSAGE_CODE_LIMIT,
framework::{standard::Args, Framework},
futures::prelude::future::BoxFuture,
http::Http,
model::{
channel::{Channel, GuildChannel, Message},
guild::{Guild, Member},
id::ChannelId,
},
Result as SerenityResult,
};
use log::{error, info, warn};
use regex::{Match, Regex, RegexBuilder};
use std::{collections::HashMap, fmt};
use crate::{guild_data::GuildData, MySQL};
use serenity::framework::standard::{CommandResult, Delimiter};
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, Args) -> BoxFuture<'fut, CommandResult>;
#[derive(Debug, PartialEq)]
pub enum PermissionLevel {
Unrestricted,
Managed,
Restricted,
}
pub struct Command {
pub name: &'static str,
pub required_perms: PermissionLevel,
pub func: CommandFn,
}
impl Command {
async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
if self.required_perms == PermissionLevel::Unrestricted {
true
} else {
let permissions = guild.member_permissions(&ctx, &member.user).await.unwrap();
if permissions.manage_guild() && self.required_perms == PermissionLevel::Managed {
return true;
}
if self.required_perms == PermissionLevel::Managed {
let pool = ctx
.data
.read()
.await
.get::<MySQL>()
.cloned()
.expect("Could not get SQLPool from data");
match sqlx::query!(
"
SELECT role
FROM roles
WHERE guild_id = ?
",
guild.id.as_u64()
)
.fetch_all(&pool)
.await
{
Ok(rows) => {
let role_ids = member
.roles
.iter()
.map(|r| *r.as_u64())
.collect::<Vec<u64>>();
for row in rows {
if role_ids.contains(&row.role) || &row.role == guild.id.as_u64() {
return true;
}
}
false
}
Err(sqlx::Error::RowNotFound) => false,
Err(e) => {
warn!("Unexpected error occurred querying roles: {:?}", e);
false
}
}
} else {
false
}
}
}
}
impl fmt::Debug for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Command")
.field("name", &self.name)
.field("required_perms", &self.required_perms)
.finish()
}
}
#[async_trait]
pub trait SendIterator {
async fn say_lines(
self,
http: impl AsRef<Http> + Send + Sync + 'async_trait,
content: impl Iterator<Item = String> + Send + 'async_trait,
) -> SerenityResult<()>;
}
#[async_trait]
impl SendIterator for ChannelId {
async fn say_lines(
self,
http: impl AsRef<Http> + Send + Sync + 'async_trait,
content: impl Iterator<Item = String> + Send + 'async_trait,
) -> SerenityResult<()> {
let mut current_content = String::new();
for line in content {
if current_content.len() + line.len() > MESSAGE_CODE_LIMIT as usize {
self.send_message(&http, |m| {
m.allowed_mentions(|am| am.empty_parse())
.content(&current_content)
})
.await?;
current_content = line;
} else {
current_content = format!("{}\n{}", current_content, line);
}
}
if !current_content.is_empty() {
self.send_message(&http, |m| {
m.allowed_mentions(|am| am.empty_parse())
.content(&current_content)
})
.await?;
}
Ok(())
}
}
pub struct RegexFramework {
commands: HashMap<String, &'static Command>,
command_matcher: Regex,
default_prefix: String,
client_id: u64,
ignore_bots: bool,
case_insensitive: bool,
}
impl RegexFramework {
pub fn new<T: Into<u64>>(client_id: T) -> Self {
Self {
commands: HashMap::new(),
command_matcher: Regex::new(r#"^$"#).unwrap(),
default_prefix: "".to_string(),
client_id: client_id.into(),
ignore_bots: true,
case_insensitive: true,
}
}
pub fn case_insensitive(mut self, case_insensitive: bool) -> Self {
self.case_insensitive = case_insensitive;
self
}
pub fn default_prefix<T: ToString>(mut self, new_prefix: T) -> Self {
self.default_prefix = new_prefix.to_string();
self
}
pub fn ignore_bots(mut self, ignore_bots: bool) -> Self {
self.ignore_bots = ignore_bots;
self
}
pub fn add_command<S: ToString>(mut self, name: S, command: &'static Command) -> Self {
self.commands.insert(name.to_string(), command);
self
}
pub fn build(mut self) -> Self {
let command_names;
{
let mut command_names_vec = self.commands.keys().map(|k| &k[..]).collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
command_names = command_names_vec.join("|");
}
info!("Command names: {}", command_names);
{
let match_string = r#"^(?:(?:<@ID>\s*)|(?:<@!ID>\s*)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
.replace("COMMANDS", command_names.as_str())
.replace("ID", self.client_id.to_string().as_str());
self.command_matcher = RegexBuilder::new(match_string.as_str())
.case_insensitive(self.case_insensitive)
.dot_matches_new_line(true)
.build()
.unwrap();
}
self
}
}
enum PermissionCheck {
None, // No permissions
All, // Sufficient permissions
}
#[async_trait]
impl Framework for RegexFramework {
async fn dispatch(&self, ctx: Context, msg: Message) {
async fn check_self_permissions(
ctx: &Context,
channel: &GuildChannel,
) -> SerenityResult<PermissionCheck> {
let user_id = ctx.cache.current_user_id().await;
let channel_perms = channel.permissions_for_user(ctx, user_id).await?;
Ok(
if channel_perms.send_messages() && channel_perms.embed_links() {
PermissionCheck::All
} else {
PermissionCheck::None
},
)
}
async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool {
if let Some(prefix) = prefix_opt {
let pool = ctx
.data
.read()
.await
.get::<MySQL>()
.cloned()
.expect("Could not get SQLPool from data");
let guild_prefix = match GuildData::get_from_id(guild.clone(), pool.clone()).await {
Some(guild_data) => guild_data.prefix,
None => {
GuildData::create_from_guild(guild, pool).await.unwrap();
String::from("?")
}
};
guild_prefix.as_str() == prefix.as_str()
} else {
true
}
}
// gate to prevent analysing messages unnecessarily
if msg.author.bot || msg.content.is_empty() {
}
// Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) =
(msg.guild(&ctx).await, msg.channel(&ctx).await)
{
if let Some(full_match) = self.command_matcher.captures(&msg.content) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
match check_self_permissions(&ctx, &channel).await {
Ok(perms) => match perms {
PermissionCheck::All => {
let command = self
.commands
.get(&full_match.name("cmd").unwrap().as_str().to_lowercase())
.unwrap();
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
let member = guild.member(&ctx, &msg.author).await.unwrap();
if command.check_permissions(&ctx, &guild, &member).await {
dbg!(command.name);
(command.func)(
&ctx,
&msg,
Args::new(&args, &[Delimiter::Single(' ')]),
)
.await
.unwrap();
} else if command.required_perms == PermissionLevel::Restricted {
let _ = msg.channel_id.say(&ctx, "You must either be an Admin or have a role specified in `?roles` to do this command").await;
} else if command.required_perms == PermissionLevel::Managed {
let _ = msg
.channel_id
.say(&ctx, "You must be an Admin to do this command")
.await;
}
}
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
},
Err(e) => {
error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
}
}
}
}
}
}
}

View File

@ -25,7 +25,7 @@ SELECT id, prefix, volume, allow_greets
match guild_data {
Ok(g) => Some(g),
Err(sqlx::Error::RowNotFound) => Self::create_from_guild(guild, db_pool).await.ok(),
Err(sqlx::Error::RowNotFound) => Self::create_from_guild(&guild, db_pool).await.ok(),
Err(e) => {
println!("{:?}", e);
@ -36,7 +36,7 @@ SELECT id, prefix, volume, allow_greets
}
pub async fn create_from_guild(
guild: Guild,
guild: &Guild,
db_pool: MySqlPool,
) -> Result<GuildData, Box<dyn std::error::Error + Send + Sync>> {
sqlx::query!(
@ -62,7 +62,7 @@ INSERT IGNORE INTO roles (guild_id, role)
.await?;
Ok(GuildData {
id: *guild.id.as_u64(),
id: guild.id.as_u64().to_owned(),
prefix: String::from("?"),
volume: 100,
allow_greets: true,

View File

@ -4,18 +4,18 @@ extern crate lazy_static;
extern crate reqwest;
mod error;
mod guilddata;
mod framework;
mod guild_data;
mod sound;
use guilddata::GuildData;
use guild_data::GuildData;
use sound::Sound;
use regex_command_attr::command;
use serenity::{
client::{bridge::gateway::GatewayIntents, Client, Context},
framework::standard::{
macros::{check, command, group, hook},
Args, CommandError, CommandResult, DispatchError, Reason, StandardFramework,
},
framework::standard::{Args, CommandResult},
http::Http,
model::{
channel::{Channel, Message},
@ -34,12 +34,11 @@ use songbird::{
Call, SerenityInit,
};
type CheckResult = Result<(), Reason>;
use sqlx::mysql::MySqlPool;
use dotenv::dotenv;
use crate::framework::RegexFramework;
use std::{collections::HashMap, env, sync::Arc, time::Duration};
use tokio::sync::MutexGuard;
@ -72,164 +71,6 @@ lazy_static! {
};
}
#[group]
#[commands(
info,
help,
list_sounds,
change_public,
search_sounds,
show_popular_sounds,
show_random_sounds,
set_greet_sound
)]
#[checks(self_perm_check)]
struct AllUsers;
#[group]
#[commands(play, upload_new_sound, change_volume, delete_sound, stop_playing)]
#[checks(self_perm_check, role_check)]
struct RoleManagedUsers;
#[group]
#[commands(change_prefix, set_allowed_roles, allow_greet_sounds)]
#[checks(self_perm_check, permission_check)]
struct PermissionManagedUsers;
#[check]
#[name("self_perm_check")]
async fn self_perm_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResult {
let channel_o = msg.channel(&ctx).await;
if let Some(channel_e) = channel_o {
if let Channel::Guild(channel) = channel_e {
let permissions_r = channel
.permissions_for_user(&ctx, &ctx.cache.current_user_id().await)
.await;
if let Ok(permissions) = permissions_r {
if permissions.send_messages() && permissions.embed_links() {
Ok(())
} else {
Err(Reason::Log(
"Bot does not have enough permissions".to_string(),
))
}
} else {
Err(Reason::Log("No perms found".to_string()))
}
} else {
Err(Reason::Log("No DM commands".to_string()))
}
} else {
Err(Reason::Log("Channel not available".to_string()))
}
}
#[check]
#[name("role_check")]
async fn role_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResult {
async fn check_for_roles(ctx: &&Context, msg: &&Message) -> CheckResult {
let pool = ctx
.data
.read()
.await
.get::<MySQL>()
.cloned()
.expect("Could not get SQLPool from data");
let guild_opt = msg.guild(&ctx).await;
match guild_opt {
Some(guild) => {
let member_res = guild.member(*ctx, msg.author.id).await;
match member_res {
Ok(member) => {
let user_roles: String = member
.roles
.iter()
.map(|r| (*r.as_u64()).to_string())
.collect::<Vec<String>>()
.join(", ");
let guild_id = *msg.guild_id.unwrap().as_u64();
let role_res = sqlx::query!(
"
SELECT COUNT(1) as count
FROM roles
WHERE
(guild_id = ? AND role IN (?)) OR
(role = ?)
",
guild_id,
user_roles,
guild_id
)
.fetch_one(&pool)
.await;
match role_res {
Ok(role_count) => {
if role_count.count > 0 {
Ok(())
}
else {
Err(Reason::User("User has not got a sufficient role. Use `?roles` to set up role restrictions".to_string()))
}
}
Err(_) => {
Err(Reason::User("User has not got a sufficient role. Use `?roles` to set up role restrictions".to_string()))
}
}
}
Err(_) => Err(Reason::User(
"Unexpected error looking up user roles".to_string(),
)),
}
}
None => Err(Reason::User(
"Unexpected error looking up guild".to_string(),
)),
}
}
if perform_permission_check(ctx, &msg).await.is_ok() {
Ok(())
} else {
check_for_roles(&ctx, &msg).await
}
}
#[check]
#[name("permission_check")]
async fn permission_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResult {
perform_permission_check(ctx, &msg).await
}
async fn perform_permission_check(ctx: &Context, msg: &&Message) -> CheckResult {
if let Some(guild) = msg.guild(&ctx).await {
if guild
.member_permissions(&ctx, &msg.author)
.await
.unwrap()
.manage_guild()
{
Ok(())
} else {
Err(Reason::User(String::from(
"User needs `Manage Guild` permission",
)))
}
} else {
Err(Reason::User(String::from("Guild not cached")))
}
}
// create event handler for bot
struct Handler;
@ -403,28 +244,6 @@ async fn join_channel(
}
}
#[hook]
async fn log_errors(_: &Context, m: &Message, cmd_name: &str, error: Result<(), CommandError>) {
if let Err(e) = error {
println!("Error in command {} ({}): {:?}", cmd_name, m.content, e);
}
}
#[hook]
async fn dispatch_error_hook(ctx: &Context, msg: &Message, error: DispatchError) {
match error {
DispatchError::CheckFailed(_f, reason) => {
if let Reason::User(description) = reason {
let _ = msg
.reply(ctx, format!("You cannot do this command: {}", description))
.await;
}
}
_ => {}
}
}
// entry point
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
@ -436,46 +255,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let logged_in_id = http.get_current_user().await?.id;
let framework = StandardFramework::new()
.configure(|c| {
c.dynamic_prefix(|ctx, msg| {
Box::pin(async move {
let pool = ctx
.data
.read()
.await
.get::<MySQL>()
.cloned()
.expect("Could not get SQLPool from data");
let guild = match msg.guild(&ctx.cache).await {
Some(guild) => guild,
None => {
return Some(String::from("?"));
}
};
match GuildData::get_from_id(guild.clone(), pool.clone()).await {
Some(guild_data) => Some(guild_data.prefix),
None => {
GuildData::create_from_guild(guild, pool).await.unwrap();
Some(String::from("?"))
}
}
})
})
.allow_dm(false)
.ignore_bots(true)
.ignore_webhooks(true)
.on_mention(Some(logged_in_id))
})
.group(&ALLUSERS_GROUP)
.group(&ROLEMANAGEDUSERS_GROUP)
.group(&PERMISSIONMANAGEDUSERS_GROUP)
.after(log_errors)
.on_dispatch_error(dispatch_error_hook);
let framework = RegexFramework::new(logged_in_id)
.default_prefix("?")
.case_insensitive(true)
.ignore_bots(true)
// info commands
.add_command("help", &HELP_COMMAND)
.add_command("info", &INFO_COMMAND)
.add_command("invite", &INFO_COMMAND)
.add_command("donate", &INFO_COMMAND)
.build();
let mut client =
Client::builder(&env::var("DISCORD_TOKEN").expect("Missing token from environment"))
@ -518,8 +307,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
#[command("play")]
#[aliases("p")]
#[command]
async fn play(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
let guild = match msg.guild(&ctx.cache).await {
Some(guild) => guild,

View File

@ -130,7 +130,7 @@ SELECT src
.await
.unwrap();
return record.src;
record.src
}
pub async fn store_sound_source(