Refactor macros

This commit is contained in:
jude 2024-02-06 20:08:59 +00:00
parent e4e9af2bb4
commit def43bfa78
14 changed files with 98 additions and 333 deletions

8
Cargo.lock generated
View File

@ -2067,9 +2067,9 @@ checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a"
[[package]]
name = "poise"
version = "0.6.1-rc1"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bde9f83da70341825e4116c06ffd5f1b23155121c2801ea29a8e178a7fc9857"
checksum = "1819d5a45e3590ef33754abce46432570c54a120798bdbf893112b4211fa09a6"
dependencies = [
"async-trait",
"derivative",
@ -2084,9 +2084,9 @@ dependencies = [
[[package]]
name = "poise_macros"
version = "0.6.1-rc1"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "130fd27280c82e5ab5b147838b5ff9f9da33603fbadfff8ff613de530c12922d"
checksum = "8fa2c123c961e78315cd3deac7663177f12be4460f5440dbf62a7ed37b1effea"
dependencies = [
"darling",
"proc-macro2",

View File

@ -7,7 +7,7 @@ license = "AGPL-3.0 only"
description = "Reminder Bot for Discord, now in Rust"
[dependencies]
poise = "0.6.1-rc1"
poise = "0.6.1"
dotenv = "0.15"
tokio = { version = "1", features = ["process", "full"] }
reqwest = "0.11"

View File

@ -17,7 +17,13 @@ pub async fn delete_macro(
) -> Result<(), Error> {
match sqlx::query!(
"
SELECT id FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?",
SELECT m.id
FROM macro m
INNER JOIN guilds
ON guilds.guild = m.guild_id
WHERE guild = ?
AND m.name = ?
",
ctx.guild_id().unwrap().get(),
name
)

View File

@ -28,11 +28,11 @@ pub async fn list_macro(ctx: Context<'_>) -> Result<(), Error> {
Ok(())
}
pub fn max_macro_page<U, E>(macros: &[CommandMacro<U, E>]) -> usize {
pub fn max_macro_page(macros: &[CommandMacro]) -> usize {
((macros.len() as f64) / 25.0).ceil() as usize
}
pub fn show_macro_page<U, E>(macros: &[CommandMacro<U, E>], page: usize) -> CreateReply {
pub fn show_macro_page(macros: &[CommandMacro], page: usize) -> CreateReply {
let pager = MacroPager::new(page);
if macros.is_empty() {

View File

@ -1,229 +0,0 @@
use lazy_regex::regex;
use poise::{serenity_prelude::CommandOptionType, CreateReply};
use regex::Captures;
use serde_json::{json, Value};
use crate::{models::command_macro::RawCommandMacro, Context, Error, GuildId};
struct Alias {
name: String,
command: String,
}
/// Migrate old $alias reminder commands to macros. Only macro names that are not taken will be used.
#[poise::command(
slash_command,
rename = "migrate",
guild_only = true,
default_member_permissions = "MANAGE_GUILD",
identifying_name = "migrate_macro"
)]
pub async fn migrate_macro(ctx: Context<'_>) -> Result<(), Error> {
let guild_id = ctx.guild_id().unwrap();
let mut transaction = ctx.data().database.begin().await?;
let aliases = sqlx::query_as!(
Alias,
"SELECT name, command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
guild_id.get()
)
.fetch_all(&mut *transaction)
.await?;
let mut added_aliases = 0;
for alias in aliases {
match parse_text_command(guild_id, alias.name, &alias.command) {
Some(cmd_macro) => {
sqlx::query!(
"INSERT INTO macro (guild_id, name, description, commands) VALUES ((SELECT id FROM guilds WHERE guild = ?), ?, ?, ?)",
cmd_macro.guild_id.get(),
cmd_macro.name,
cmd_macro.description,
cmd_macro.commands
)
.execute(&mut *transaction)
.await?;
added_aliases += 1;
}
None => {}
}
}
transaction.commit().await?;
ctx.send(CreateReply::default().content(format!("Added {} macros.", added_aliases))).await?;
Ok(())
}
fn parse_text_command(
guild_id: GuildId,
alias_name: String,
command: &str,
) -> Option<RawCommandMacro> {
match command.split_once(" ") {
Some((command_word, args)) => {
let command_word = command_word.to_lowercase();
if command_word == "r"
|| command_word == "i"
|| command_word == "remind"
|| command_word == "interval"
{
let matcher = regex!(
r#"(?P<mentions>(?:<@\d+>\s+|<@!\d+>\s+|<#\d+>\s+)*)(?P<time>(?:(?:\d+)(?:s|m|h|d|:|/|-|))+)(?:\s+(?P<interval>(?:(?:\d+)(?:s|m|h|d|))+))?(?:\s+(?P<expires>(?:(?:\d+)(?:s|m|h|d|:|/|-|))+))?\s+(?P<content>.*)"#s
);
match matcher.captures(&args) {
Some(captures) => {
let mut args: Vec<Value> = vec![];
if let Some(group) = captures.name("time") {
let content = group.as_str();
args.push(json!({
"name": "time",
"value": content,
"type": CommandOptionType::String,
}));
}
if let Some(group) = captures.name("content") {
let content = group.as_str();
args.push(json!({
"name": "content",
"value": content,
"type": CommandOptionType::String,
}));
}
if let Some(group) = captures.name("interval") {
let content = group.as_str();
args.push(json!({
"name": "interval",
"value": content,
"type": CommandOptionType::String,
}));
}
if let Some(group) = captures.name("expires") {
let content = group.as_str();
args.push(json!({
"name": "expires",
"value": content,
"type": CommandOptionType::String,
}));
}
if let Some(group) = captures.name("mentions") {
let content = group.as_str();
args.push(json!({
"name": "channels",
"value": content,
"type": CommandOptionType::String,
}));
}
Some(RawCommandMacro {
guild_id,
name: alias_name,
description: None,
commands: json!([
{
"command_name": "remind",
"options": args,
}
]),
})
}
None => None,
}
} else if command_word == "n" || command_word == "natural" {
let matcher_primary = regex!(
r#"(?P<time>.*?)(?:\s+)(?:send|say)(?:\s+)(?P<content>.*?)(?:(?:\s+)to(?:\s+)(?P<mentions>((?:<@\d+>)|(?:<@!\d+>)|(?:<#\d+>)|(?:\s+))+))?$"#s
);
let matcher_secondary = regex!(
r#"(?P<msg>.*)(?:\s+)every(?:\s+)(?P<interval>.*?)(?:(?:\s+)(?:until|for)(?:\s+)(?P<expires>.*?))?$"#s
);
match matcher_primary.captures(&args) {
Some(captures) => {
let captures_secondary = matcher_secondary.captures(&args);
let mut args: Vec<Value> = vec![];
if let Some(group) = captures.name("time") {
let content = group.as_str();
args.push(json!({
"name": "time",
"value": content,
"type": CommandOptionType::String,
}));
}
if let Some(group) = captures.name("content") {
let content = group.as_str();
args.push(json!({
"name": "content",
"value": content,
"type": CommandOptionType::String,
}));
}
if let Some(group) =
captures_secondary.as_ref().and_then(|c: &Captures| c.name("interval"))
{
let content = group.as_str();
args.push(json!({
"name": "interval",
"value": content,
"type": CommandOptionType::String,
}));
}
if let Some(group) =
captures_secondary.and_then(|c: Captures| c.name("expires"))
{
let content = group.as_str();
args.push(json!({
"name": "expires",
"value": content,
"type": CommandOptionType::String,
}));
}
if let Some(group) = captures.name("mentions") {
let content = group.as_str();
args.push(json!({
"name": "channels",
"value": content,
"type": CommandOptionType::String,
}));
}
Some(RawCommandMacro {
guild_id,
name: alias_name,
description: None,
commands: json!([
{
"command_name": "remind",
"options": args,
}
]),
})
}
None => None,
}
} else {
None
}
}
None => None,
}
}

View File

@ -2,7 +2,6 @@ use crate::{Context, Error};
pub mod delete;
pub mod list;
pub mod migrate;
pub mod record;
pub mod run;

View File

@ -112,7 +112,7 @@ pub async fn finish_macro(ctx: Context<'_>) -> Result<(), Error> {
let lock = ctx.data().recording_macros.read().await;
let contained = lock.get(&key);
if contained.map_or(true, |cmacro| cmacro.commands.is_empty()) {
if contained.map_or(true, |r#macro| r#macro.commands.is_empty()) {
ctx.send(
CreateReply::default().embed(
CreateEmbed::new()

View File

@ -1,7 +1,4 @@
use poise::{
serenity_prelude::{CommandOption, CreateEmbed},
CreateReply,
};
use poise::{serenity_prelude::CreateEmbed, CreateReply};
use super::super::autocomplete::macro_name_autocomplete;
use crate::{models::command_macro::guild_command_macro, Context, Data, Error, THEME_COLOR};
@ -35,20 +32,7 @@ pub async fn run_macro(
.await?;
for command in command_macro.commands {
if let Some(action) = command.action {
match (action)(poise::ApplicationContext { args: &command.options, ..ctx })
.await
{
Ok(()) => {}
Err(e) => {
println!("{:?}", e);
}
}
} else {
Context::Application(ctx)
.say(format!("Command \"{}\" not found", command.command_name))
.await?;
}
command.execute(poise::ApplicationContext { ..ctx }).await;
}
}

View File

@ -125,15 +125,15 @@ pub async fn offset(
if combined_time == 0 {
ctx.say("Please specify one of `hours`, `minutes` or `seconds`").await?;
} else {
if let Some(guild) = ctx.guild() {
let channels = guild
if let Some(channels) = ctx.guild().map(|guild| {
guild
.channels
.iter()
.filter(|(_, channel)| channel.is_text_based())
.map(|(id, _)| id.get().to_string())
.collect::<Vec<String>>()
.join(",");
.join(",")
}) {
sqlx::query!(
"
UPDATE reminders
@ -224,9 +224,7 @@ pub async fn look(
}),
};
let channel_opt = ctx.channel_id().to_channel_cached(&ctx.cache());
let channel_id = if let Some(channel) = channel_opt {
let channel_id = if let Some(channel) = ctx.channel_id().to_channel_cached(&ctx.cache()) {
if Some(channel.guild_id) == ctx.guild_id() {
flags.channel_id.unwrap_or_else(|| ctx.channel_id())
} else {
@ -795,17 +793,14 @@ fn create_response(
),
};
let mut embed = CreateEmbed::default();
embed
CreateEmbed::default()
.title(format!(
"{n} Reminder{s} Set",
n = successes.len(),
s = if successes.len() > 1 { "s" } else { "" }
))
.description(format!("{}\n\n{}", success_part, error_part))
.color(*THEME_COLOR);
embed
.color(*THEME_COLOR)
}
fn parse_mention_list(mentions: &str) -> Vec<ReminderScope> {

View File

@ -6,6 +6,18 @@ async fn macro_check(ctx: Context<'_>) -> bool {
if let Context::Application(app_ctx) = ctx {
if let Some(guild_id) = ctx.guild_id() {
if ctx.command().identifying_name != "finish_macro" {
if ctx.command().identifying_name != "remind" {
let _ = ctx
.send(
CreateReply::default()
.ephemeral(true)
.content("Macro recording only supports `/remind`. Please stop recording with `/macro finish` before using other commands.")
)
.await;
return false;
}
let mut lock = ctx.data().recording_macros.write().await;
if let Some(command_macro) = lock.get_mut(&(guild_id, ctx.author().id)) {
@ -18,12 +30,7 @@ async fn macro_check(ctx: Context<'_>) -> bool {
)
.await;
} else {
let recorded = RecordedCommand {
action: None,
command_name: ctx.command().identifying_name.clone(),
options: app_ctx.interaction.data.options.clone(),
};
let recorded = RecordedCommand::from_context(app_ctx).unwrap();
command_macro.commands.push(recorded);
let _ = ctx

View File

@ -50,7 +50,7 @@ type ApplicationContext<'a> = poise::ApplicationContext<'a, Data, Error>;
pub struct Data {
database: Pool<Database>,
http: reqwest::Client,
recording_macros: RwLock<HashMap<(GuildId, UserId), CommandMacro<Data, Error>>>,
recording_macros: RwLock<HashMap<(GuildId, UserId), CommandMacro>>,
popular_timezones: Vec<Tz>,
_broadcast: Sender<()>,
}
@ -132,7 +132,6 @@ async fn _main(tx: Sender<()>) -> Result<(), Box<dyn StdError + Send + Sync>> {
command_macro::list::list_macro(),
command_macro::record::record_macro(),
command_macro::run::run_macro(),
command_macro::migrate::migrate_macro(),
],
..command_macro::macro_base()
},

View File

@ -1,47 +1,66 @@
use poise::serenity_prelude::{model::id::GuildId, CommandDataOption};
use chrono_tz::Tz;
use poise::serenity_prelude::model::id::GuildId;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::{Context, Data, Error};
type Func<U, E> = for<'a> fn(
poise::ApplicationContext<'a, U, E>,
) -> poise::BoxFuture<'a, Result<(), poise::FrameworkError<'a, U, E>>>;
fn default_none<U, E>() -> Option<Func<U, E>> {
None
}
use crate::{ApplicationContext, Context};
#[derive(Serialize, Deserialize)]
pub struct RecordedCommand<U, E> {
#[serde(skip)]
#[serde(default = "default_none::<U, E>")]
pub action: Option<Func<U, E>>,
pub command_name: String,
pub options: Vec<CommandDataOption>,
#[serde(tag = "command_name")]
pub enum RecordedCommand {
Remind(RemindOptions),
}
pub struct CommandMacro<U, E> {
impl RecordedCommand {
pub fn from_context(ctx: ApplicationContext) -> Option<Self> {
match ctx.command().identifying_name.as_str() {
"remind" => Some(Self::Remind(RemindOptions {
time: "10 seconds".to_string(),
content: "message".to_string(),
channels: None,
interval: None,
expires: None,
tts: None,
timezone: None,
})),
_ => None,
}
}
pub async fn execute(&self, ctx: ApplicationContext<'_>) {
match self {
RecordedCommand::Remind(_) => {}
}
}
}
#[derive(Serialize, Deserialize, Default)]
pub struct RemindOptions {
time: String,
content: String,
channels: Option<String>,
interval: Option<String>,
expires: Option<String>,
tts: Option<bool>,
timezone: Option<Tz>,
}
pub struct CommandMacro {
pub guild_id: GuildId,
pub name: String,
pub description: Option<String>,
pub commands: Vec<RecordedCommand<U, E>>,
pub commands: Vec<RecordedCommand>,
}
pub struct RawCommandMacro {
pub guild_id: GuildId,
pub name: String,
pub description: Option<String>,
pub commands: Value,
}
pub async fn guild_command_macro(
ctx: &Context<'_>,
name: &str,
) -> Option<CommandMacro<Data, Error>> {
pub async fn guild_command_macro(ctx: &Context<'_>, name: &str) -> Option<CommandMacro> {
let row = sqlx::query!(
"
SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?
SELECT m.id, m.name, m.description, m.commands
FROM macro m
INNER JOIN guilds g
ON g.id = m.guild_id
WHERE guild = ?
AND m.name = ?
",
ctx.guild_id().unwrap().get(),
name
@ -50,20 +69,7 @@ SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND
.await
.ok()?;
let mut commands: Vec<RecordedCommand<Data, Error>> =
serde_json::from_str(&row.commands).unwrap();
for recorded_command in &mut commands {
let command = &ctx
.framework()
.options()
.commands
.iter()
.find(|c| c.identifying_name == recorded_command.command_name);
recorded_command.action = command.map(|c| c.slash_action).flatten();
}
let commands: Vec<RecordedCommand> = serde_json::from_str(&row.commands).unwrap();
let command_macro = CommandMacro {
guild_id: ctx.guild_id().unwrap(),
name: row.name,

View File

@ -25,7 +25,7 @@ pub trait CtxData {
async fn channel_data(&self) -> Result<ChannelData, Error>;
async fn command_macros(&self) -> Result<Vec<CommandMacro<Data, Error>>, Error>;
async fn command_macros(&self) -> Result<Vec<CommandMacro>, Error>;
}
#[async_trait]
@ -57,7 +57,7 @@ impl CtxData for Context<'_> {
ChannelData::from_channel(&channel, &self.data().database).await
}
async fn command_macros(&self) -> Result<Vec<CommandMacro<Data, Error>>, Error> {
async fn command_macros(&self) -> Result<Vec<CommandMacro>, Error> {
self.data().command_macros(self.guild_id().unwrap()).await
}
}
@ -66,7 +66,7 @@ impl Data {
pub(crate) async fn command_macros(
&self,
guild_id: GuildId,
) -> Result<Vec<CommandMacro<Data, Error>>, Error> {
) -> Result<Vec<CommandMacro>, Error> {
let rows = sqlx::query!(
"SELECT name, description, commands FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
guild_id.get()

View File

@ -31,9 +31,7 @@ pub async fn check_guild_subscription(
cache_http: impl CacheHttp,
guild_id: impl Into<GuildId>,
) -> bool {
if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) {
let owner = guild.owner_id;
if let Some(owner) = cache_http.cache().unwrap().guild(guild_id).map(|g| g.owner_id) {
check_subscription(&cache_http, owner).await
} else {
false