Readded some guild data code. fixed some weird cases with macro command. removed restrict command. changed db to be 'as it was'. removed execution limiters since commands are quite heavily ratelimited anyway

This commit is contained in:
jellywx 2021-10-30 20:57:33 +01:00
parent db7cca6296
commit 72228911f2
12 changed files with 150 additions and 368 deletions

4
Cargo.lock generated
View File

@ -177,7 +177,7 @@ dependencies = [
[[package]]
name = "command_attr"
version = "0.3.7"
source = "git+https://github.com/serenity-rs/serenity?branch=next#29dd43adeae81861613930e6d6385cd2497018de"
source = "git+https://github.com/serenity-rs/serenity?branch=next#d1f944b0729a83d60925ce49b7d83eeaed83bd73"
dependencies = [
"proc-macro2",
"quote",
@ -1452,7 +1452,7 @@ dependencies = [
[[package]]
name = "serenity"
version = "0.10.9"
source = "git+https://github.com/serenity-rs/serenity?branch=next#29dd43adeae81861613930e6d6385cd2497018de"
source = "git+https://github.com/serenity-rs/serenity?branch=next#d1f944b0729a83d60925ce49b7d83eeaed83bd73"
dependencies = [
"async-trait",
"async-tungstenite",

View File

@ -22,7 +22,6 @@ serde_repr = "0.1"
rmp-serde = "0.15"
rand = "0.7"
levenshtein = "1.0"
# serenity = { path = "/home/jude/serenity", features = ["collector", "unstable_discord_api"] }
serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["collector", "unstable_discord_api"] }
sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]}
base64 = "0.13.0"

View File

@ -41,7 +41,5 @@ __Other Variables__
### Todo List
* Convert aliases to macros
* Block users from interacting with another users' components
* Help command
* Change all db keys to be discord IDs
* Test everything

View File

@ -56,8 +56,7 @@ CREATE TABLE reminders_new (
-- , CONSTRAINT interval_enabled_mutin CHECK (`enabled` = 1 OR `interval` IS NULL)
# disallow an expiry time if interval is unspecified
-- , CONSTRAINT interval_expires_mutin CHECK (`expires` IS NULL OR `interval` IS NOT NULL)
)
COLLATE utf8mb4_unicode_ci;
);
# import data from other tables
INSERT INTO reminders_new (

View File

@ -1,33 +1,13 @@
USE reminders;
CREATE TABLE macro (
id INT UNSIGNED AUTO_INCREMENT,
guild_id BIGINT UNSIGNED NOT NULL,
guild_id INT UNSIGNED NOT NULL,
name VARCHAR(100) NOT NULL,
description VARCHAR(100),
commands TEXT NOT NULL,
FOREIGN KEY (guild_id) REFERENCES guilds(guild) ON DELETE CASCADE,
FOREIGN KEY (guild_id) REFERENCES guilds(id) ON DELETE CASCADE,
PRIMARY KEY (id)
);
DROP TABLE IF EXISTS events;
CREATE TABLE reminders.todos_new (
id INT UNSIGNED AUTO_INCREMENT UNIQUE NOT NULL,
user_id BIGINT UNSIGNED,
guild_id BIGINT UNSIGNED,
channel_id BIGINT UNSIGNED,
value VARCHAR(2000) NOT NULL,
PRIMARY KEY (id),
INDEX (user_id),
INDEX (guild_id),
INDEX (channel_id)
);
INSERT INTO reminders.todos_new (user_id, guild_id, channel_id, value)
SELECT users.user, guilds.guild, channels.channel, todos.value
FROM todos
INNER JOIN users ON users.id = todos.user_id
INNER JOIN guilds ON guilds.id = todos.guild_id
INNER JOIN channels ON channels.id = todos.channel_id;

View File

@ -127,82 +127,6 @@ You may want to use one of the popular timezones below, otherwise click [here](h
}
}
#[command("restrict")]
#[description("Configure which roles can use commands on the bot")]
#[arg(
name = "role",
description = "The role to configure command permissions for",
kind = "Role",
required = true
)]
#[supports_dm(false)]
#[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn restrict(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let framework = ctx.data.read().await.get::<RegexFramework>().cloned().unwrap();
if let Some(OptionValue::Role(role)) = args.get("role") {
let restricted_commands =
sqlx::query!("SELECT command FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?)", role.0)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| row.command.clone())
.collect::<Vec<String>>();
let restrictable_commands = framework
.commands
.iter()
.filter(|c| c.hooks.contains(&&CHECK_MANAGED_PERMISSIONS_HOOK))
.map(|c| c.names[0].to_string())
.collect::<Vec<String>>();
let len = restrictable_commands.len();
let restrict_pl = ComponentDataModel::Restrict(Restrict {
role_id: *role,
author_id: invoke.author_id(),
guild_id: invoke.guild_id().unwrap(),
});
invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new()
.content(format!(
"Select the commands to allow to {} from below:",
role.mention()
))
.components(|c| {
c.create_action_row(|row| {
row.create_select_menu(|select| {
select
.custom_id(restrict_pl.to_custom_id())
.options(|options| {
for command in restrictable_commands {
options.create_option(|opt| {
opt.label(&command)
.value(&command)
.default_selection(
restricted_commands.contains(&command),
)
});
}
options
})
.min_values(0)
.max_values(len as u64)
})
})
}),
)
.await
.unwrap();
}
}
#[command("macro")]
#[description("Record and replay command sequences")]
#[subcommand("record")]
@ -231,24 +155,53 @@ async fn macro_cmd(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptio
match args.subcommand.clone().unwrap().as_str() {
"record" => {
let macro_buffer = ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
{
let mut lock = macro_buffer.write().await;
let guild_id = invoke.guild_id().unwrap();
let name = args.get("name").unwrap().to_string();
let row = sqlx::query!(
"SELECT 1 as _e FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?",
guild_id.0,
name
)
.fetch_one(&pool)
.await;
if row.is_ok() {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().ephemeral().embed(|e| {
e
.title("Unique Name Required")
.description("A macro already exists under this name. Please select a unique name for your macro.")
.color(*THEME_COLOR)
}),
)
.await;
} else {
let macro_buffer = ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
let okay = {
let mut lock = macro_buffer.write().await;
if lock.contains_key(&(guild_id, invoke.author_id())) {
false
} else {
lock.insert(
(guild_id, invoke.author_id()),
CommandMacro {
guild_id,
name: args.get("name").unwrap().to_string(),
name,
description: args.get("description").map(|d| d.to_string()),
commands: vec![],
},
);
true
}
};
if okay {
let _ = invoke
.respond(
&ctx,
@ -262,6 +215,22 @@ Any commands ran as part of recording will be inconsequential")
}),
)
.await;
} else {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().ephemeral().embed(|e| {
e.title("Macro Already Recording")
.description(
"You are already recording a macro in this server.
Please use `/macro finish` to end this recording before starting another.",
)
.color(*THEME_COLOR)
}),
)
.await;
}
}
}
"finish" => {
let key = (invoke.guild_id().unwrap(), invoke.author_id());
@ -287,7 +256,7 @@ Any commands ran as part of recording will be inconsequential")
let json = serde_json::to_string(&command_macro.commands).unwrap();
sqlx::query!(
"INSERT INTO macro (guild_id, name, description, commands) VALUES (?, ?, ?, ?)",
"INSERT INTO macro (guild_id, name, description, commands) VALUES ((SELECT id FROM guilds WHERE guild = ?), ?, ?, ?)",
command_macro.guild_id.0,
command_macro.name,
command_macro.description,
@ -326,7 +295,7 @@ Any commands ran as part of recording will be inconsequential")
let macro_name = args.get("name").unwrap().to_string();
match sqlx::query!(
"SELECT commands FROM macro WHERE guild_id = ? AND name = ?",
"SELECT commands FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?",
invoke.guild_id().unwrap().0,
macro_name
)
@ -364,7 +333,7 @@ Any commands ran as part of recording will be inconsequential")
let macro_name = args.get("name").unwrap().to_string();
match sqlx::query!(
"SELECT id FROM macro WHERE guild_id = ? AND name = ?",
"SELECT id FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?",
invoke.guild_id().unwrap().0,
macro_name
)

View File

@ -71,7 +71,7 @@ async fn todo(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) {
let task = task.to_string();
sqlx::query!(
"INSERT INTO todos (user_id, channel_id, guild_id, value) VALUES (?, ?, ?, ?)",
"INSERT INTO todos (user_id, channel_id, guild_id, value) VALUES ((SELECT id FROM users WHERE user = ?), (SELECT id FROM channels WHERE channel = ?), (SELECT id FROM guilds WHERE guild = ?), ?)",
keys.0,
keys.1,
keys.2,
@ -88,7 +88,11 @@ async fn todo(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) {
None => {
let values = sqlx::query!(
// fucking braindead mysql use <=> instead of = for null comparison
"SELECT id, value FROM todos WHERE user_id <=> ? AND channel_id <=> ? AND guild_id <=> ?",
"SELECT todos.id, value FROM todos
INNER JOIN users ON todos.user_id = users.id
INNER JOIN channels ON todos.channel_id = channels.id
INNER JOIN guilds ON todos.guild_id = guilds.id
WHERE users.user <=> ? AND channels.channel <=> ? AND guilds.guild <=> ?",
keys.0,
keys.1,
keys.2,

View File

@ -33,7 +33,6 @@ use crate::{
#[serde(tag = "type")]
#[repr(u8)]
pub enum ComponentDataModel {
Restrict(Restrict),
LookPager(LookPager),
DelPager(DelPager),
TodoPager(TodoPager),
@ -57,54 +56,6 @@ impl ComponentDataModel {
pub async fn act(&self, ctx: &Context, component: MessageComponentInteraction) {
match self {
ComponentDataModel::Restrict(restrict) => {
if restrict.author_id == component.user.id {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!(
"
INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id FROM guilds WHERE guild = ?))
",
restrict.role_id.0,
restrict.guild_id.0
)
.execute(&pool)
.await;
for command in &component.data.values {
let _ = sqlx::query!(
"INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM roles WHERE role = ?), ?)",
restrict.role_id.0,
command
)
.execute(&pool)
.await;
}
component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|response| response
.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
.content("Role permissions updated")
)
})
.await
.unwrap();
} else {
let _ = component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.flags(
InteractionApplicationCommandCallbackDataFlags::EPHEMERAL,
)
.content("Only the user who performed the command can use these components")
})
})
.await;
}
}
ComponentDataModel::LookPager(pager) => {
let flags = pager.flags;
@ -315,13 +266,6 @@ INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id
}
}
#[derive(Serialize, Deserialize)]
pub struct Restrict {
pub role_id: RoleId,
pub author_id: UserId,
pub guild_id: GuildId,
}
#[derive(Serialize, Deserialize)]
pub struct DelSelector {
pub page: usize,

View File

@ -7,7 +7,6 @@ use std::{
};
use log::info;
use regex::{Regex, RegexBuilder};
use serde::{Deserialize, Serialize};
use serenity::{
async_trait,
@ -34,7 +33,7 @@ use serenity::{
Result as SerenityResult,
};
use crate::LimitExecutors;
use crate::SQLPool;
pub struct CreateGenericResponse {
content: String,
@ -512,12 +511,7 @@ impl Eq for Command {}
pub struct RegexFramework {
pub commands_map: HashMap<String, &'static Command>,
pub commands: HashSet<&'static Command>,
command_matcher: Regex,
dm_regex_matcher: Regex,
default_prefix: String,
client_id: u64,
ignore_bots: bool,
case_insensitive: bool,
dm_enabled: bool,
debug_guild: Option<GuildId>,
hooks: Vec<&'static Hook>,
@ -528,34 +522,17 @@ impl TypeMapKey for RegexFramework {
}
impl RegexFramework {
pub fn new<T: Into<u64>>(client_id: T) -> Self {
pub fn new() -> Self {
Self {
commands_map: HashMap::new(),
commands: HashSet::new(),
command_matcher: Regex::new(r#"^$"#).unwrap(),
dm_regex_matcher: Regex::new(r#"^$"#).unwrap(),
default_prefix: "".to_string(),
client_id: client_id.into(),
ignore_bots: true,
case_insensitive: true,
dm_enabled: true,
debug_guild: None,
hooks: vec![],
}
}
pub fn case_insensitive(mut self, case_insensitive: bool) -> Self {
self.case_insensitive = case_insensitive;
self
}
pub fn default_prefix<T: ToString>(mut self, new_prefix: T) -> Self {
self.default_prefix = new_prefix.to_string();
self
}
pub fn ignore_bots(mut self, ignore_bots: bool) -> Self {
self.ignore_bots = ignore_bots;
@ -590,68 +567,6 @@ impl RegexFramework {
self
}
pub fn build(mut self) -> Self {
{
let command_names;
{
let mut command_names_vec =
self.commands_map.keys().map(|k| &k[..]).collect::<Vec<&str>>();
command_names_vec.sort_unstable_by_key(|a| a.len());
command_names = command_names_vec.join("|");
}
info!("Command names: {}", command_names);
{
let match_string =
r#"^(?:(?:<@ID>\s*)|(?:<@!ID>\s*)|(?P<prefix>\S{1,5}?))(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
.replace("COMMANDS", command_names.as_str())
.replace("ID", self.client_id.to_string().as_str());
self.command_matcher = RegexBuilder::new(match_string.as_str())
.case_insensitive(self.case_insensitive)
.dot_matches_new_line(true)
.build()
.unwrap();
}
}
{
let dm_command_names;
{
let mut command_names_vec = self
.commands_map
.iter()
.filter_map(
|(key, command)| if command.supports_dm { Some(&key[..]) } else { None },
)
.collect::<Vec<&str>>();
command_names_vec.sort_unstable_by_key(|a| a.len());
dm_command_names = command_names_vec.join("|");
}
{
let match_string = r#"^(?:(?:<@ID>\s+)|(?:<@!ID>\s+)|(\$)|())(?P<cmd>COMMANDS)(?:$|\s+(?P<args>.*))$"#
.replace("COMMANDS", dm_command_names.as_str())
.replace("ID", self.client_id.to_string().as_str());
self.dm_regex_matcher = RegexBuilder::new(match_string.as_str())
.case_insensitive(self.case_insensitive)
.dot_matches_new_line(true)
.build()
.unwrap();
}
}
self
}
fn _populate_commands<'a>(
&self,
commands: &'a mut CreateApplicationCommands,
@ -721,6 +636,15 @@ impl RegexFramework {
}
pub async fn execute(&self, ctx: Context, interaction: ApplicationCommandInteraction) {
{
if let Some(guild_id) = interaction.guild_id {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("INSERT IGNORE INTO guilds (guild) VALUES (?)", guild_id.0)
.execute(&pool)
.await;
}
}
let command = {
self.commands_map
.get(&interaction.data.name)
@ -748,18 +672,10 @@ impl RegexFramework {
}
}
let user_id = command_invoke.author_id();
if !ctx.check_executing(user_id).await {
ctx.set_executing(user_id).await;
match command.fun {
CommandFnType::Slash(t) => t(&ctx, &mut command_invoke, args).await,
CommandFnType::Multi(m) => m(&ctx, &mut command_invoke).await,
}
ctx.drop_executing(user_id).await;
}
}
pub async fn run_command_from_options(

View File

@ -20,6 +20,14 @@ pub async fn macro_check(
let mut lock = active_recordings.write().await;
if let Some(command_macro) = lock.get_mut(&(guild_id, invoke.author_id())) {
if command_macro.commands.len() >= 5 {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("5 commands already recorded. Please use `/macro finish` to end recording."),
)
.await;
} else {
command_macro.commands.push(args.clone());
let _ = invoke
@ -28,6 +36,7 @@ pub async fn macro_check(
CreateGenericResponse::new().content("Command recorded to macro"),
)
.await;
}
HookResult::Halt
} else {

View File

@ -10,7 +10,7 @@ mod hooks;
mod models;
mod time_parser;
use std::{collections::HashMap, env, sync::Arc, time::Instant};
use std::{collections::HashMap, env, sync::Arc};
use chrono_tz::Tz;
use dotenv::dotenv;
@ -18,12 +18,11 @@ use log::info;
use serenity::{
async_trait,
client::{bridge::gateway::GatewayIntents, Client},
futures::TryFutureExt,
http::{client::Http, CacheHttp},
model::{
channel::GuildChannel,
gateway::{Activity, Ready},
guild::Guild,
guild::{Guild, GuildUnavailable},
id::{GuildId, UserId},
interactions::Interaction,
},
@ -59,55 +58,12 @@ impl TypeMapKey for PopularTimezones {
type Value = Arc<Vec<Tz>>;
}
struct CurrentlyExecuting;
impl TypeMapKey for CurrentlyExecuting {
type Value = Arc<RwLock<HashMap<UserId, Instant>>>;
}
struct RecordingMacros;
impl TypeMapKey for RecordingMacros {
type Value = Arc<RwLock<HashMap<(GuildId, UserId), CommandMacro>>>;
}
#[async_trait]
trait LimitExecutors {
async fn check_executing(&self, user: UserId) -> bool;
async fn set_executing(&self, user: UserId);
async fn drop_executing(&self, user: UserId);
}
#[async_trait]
impl LimitExecutors for Context {
async fn check_executing(&self, user: UserId) -> bool {
let currently_executing =
self.data.read().await.get::<CurrentlyExecuting>().cloned().unwrap();
let lock = currently_executing.read().await;
lock.get(&user).map_or(false, |now| now.elapsed().as_secs() < 4)
}
async fn set_executing(&self, user: UserId) {
let currently_executing =
self.data.read().await.get::<CurrentlyExecuting>().cloned().unwrap();
let mut lock = currently_executing.write().await;
lock.insert(user, Instant::now());
}
async fn drop_executing(&self, user: UserId) {
let currently_executing =
self.data.read().await.get::<CurrentlyExecuting>().cloned().unwrap();
let mut lock = currently_executing.write().await;
lock.remove(&user);
}
}
struct Handler;
#[async_trait]
@ -148,6 +104,14 @@ DELETE FROM channels WHERE channel = ?
if is_new {
let guild_id = guild.id.as_u64().to_owned();
{
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id)
.execute(&pool)
.await;
}
if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
let shard_count = ctx.cache.shard_count();
let current_shard_id = shard_id(guild_id, shard_count);
@ -192,6 +156,13 @@ DELETE FROM channels WHERE channel = ?
}
}
async fn guild_delete(&self, ctx: Context, incomplete: GuildUnavailable, _full: Option<Guild>) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0)
.execute(&pool)
.await;
}
async fn ready(&self, ctx: Context, _: Ready) {
ctx.set_activity(Activity::watching("for /remind")).await;
}
@ -199,10 +170,6 @@ DELETE FROM channels WHERE channel = ?
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
match interaction {
Interaction::ApplicationCommand(application_command) => {
if application_command.guild_id.is_none() {
return;
}
let framework = ctx
.data
.read()
@ -232,14 +199,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let http = Http::new_with_token(&token);
let logged_in_id = http.get_current_user().map_ok(|user| user.id.as_u64().to_owned()).await?;
let application_id = http.get_current_application_info().await?.id;
let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1");
let framework = RegexFramework::new(logged_in_id)
.default_prefix("")
.case_insensitive(env::var("CASE_INSENSITIVE").map_or(true, |var| var == "1"))
let framework = RegexFramework::new()
.ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
.debug_guild(env::var("DEBUG_GUILD").map_or(None, |g| {
Some(GuildId(g.parse::<u64>().expect("DEBUG_GUILD must be a guild ID")))
@ -263,12 +227,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// to-do commands
.add_command(&todo_cmds::TODO_COMMAND)
// moderation commands
.add_command(&moderation_cmds::RESTRICT_COMMAND)
.add_command(&moderation_cmds::TIMEZONE_COMMAND)
.add_command(&moderation_cmds::MACRO_CMD_COMMAND)
.add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK)
.add_hook(&hooks::MACRO_CHECK_HOOK)
.build();
.add_hook(&hooks::MACRO_CHECK_HOOK);
let framework_arc = Arc::new(framework);
@ -305,7 +267,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut data = client.data.write().await;
data.insert::<CurrentlyExecuting>(Arc::new(RwLock::new(HashMap::new())));
data.insert::<SQLPool>(pool);
data.insert::<PopularTimezones>(Arc::new(popular_timezones));
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));

View File

@ -14,13 +14,16 @@ impl CommandMacro {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let guild_id = guild_id.into();
sqlx::query!("SELECT * FROM macro WHERE guild_id = ?", guild_id.0)
sqlx::query!(
"SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
guild_id.0
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| Self {
guild_id: GuildId(row.guild_id),
guild_id,
name: row.name.clone(),
description: row.description.clone(),
commands: serde_json::from_str(&row.commands).unwrap(),