2nd attempt at doing poise stuff

This commit is contained in:
jude 2022-02-19 14:32:03 +00:00
parent 620f054703
commit 84ee7e77c5
18 changed files with 1071 additions and 1035 deletions

474
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -5,8 +5,8 @@ authors = ["jellywx <judesouthworth@pm.me>"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
poise = { git = "https://github.com/kangalioo/poise", branch = "master" }
dotenv = "0.15" dotenv = "0.15"
humantime = "2.1"
tokio = { version = "1", features = ["process", "full"] } tokio = { version = "1", features = ["process", "full"] }
reqwest = "0.11" reqwest = "0.11"
regex = "1.4" regex = "1.4"
@ -33,20 +33,3 @@ path = "postman"
[dependencies.reminder_web] [dependencies.reminder_web]
path = "web" path = "web"
[dependencies.serenity]
git = "https://github.com/serenity-rs/serenity"
branch = "next"
default-features = false
features = [
"builder",
"client",
"cache",
"gateway",
"http",
"model",
"utils",
"rustls_backend",
"collector",
"unstable_discord_api"
]

View File

@ -1,16 +1,11 @@
use chrono::offset::Utc; use chrono::offset::Utc;
use regex_command_attr::command; use poise::serenity::builder::CreateEmbedFooter;
use serenity::{builder::CreateEmbedFooter, client::Context};
use crate::{ use crate::{models::CtxData, Context, Error, THEME_COLOR};
framework::{CommandInvoke, CreateGenericResponse},
models::CtxData,
THEME_COLOR,
};
fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter { fn footer(ctx: Context<'_>) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter {
let shard_count = ctx.cache.shard_count(); let shard_count = ctx.discord().cache.shard_count();
let shard = ctx.shard_id; let shard = ctx.discord().shard_id;
move |f| { move |f| {
f.text(format!( f.text(format!(
@ -22,15 +17,14 @@ fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEm
} }
} }
#[command] /// Get an overview of bot commands
#[description("Get an overview of the bot commands")] #[poise::command(slash_command)]
async fn help(ctx: &Context, invoke: &mut CommandInvoke) { pub async fn help(ctx: Context<'_>) -> Result<(), Error> {
let footer = footer(ctx); let footer = footer(ctx);
let _ = invoke let _ = ctx
.respond( .send(|m| {
&ctx, m.embed(|e| {
CreateGenericResponse::new().embed(|e| {
e.title("Help") e.title("Help")
.color(*THEME_COLOR) .color(*THEME_COLOR)
.description( .description(
@ -60,21 +54,21 @@ __Advanced Commands__
", ",
) )
.footer(footer) .footer(footer)
}), })
) })
.await; .await;
Ok(())
} }
#[command] /// Get information about the bot
#[aliases("invite")] #[poise::command(slash_command)]
#[description("Get information about the bot")] pub async fn info(ctx: Context<'_>) -> Result<(), Error> {
async fn info(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx); let footer = footer(ctx);
let _ = invoke let _ = ctx
.respond( .send(|m| {
ctx.http.clone(), m.embed(|e| {
CreateGenericResponse::new().embed(|e| {
e.title("Info") e.title("Info")
.description(format!( .description(format!(
"Help: `/help` "Help: `/help`
@ -89,21 +83,19 @@ Use our dashboard: https://reminder-bot.com/",
)) ))
.footer(footer) .footer(footer)
.color(*THEME_COLOR) .color(*THEME_COLOR)
}), })
) })
.await; .await;
Ok(())
} }
#[command] /// Details on supporting the bot and Patreon benefits
#[description("Details on supporting the bot and Patreon benefits")] #[poise::command(slash_command)]
#[group("Info")] pub async fn donate(ctx: Context<'_>) -> Result<(), Error> {
async fn donate(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx); let footer = footer(ctx);
let _ = invoke let _ = ctx.send(|m| m.embed(|e| {
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Donate") e.title("Donate")
.description("Thinking of adding a monthly contribution? Click below for my Patreon and official bot server :) .description("Thinking of adding a monthly contribution? Click below for my Patreon and official bot server :)
@ -125,38 +117,41 @@ Just $2 USD/month!
}), }),
) )
.await; .await;
Ok(())
} }
#[command] /// Get the link to the online dashboard
#[description("Get the link to the online dashboard")] #[poise::command(slash_command)]
#[group("Info")] pub async fn dashboard(ctx: Context<'_>) -> Result<(), Error> {
async fn dashboard(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx); let footer = footer(ctx);
let _ = invoke let _ = ctx
.respond( .send(|m| {
ctx.http.clone(), m.embed(|e| {
CreateGenericResponse::new().embed(|e| {
e.title("Dashboard") e.title("Dashboard")
.description("**https://reminder-bot.com/dashboard**") .description("**https://reminder-bot.com/dashboard**")
.footer(footer) .footer(footer)
.color(*THEME_COLOR) .color(*THEME_COLOR)
}), })
) })
.await; .await;
Ok(())
} }
#[command] /// View the current time in a user's selected timezone
#[description("View the current time in your selected timezone")] #[poise::command(slash_command)]
#[group("Info")] pub async fn clock(ctx: Context<'_>) -> Result<(), Error> {
async fn clock(ctx: &Context, invoke: &mut CommandInvoke) { ctx.defer_ephemeral().await?;
let ud = ctx.user_data(&invoke.author_id()).await.unwrap();
let now = Utc::now().with_timezone(&ud.timezone());
let _ = invoke let tz = ctx.timezone().await;
.respond( let now = Utc::now().with_timezone(&tz);
ctx.http.clone(),
CreateGenericResponse::new().content(format!("Current time: {}", now.format("%H:%M"))), ctx.send(|m| {
) m.ephemeral(true).content(format!("Time in **{}**: `{}`", tz, now.format("%H:%M")))
.await; })
.await?;
Ok(())
} }

View File

@ -1,4 +1,4 @@
pub mod info_cmds; pub mod info_cmds;
pub mod moderation_cmds; pub mod moderation_cmds;
pub mod reminder_cmds; //pub mod reminder_cmds;
pub mod todo_cmds; //pub mod todo_cmds;

View File

@ -1,44 +1,50 @@
use chrono::offset::Utc; use chrono::offset::Utc;
use chrono_tz::{Tz, TZ_VARIANTS}; use chrono_tz::{Tz, TZ_VARIANTS};
use levenshtein::levenshtein; use levenshtein::levenshtein;
use regex_command_attr::command; use poise::CreateReply;
use serenity::client::Context;
use crate::{ use crate::{
component_models::pager::{MacroPager, Pager},
consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR}, consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR},
framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue}, hooks::guild_only,
hooks::{CHECK_GUILD_PERMISSIONS_HOOK, GUILD_ONLY_HOOK},
models::{command_macro::CommandMacro, CtxData}, models::{command_macro::CommandMacro, CtxData},
PopularTimezones, RecordingMacros, RegexFramework, SQLPool, Context, Data, Error,
}; };
#[command("timezone")] async fn timezone_autocomplete(ctx: Context<'_>, partial: String) -> Vec<String> {
#[description("Select your timezone")] if partial.is_empty() {
#[arg( ctx.data().popular_timezones.iter().map(|t| t.to_string()).collect::<Vec<String>>()
name = "timezone", } else {
description = "Timezone to use from this list: https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee", TZ_VARIANTS
kind = "String", .iter()
required = false .filter(|tz| tz.to_string().contains(&partial))
)] .take(25)
async fn timezone(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) { .map(|t| t.to_string())
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); .collect::<Vec<String>>()
let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap(); }
}
/// Select your timezone
#[poise::command(slash_command)]
pub async fn timezone(
ctx: Context<'_>,
#[description = "Timezone to use from this list: https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee"]
#[autocomplete = "timezone_autocomplete"]
timezone: Option<String>,
) -> Result<(), Error> {
let mut user_data = ctx.author_data().await.unwrap();
let footer_text = format!("Current timezone: {}", user_data.timezone); let footer_text = format!("Current timezone: {}", user_data.timezone);
if let Some(OptionValue::String(timezone)) = args.get("timezone") { if let Some(timezone) = timezone {
match timezone.parse::<Tz>() { match timezone.parse::<Tz>() {
Ok(tz) => { Ok(tz) => {
user_data.timezone = timezone.clone(); user_data.timezone = timezone.clone();
user_data.commit_changes(&pool).await; user_data.commit_changes(&ctx.data().database).await;
let now = Utc::now().with_timezone(&tz); let now = Utc::now().with_timezone(&tz);
let _ = invoke ctx.send(|m| {
.respond( m.embed(|e| {
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Timezone Set") e.title("Timezone Set")
.description(format!( .description(format!(
"Timezone has been set to **{}**. Your current time should be `{}`", "Timezone has been set to **{}**. Your current time should be `{}`",
@ -46,9 +52,9 @@ async fn timezone(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOption
now.format("%H:%M").to_string() now.format("%H:%M").to_string()
)) ))
.color(*THEME_COLOR) .color(*THEME_COLOR)
}), })
) })
.await; .await?;
} }
Err(_) => { Err(_) => {
@ -56,8 +62,8 @@ async fn timezone(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOption
.iter() .iter()
.filter(|tz| { .filter(|tz| {
timezone.contains(&tz.to_string()) timezone.contains(&tz.to_string())
|| tz.to_string().contains(timezone) || tz.to_string().contains(&timezone)
|| levenshtein(&tz.to_string(), timezone) < 4 || levenshtein(&tz.to_string(), &timezone) < 4
}) })
.take(25) .take(25)
.map(|t| t.to_owned()) .map(|t| t.to_owned())
@ -74,25 +80,21 @@ async fn timezone(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOption
) )
}); });
let _ = invoke ctx.send(|m| {
.respond( m.embed(|e| {
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Timezone Not Recognized") e.title("Timezone Not Recognized")
.description("Possibly you meant one of the following timezones, otherwise click [here](https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee):") .description("Possibly you meant one of the following timezones, otherwise click [here](https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee):")
.color(*THEME_COLOR) .color(*THEME_COLOR)
.fields(fields) .fields(fields)
.footer(|f| f.text(footer_text)) .footer(|f| f.text(footer_text))
.url("https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee") .url("https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee")
}), })
) })
.await; .await?;
} }
} }
} else { } else {
let popular_timezones = ctx.data.read().await.get::<PopularTimezones>().cloned().unwrap(); let popular_timezones_iter = ctx.data().popular_timezones.iter().map(|t| {
let popular_timezones_iter = popular_timezones.iter().map(|t| {
( (
t.to_string(), t.to_string(),
format!("🕗 `{}`", Utc::now().with_timezone(t).format("%H:%M").to_string()), format!("🕗 `{}`", Utc::now().with_timezone(t).format("%H:%M").to_string()),
@ -100,10 +102,8 @@ async fn timezone(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOption
) )
}); });
let _ = invoke ctx.send(|m| {
.respond( m.embed(|e| {
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Timezone Usage") e.title("Timezone Usage")
.description( .description(
"**Usage:** "**Usage:**
@ -118,137 +118,137 @@ You may want to use one of the popular timezones below, otherwise click [here](h
.fields(popular_timezones_iter) .fields(popular_timezones_iter)
.footer(|f| f.text(footer_text)) .footer(|f| f.text(footer_text))
.url("https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee") .url("https://gist.github.com/JellyWX/913dfc8b63d45192ad6cb54c829324ee")
}), })
})
.await?;
}
Ok(())
}
async fn macro_name_autocomplete(ctx: Context<'_>, partial: String) -> Vec<String> {
sqlx::query!(
"
SELECT name
FROM macro
WHERE
guild_id = (SELECT id FROM guilds WHERE guild = ?)
AND name LIKE CONCAT(?, '%')",
ctx.guild_id().unwrap().0,
partial,
) )
.await; .fetch_all(&ctx.data().database)
} .await
.unwrap_or(vec![])
.iter()
.map(|s| s.name.clone())
.collect()
} }
#[command("macro")] /// Record and replay command sequences
#[description("Record and replay command sequences")] #[poise::command(slash_command, rename = "macro", check = "guild_only")]
#[subcommand("record")] pub async fn macro_base(_ctx: Context<'_>) -> Result<(), Error> {
#[description("Start recording up to 5 commands to replay")] Ok(())
#[arg(name = "name", description = "Name for the new macro", kind = "String", required = true)] }
#[arg(
name = "description",
description = "Description for the new macro",
kind = "String",
required = false
)]
#[subcommand("finish")]
#[description("Finish current recording")]
#[subcommand("list")]
#[description("List recorded macros")]
#[subcommand("run")]
#[description("Run a recorded macro")]
#[arg(name = "name", description = "Name of the macro to run", kind = "String", required = true)]
#[subcommand("delete")]
#[description("Delete a recorded macro")]
#[arg(name = "name", description = "Name of the macro to delete", kind = "String", required = true)]
#[supports_dm(false)]
#[hook(GUILD_ONLY_HOOK)]
#[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn macro_cmd(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
match args.subcommand.clone().unwrap().as_str() { /// Start recording up to 5 commands to replay
"record" => { #[poise::command(slash_command, rename = "record", check = "guild_only")]
let guild_id = invoke.guild_id().unwrap(); pub async fn record_macro(
ctx: Context<'_>,
let name = args.get("name").unwrap().to_string(); #[description = "Name for the new macro"] name: String,
#[description = "Description for the new macro"] description: Option<String>,
) -> Result<(), Error> {
let guild_id = ctx.guild_id().unwrap();
let row = sqlx::query!( let row = sqlx::query!(
"SELECT 1 as _e FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?", "
SELECT 1 as _e FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?",
guild_id.0, guild_id.0,
name name
) )
.fetch_one(&pool) .fetch_one(&ctx.data().database)
.await; .await;
if row.is_ok() { if row.is_ok() {
let _ = invoke ctx.send(|m| {
.respond( m.ephemeral(true).embed(|e| {
&ctx, e.title("Unique Name Required")
CreateGenericResponse::new().ephemeral().embed(|e| { .description(
e "A macro already exists under this name.
.title("Unique Name Required") Please select a unique name for your macro.",
.description("A macro already exists under this name. Please select a unique name for your macro.")
.color(*THEME_COLOR)
}),
) )
.await; .color(*THEME_COLOR)
})
})
.await?;
} else { } else {
let macro_buffer = ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
let okay = { let okay = {
let mut lock = macro_buffer.write().await; let mut lock = ctx.data().recording_macros.write().await;
if lock.contains_key(&(guild_id, invoke.author_id())) { if lock.contains_key(&(guild_id, ctx.author().id)) {
false false
} else { } else {
lock.insert( lock.insert(
(guild_id, invoke.author_id()), (guild_id, ctx.author().id),
CommandMacro { CommandMacro { guild_id, name, description, commands: vec![] },
guild_id,
name,
description: args.get("description").map(|d| d.to_string()),
commands: vec![],
},
); );
true true
} }
}; };
if okay { if okay {
let _ = invoke ctx.send(|m| {
.respond( m.ephemeral(true).embed(|e| {
&ctx, e.title("Macro Recording Started")
CreateGenericResponse::new().ephemeral().embed(|e| {
e
.title("Macro Recording Started")
.description( .description(
"Run up to 5 commands, or type `/macro finish` to stop at any point. "Run up to 5 commands, or type `/macro finish` to stop at any point.
Any commands ran as part of recording will be inconsequential") Any commands ran as part of recording will be inconsequential",
.color(*THEME_COLOR)
}),
) )
.await; .color(*THEME_COLOR)
})
})
.await?;
} else { } else {
let _ = invoke ctx.send(|m| {
.respond( m.ephemeral(true).embed(|e| {
&ctx,
CreateGenericResponse::new().ephemeral().embed(|e| {
e.title("Macro Already Recording") e.title("Macro Already Recording")
.description( .description(
"You are already recording a macro in this server. "You are already recording a macro in this server.
Please use `/macro finish` to end this recording before starting another.", Please use `/macro finish` to end this recording before starting another.",
) )
.color(*THEME_COLOR) .color(*THEME_COLOR)
}), })
) })
.await; .await?;
} }
} }
Ok(())
} }
"finish" => {
let key = (invoke.guild_id().unwrap(), invoke.author_id()); /// Finish current macro recording
let macro_buffer = ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap(); #[poise::command(
slash_command,
rename = "finish",
check = "guild_only",
identifying_name = "macro_finish"
)]
pub async fn finish_macro(ctx: Context<'_>) -> Result<(), Error> {
let key = (ctx.guild_id().unwrap(), ctx.author().id);
{ {
let lock = macro_buffer.read().await; let lock = ctx.data().recording_macros.read().await;
let contained = lock.get(&key); let contained = lock.get(&key);
if contained.map_or(true, |cmacro| cmacro.commands.is_empty()) { if contained.map_or(true, |cmacro| cmacro.commands.is_empty()) {
let _ = invoke ctx.send(|m| {
.respond( m.embed(|e| {
&ctx,
CreateGenericResponse::new().embed(|e| {
e.title("No Macro Recorded") e.title("No Macro Recorded")
.description("Use `/macro record` to start recording a macro") .description("Use `/macro record` to start recording a macro")
.color(*THEME_COLOR) .color(*THEME_COLOR)
}), })
) })
.await; .await?;
} else { } else {
let command_macro = contained.unwrap(); let command_macro = contained.unwrap();
let json = serde_json::to_string(&command_macro.commands).unwrap(); let json = serde_json::to_string(&command_macro.commands).unwrap();
@ -260,119 +260,120 @@ Please use `/macro finish` to end this recording before starting another.",
command_macro.description, command_macro.description,
json json
) )
.execute(&pool) .execute(&ctx.data().database)
.await .await
.unwrap(); .unwrap();
let _ = invoke ctx.send(|m| {
.respond( m.embed(|e| {
&ctx,
CreateGenericResponse::new().embed(|e| {
e.title("Macro Recorded") e.title("Macro Recorded")
.description("Use `/macro run` to execute the macro") .description("Use `/macro run` to execute the macro")
.color(*THEME_COLOR) .color(*THEME_COLOR)
}), })
) })
.await; .await?;
} }
} }
{ {
let mut lock = macro_buffer.write().await; let mut lock = ctx.data().recording_macros.write().await;
lock.remove(&key); lock.remove(&key);
} }
Ok(())
} }
"list" => {
let macros = CommandMacro::from_guild(ctx, invoke.guild_id().unwrap()).await; /// List recorded macros
#[poise::command(slash_command, rename = "list", check = "guild_only")]
pub async fn list_macro(ctx: Context<'_>) -> Result<(), Error> {
// let macros = CommandMacro::from_guild(&ctx.data().database, ctx.guild_id().unwrap()).await;
let macros: Vec<CommandMacro<Data, Error>> = vec![];
let resp = show_macro_page(&macros, 0); let resp = show_macro_page(&macros, 0);
invoke.respond(&ctx, resp).await.unwrap(); ctx.send(|m| {
} *m = resp;
"run" => { m
let macro_name = args.get("name").unwrap().to_string(); })
.await?;
Ok(())
}
/// Run a recorded macro
#[poise::command(slash_command, rename = "run", check = "guild_only")]
pub async fn run_macro(
ctx: Context<'_>,
#[description = "Name of macro to run"]
#[autocomplete = "macro_name_autocomplete"]
name: String,
) -> Result<(), Error> {
match sqlx::query!( match sqlx::query!(
"SELECT commands FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?", "
invoke.guild_id().unwrap().0, SELECT commands FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?",
macro_name ctx.guild_id().unwrap().0,
name
) )
.fetch_one(&pool) .fetch_one(&ctx.data().database)
.await .await
{ {
Ok(row) => { Ok(row) => {
invoke.defer(&ctx).await; ctx.defer().await?;
let commands: Vec<CommandOptions> = // TODO TODO TODO!!!!!!!! RUN COMMAND FROM MACRO
serde_json::from_str(&row.commands).unwrap();
let framework = ctx.data.read().await.get::<RegexFramework>().cloned().unwrap();
for command in commands {
framework.run_command_from_options(ctx, invoke, command).await;
}
} }
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {
let _ = invoke ctx.say(format!("Macro \"{}\" not found", name)).await?;
.respond(
&ctx,
CreateGenericResponse::new()
.content(format!("Macro \"{}\" not found", macro_name)),
)
.await;
} }
Err(e) => { Err(e) => {
panic!("{}", e); panic!("{}", e);
} }
} }
}
"delete" => {
let macro_name = args.get("name").unwrap().to_string();
Ok(())
}
/// Delete a recorded macro
#[poise::command(slash_command, rename = "delete", check = "guild_only")]
pub async fn delete_macro(
ctx: Context<'_>,
#[description = "Name of macro to delete"]
#[autocomplete = "macro_name_autocomplete"]
name: String,
) -> Result<(), Error> {
match sqlx::query!( match sqlx::query!(
"SELECT id FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?", "
invoke.guild_id().unwrap().0, SELECT id FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?",
macro_name ctx.guild_id().unwrap().0,
name
) )
.fetch_one(&pool) .fetch_one(&ctx.data().database)
.await .await
{ {
Ok(row) => { Ok(row) => {
sqlx::query!("DELETE FROM macro WHERE id = ?", row.id) sqlx::query!("DELETE FROM macro WHERE id = ?", row.id)
.execute(&pool) .execute(&ctx.data().database)
.await .await
.unwrap(); .unwrap();
let _ = invoke ctx.say(format!("Macro \"{}\" deleted", name)).await?;
.respond(
&ctx,
CreateGenericResponse::new()
.content(format!("Macro \"{}\" deleted", macro_name)),
)
.await;
} }
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {
let _ = invoke ctx.say(format!("Macro \"{}\" not found", name)).await?;
.respond(
&ctx,
CreateGenericResponse::new()
.content(format!("Macro \"{}\" not found", macro_name)),
)
.await;
} }
Err(e) => { Err(e) => {
panic!("{}", e); panic!("{}", e);
} }
} }
}
_ => {} Ok(())
}
} }
pub fn max_macro_page(macros: &[CommandMacro]) -> usize { pub fn max_macro_page<U, E>(macros: &[CommandMacro<U, E>]) -> usize {
let mut skipped_char_count = 0; let mut skipped_char_count = 0;
macros macros
@ -396,15 +397,30 @@ pub fn max_macro_page(macros: &[CommandMacro]) -> usize {
}) })
} }
pub fn show_macro_page(macros: &[CommandMacro], page: usize) -> CreateGenericResponse { pub fn show_macro_page<U, E>(macros: &[CommandMacro<U, E>], page: usize) -> CreateReply {
let pager = MacroPager::new(page); let mut reply = CreateReply::default();
if macros.is_empty() { reply.embed(|e| {
return CreateGenericResponse::new().embed(|e| {
e.title("Macros") e.title("Macros")
.description("No Macros Set Up. Use `/macro record` to get started.") .description("No Macros Set Up. Use `/macro record` to get started.")
.color(*THEME_COLOR) .color(*THEME_COLOR)
}); });
reply
/*
let pager = MacroPager::new(page);
if macros.is_empty() {
let mut reply = CreateReply::default();
reply.embed(|e| {
e.title("Macros")
.description("No Macros Set Up. Use `/macro record` to get started.")
.color(*THEME_COLOR)
});
return reply;
} }
let pages = max_macro_page(macros); let pages = max_macro_page(macros);
@ -447,7 +463,9 @@ pub fn show_macro_page(macros: &[CommandMacro], page: usize) -> CreateGenericRes
let display = display_vec.join("\n"); let display = display_vec.join("\n");
CreateGenericResponse::new() let mut reply = CreateReply::default();
reply
.embed(|e| { .embed(|e| {
e.title("Macros") e.title("Macros")
.description(display) .description(display)
@ -458,5 +476,8 @@ pub fn show_macro_page(macros: &[CommandMacro], page: usize) -> CreateGenericRes
pager.create_button_row(pages, comp); pager.create_button_row(pages, comp);
comp comp
}) });
reply
*/
} }

View File

@ -6,11 +6,12 @@ pub const SELECT_MAX_ENTRIES: usize = 25;
pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
const THEME_COLOR_FALLBACK: u32 = 0x8fb677; const THEME_COLOR_FALLBACK: u32 = 0x8fb677;
pub const MACRO_MAX_COMMANDS: usize = 5;
use std::{collections::HashSet, env, iter::FromIterator}; use std::{collections::HashSet, env, iter::FromIterator};
use poise::serenity::model::prelude::AttachmentType;
use regex::Regex; use regex::Regex;
use serenity::model::prelude::AttachmentType;
lazy_static! { lazy_static! {
pub static ref DEFAULT_AVATAR: AttachmentType<'static> = ( pub static ref DEFAULT_AVATAR: AttachmentType<'static> = (

View File

@ -1,33 +1,22 @@
use std::{collections::HashMap, env, sync::atomic::Ordering}; use std::{collections::HashMap, env, sync::atomic::Ordering};
use log::{info, warn}; use log::{info, warn};
use serenity::{ use poise::serenity::{client::Context, model::interactions::Interaction, utils::shard_id};
async_trait,
client::{Context, EventHandler},
model::{
channel::GuildChannel,
gateway::{Activity, Ready},
guild::{Guild, UnavailableGuild},
id::GuildId,
interactions::Interaction,
},
utils::shard_id,
};
use crate::{ComponentDataModel, Handler, RegexFramework, ReqwestClient, SQLPool}; use crate::{Data, Error};
#[async_trait] pub async fn listener(ctx: &Context, event: &poise::Event<'_>, data: &Data) -> Result<(), Error> {
impl EventHandler for Handler { match event {
async fn cache_ready(&self, ctx_base: Context, _guilds: Vec<GuildId>) { poise::Event::CacheReady { .. } => {
info!("Cache Ready!"); info!("Cache Ready!");
info!("Preparing to send reminders"); info!("Preparing to send reminders");
if !self.is_loop_running.load(Ordering::Relaxed) { if !data.is_loop_running.load(Ordering::Relaxed) {
let ctx1 = ctx_base.clone(); let ctx1 = ctx.clone();
let ctx2 = ctx_base.clone(); let ctx2 = ctx.clone();
let pool1 = ctx1.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool1 = data.database.clone();
let pool2 = ctx2.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool2 = data.database.clone();
let run_settings = env::var("DONTRUN").unwrap_or_else(|_| "".to_string()); let run_settings = env::var("DONTRUN").unwrap_or_else(|_| "".to_string());
@ -47,41 +36,28 @@ impl EventHandler for Handler {
warn!("Not running web") warn!("Not running web")
} }
self.is_loop_running.swap(true, Ordering::Relaxed); data.is_loop_running.swap(true, Ordering::Relaxed);
} }
} }
poise::Event::ChannelDelete { channel } => {
async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
sqlx::query!( sqlx::query!(
" "
DELETE FROM channels WHERE channel = ? DELETE FROM channels WHERE channel = ?
", ",
channel.id.as_u64() channel.id.as_u64()
) )
.execute(&pool) .execute(&data.database)
.await .await
.unwrap(); .unwrap();
} }
poise::Event::GuildCreate { guild, is_new } => {
async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) { if *is_new {
if is_new {
let guild_id = guild.id.as_u64().to_owned(); let guild_id = guild.id.as_u64().to_owned();
{ sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id)
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); .execute(&data.database)
.await
let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id) .unwrap();
.execute(&pool)
.await;
}
if let Ok(token) = env::var("DISCORDBOTS_TOKEN") { if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
let shard_count = ctx.cache.shard_count(); let shard_count = ctx.cache.shard_count();
@ -91,7 +67,9 @@ DELETE FROM channels WHERE channel = ?
.cache .cache
.guilds() .guilds()
.iter() .iter()
.filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id) .filter(|g| {
shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id
})
.count() as u64; .count() as u64;
let mut hm = HashMap::new(); let mut hm = HashMap::new();
@ -99,15 +77,8 @@ DELETE FROM channels WHERE channel = ?
hm.insert("shard_id", current_shard_id); hm.insert("shard_id", current_shard_id);
hm.insert("shard_count", shard_count); hm.insert("shard_count", shard_count);
let client = ctx let response = data
.data .http
.read()
.await
.get::<ReqwestClient>()
.cloned()
.expect("Could not get ReqwestClient from data");
let response = client
.post( .post(
format!( format!(
"https://top.gg/api/bots/{}/stats", "https://top.gg/api/bots/{}/stats",
@ -126,36 +97,20 @@ DELETE FROM channels WHERE channel = ?
} }
} }
} }
poise::Event::GuildDelete { incomplete, full } => {
async fn guild_delete(&self, ctx: Context, incomplete: UnavailableGuild, _full: Option<Guild>) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0) let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0)
.execute(&pool) .execute(&data.database)
.await; .await;
} }
poise::Event::InteractionCreate { interaction } => match interaction {
async fn ready(&self, ctx: Context, _: Ready) {
ctx.set_activity(Activity::watching("for /remind")).await;
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
match interaction {
Interaction::ApplicationCommand(application_command) => {
let framework = ctx
.data
.read()
.await
.get::<RegexFramework>()
.cloned()
.expect("RegexFramework not found in context");
framework.execute(ctx, application_command).await;
}
Interaction::MessageComponent(component) => { Interaction::MessageComponent(component) => {
let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id); //let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id);
component_model.act(&ctx, component).await; //component_model.act(&ctx, component).await;
} }
_ => {} _ => {}
},
_ => {}
} }
}
Ok(())
} }

View File

@ -1,91 +1,74 @@
use regex_command_attr::check; use poise::{serenity::model::channel::Channel, ApplicationCommandOrAutocompleteInteraction};
use serenity::{client::Context, model::channel::Channel};
use crate::{ use crate::{consts::MACRO_MAX_COMMANDS, Context, Error};
framework::{CommandInvoke, CommandOptions, CreateGenericResponse, HookResult},
moderation_cmds, RecordingMacros,
};
#[check] pub async fn guild_only(ctx: Context<'_>) -> Result<bool, Error> {
pub async fn guild_only( if ctx.guild_id().is_some() {
ctx: &Context, Ok(true)
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if invoke.guild_id().is_some() {
HookResult::Continue
} else { } else {
let _ = invoke let _ = ctx.say("This command can only be used in servers").await;
.respond(
&ctx,
CreateGenericResponse::new().content("This command can only be used in servers"),
)
.await;
HookResult::Halt Ok(false)
} }
} }
#[check] async fn macro_check(ctx: Context<'_>) -> bool {
pub async fn macro_check( if let Context::Application(app_ctx) = ctx {
ctx: &Context, if let ApplicationCommandOrAutocompleteInteraction::ApplicationCommand(interaction) =
invoke: &mut CommandInvoke, app_ctx.interaction
args: &CommandOptions, {
) -> HookResult { if let Some(guild_id) = ctx.guild_id() {
if let Some(guild_id) = invoke.guild_id() { if ctx.command().identifying_name != "macro_finish" {
if args.command != moderation_cmds::MACRO_CMD_COMMAND.names[0] { let mut lock = ctx.data().recording_macros.write().await;
let active_recordings =
ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
let mut lock = active_recordings.write().await;
if let Some(command_macro) = lock.get_mut(&(guild_id, invoke.author_id())) { if let Some(command_macro) = lock.get_mut(&(guild_id, ctx.author().id)) {
if command_macro.commands.len() >= 5 { if command_macro.commands.len() >= MACRO_MAX_COMMANDS {
let _ = invoke let _ = ctx.send(|m| {
.respond( m.ephemeral(true).content(
&ctx, "5 commands already recorded. Please use `/macro finish` to end recording.",
CreateGenericResponse::new().content("5 commands already recorded. Please use `/macro finish` to end recording."),
) )
})
.await; .await;
} else { } else {
command_macro.commands.push(args.clone()); // TODO TODO TODO write command to macro
let _ = invoke let _ = ctx
.respond( .send(|m| m.ephemeral(true).content("Command recorded to macro"))
&ctx,
CreateGenericResponse::new().content("Command recorded to macro"),
)
.await; .await;
} }
HookResult::Halt false
} else { } else {
HookResult::Continue true
} }
} else { } else {
HookResult::Continue true
} }
} else { } else {
HookResult::Continue true
}
} else {
true
}
} else {
true
} }
} }
#[check] async fn check_self_permissions(ctx: Context<'_>) -> bool {
pub async fn check_self_permissions( if let Some(guild) = ctx.guild() {
ctx: &Context, let user_id = ctx.discord().cache.current_user_id();
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let user_id = ctx.cache.current_user_id();
let manage_webhooks = let manage_webhooks = guild
guild.member_permissions(&ctx, user_id).await.map_or(false, |p| p.manage_webhooks()); .member_permissions(&ctx.discord(), user_id)
let (view_channel, send_messages, embed_links) = invoke .await
.map_or(false, |p| p.manage_webhooks());
let (view_channel, send_messages, embed_links) = ctx
.channel_id() .channel_id()
.to_channel_cached(&ctx) .to_channel_cached(&ctx.discord())
.map(|c| { .map(|c| {
if let Channel::Guild(channel) = c { if let Channel::Guild(channel) = c {
channel.permissions_for_user(ctx, user_id).ok() channel.permissions_for_user(&ctx.discord(), user_id).ok()
} else { } else {
None None
} }
@ -96,12 +79,11 @@ pub async fn check_self_permissions(
}); });
if manage_webhooks && send_messages && embed_links { if manage_webhooks && send_messages && embed_links {
HookResult::Continue true
} else { } else {
let _ = invoke let _ = ctx
.respond( .send(|m| {
&ctx, m.content(format!(
CreateGenericResponse::new().content(format!(
"Please ensure the bot has the correct permissions: "Please ensure the bot has the correct permissions:
{} **View Channel** {} **View Channel**
@ -112,41 +94,17 @@ pub async fn check_self_permissions(
if send_messages { "" } else { "" }, if send_messages { "" } else { "" },
if manage_webhooks { "" } else { "" }, if manage_webhooks { "" } else { "" },
if embed_links { "" } else { "" }, if embed_links { "" } else { "" },
)), ))
) })
.await; .await;
HookResult::Halt false
} }
} else { } else {
HookResult::Continue true
} }
} }
#[check] pub async fn all_checks(ctx: Context<'_>) -> Result<bool, Error> {
pub async fn check_guild_permissions( Ok(macro_check(ctx).await && check_self_permissions(ctx).await)
ctx: &Context,
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap();
if !permissions.manage_guild() {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(
"You must have the \"Manage Server\" permission to use this command",
),
)
.await;
HookResult::Halt
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
} }

View File

@ -1,5 +1,9 @@
/* /*
Copyright 2021 Paul Colomiets, 2022 Jude Southworth With modifications, 2022 Jude Southworth
Original copyright notice:
Copyright 2021 Paul Colomiets
Permission is hereby granted, free of charge, to any person obtaining a copy of this software Permission is hereby granted, free of charge, to any person obtaining a copy of this software
and associated documentation files (the "Software"), to deal in the Software without restriction, and associated documentation files (the "Software"), to deal in the Software without restriction,

View File

@ -3,70 +3,45 @@
extern crate lazy_static; extern crate lazy_static;
mod commands; mod commands;
mod component_models; // mod component_models;
mod consts; mod consts;
mod event_handlers; mod event_handlers;
mod framework;
mod hooks; mod hooks;
mod interval_parser; mod interval_parser;
mod models; mod models;
mod time_parser; mod time_parser;
mod utils; mod utils;
use std::{ use std::{collections::HashMap, env, sync::atomic::AtomicBool};
collections::HashMap,
env,
sync::{atomic::AtomicBool, Arc},
};
use chrono_tz::Tz; use chrono_tz::Tz;
use dotenv::dotenv; use dotenv::dotenv;
use log::info; use poise::serenity::model::{
use serenity::{ gateway::{Activity, GatewayIntents},
client::Client,
http::client::Http,
model::{
gateway::GatewayIntents,
id::{GuildId, UserId}, id::{GuildId, UserId},
},
prelude::TypeMapKey,
}; };
use sqlx::mysql::MySqlPool; use sqlx::{MySql, Pool};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::{ use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds}, commands::{info_cmds, moderation_cmds},
component_models::ComponentDataModel,
consts::THEME_COLOR, consts::THEME_COLOR,
framework::RegexFramework, event_handlers::listener,
hooks::all_checks,
models::command_macro::CommandMacro, models::command_macro::CommandMacro,
utils::register_application_commands,
}; };
struct SQLPool; type Database = MySql;
impl TypeMapKey for SQLPool { type Error = Box<dyn std::error::Error + Send + Sync>;
type Value = MySqlPool; type Context<'a> = poise::Context<'a, Data, Error>;
}
struct ReqwestClient; pub struct Data {
database: Pool<Database>,
impl TypeMapKey for ReqwestClient { http: reqwest::Client,
type Value = Arc<reqwest::Client>; recording_macros: RwLock<HashMap<(GuildId, UserId), CommandMacro<Data, Error>>>,
} popular_timezones: Vec<Tz>,
struct PopularTimezones;
impl TypeMapKey for PopularTimezones {
type Value = Arc<Vec<Tz>>;
}
struct RecordingMacros;
impl TypeMapKey for RecordingMacros {
type Value = Arc<RwLock<HashMap<(GuildId, UserId), CommandMacro>>>;
}
struct Handler {
is_loop_running: AtomicBool, is_loop_running: AtomicBool,
} }
@ -76,85 +51,77 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
dotenv()?; dotenv()?;
let token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment"); let discord_token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment");
let application_id = { let options = poise::FrameworkOptions {
let http = Http::new_with_token(&token); commands: vec![
info_cmds::help(),
http.get_current_application_info().await?.id info_cmds::info(),
info_cmds::donate(),
info_cmds::clock(),
info_cmds::dashboard(),
moderation_cmds::timezone(),
poise::Command {
subcommands: vec![
moderation_cmds::delete_macro(),
moderation_cmds::finish_macro(),
moderation_cmds::list_macro(),
moderation_cmds::record_macro(),
moderation_cmds::run_macro(),
],
..moderation_cmds::macro_base()
},
],
allowed_mentions: None,
command_check: Some(|ctx| Box::pin(all_checks(ctx))),
listener: |ctx, event, _framework, data| Box::pin(listener(ctx, event, data)),
..Default::default()
}; };
let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1"); let database =
Pool::connect(&env::var("DATABASE_URL").expect("No database URL provided")).await.unwrap();
let framework = RegexFramework::new()
.ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
.debug_guild(env::var("DEBUG_GUILD").map_or(None, |g| {
Some(GuildId(g.parse::<u64>().expect("DEBUG_GUILD must be a guild ID")))
}))
.dm_enabled(dm_enabled)
// info commands
.add_command(&info_cmds::HELP_COMMAND)
.add_command(&info_cmds::INFO_COMMAND)
.add_command(&info_cmds::DONATE_COMMAND)
.add_command(&info_cmds::DASHBOARD_COMMAND)
.add_command(&info_cmds::CLOCK_COMMAND)
// reminder commands
.add_command(&reminder_cmds::TIMER_COMMAND)
.add_command(&reminder_cmds::REMIND_COMMAND)
// management commands
.add_command(&reminder_cmds::DELETE_COMMAND)
.add_command(&reminder_cmds::LOOK_COMMAND)
.add_command(&reminder_cmds::PAUSE_COMMAND)
.add_command(&reminder_cmds::OFFSET_COMMAND)
.add_command(&reminder_cmds::NUDGE_COMMAND)
// to-do commands
.add_command(&todo_cmds::TODO_COMMAND)
// moderation commands
.add_command(&moderation_cmds::TIMEZONE_COMMAND)
.add_command(&moderation_cmds::MACRO_CMD_COMMAND)
.add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK)
.add_hook(&hooks::MACRO_CHECK_HOOK);
let framework_arc = Arc::new(framework);
let mut client = Client::builder(&token)
.intents(GatewayIntents::GUILDS)
.application_id(application_id.0)
.event_handler(Handler { is_loop_running: AtomicBool::from(false) })
.await
.expect("Error occurred creating client");
{
let pool = MySqlPool::connect(
&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment"),
)
.await
.unwrap();
let popular_timezones = sqlx::query!( let popular_timezones = sqlx::query!(
"SELECT timezone FROM users GROUP BY timezone ORDER BY COUNT(timezone) DESC LIMIT 21" "
SELECT timezone FROM users GROUP BY timezone ORDER BY COUNT(timezone) DESC LIMIT 21
"
) )
.fetch_all(&pool) .fetch_all(&database)
.await .await
.unwrap() .unwrap()
.iter() .iter()
.map(|t| t.timezone.parse::<Tz>().unwrap()) .map(|t| t.timezone.parse::<Tz>().unwrap())
.collect::<Vec<Tz>>(); .collect::<Vec<Tz>>();
let mut data = client.data.write().await; poise::Framework::build()
.token(discord_token)
.user_data_setup(move |ctx, _bot, framework| {
Box::pin(async move {
ctx.set_activity(Activity::watching("for /remind")).await;
data.insert::<SQLPool>(pool); register_application_commands(
data.insert::<PopularTimezones>(Arc::new(popular_timezones)); ctx,
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new())); framework,
data.insert::<RegexFramework>(framework_arc.clone()); env::var("DEBUG_GUILD")
data.insert::<RecordingMacros>(Arc::new(RwLock::new(HashMap::new()))); .map(|inner| GuildId(inner.parse().expect("DEBUG_GUILD not valid")))
} .ok(),
)
.await
.unwrap();
framework_arc.build_slash(&client.cache_and_http.http).await; Ok(Data {
http: reqwest::Client::new(),
info!("Starting client as autosharded"); database,
popular_timezones,
client.start_autosharded().await?; recording_macros: Default::default(),
is_loop_running: AtomicBool::new(false),
})
})
})
.options(options)
.client_settings(move |client_builder| client_builder.intents(GatewayIntents::GUILDS))
.run_autosharded()
.await?;
Ok(()) Ok(())
} }

View File

@ -1,5 +1,5 @@
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use serenity::model::channel::Channel; use poise::serenity::model::channel::Channel;
use sqlx::MySqlPool; use sqlx::MySqlPool;
pub struct ChannelData { pub struct ChannelData {

View File

@ -1,33 +1,25 @@
use serenity::{client::Context, model::id::GuildId}; use poise::serenity::{
client::Context,
model::{
id::GuildId, interactions::application_command::ApplicationCommandInteractionDataOption,
},
};
use serde::Serialize;
use crate::{framework::CommandOptions, SQLPool}; #[derive(Serialize)]
pub struct RecordedCommand<U, E> {
#[serde(skip)]
action: for<'a> fn(
poise::ApplicationContext<'a, U, E>,
&'a [ApplicationCommandInteractionDataOption],
) -> poise::BoxFuture<'a, Result<(), poise::FrameworkError<'a, U, E>>>,
command_name: String,
options: Vec<ApplicationCommandInteractionDataOption>,
}
pub struct CommandMacro { pub struct CommandMacro<U, E> {
pub guild_id: GuildId, pub guild_id: GuildId,
pub name: String, pub name: String,
pub description: Option<String>, pub description: Option<String>,
pub commands: Vec<CommandOptions>, pub commands: Vec<RecordedCommand<U, E>>,
}
impl CommandMacro {
pub async fn from_guild(ctx: &Context, guild_id: impl Into<GuildId>) -> Vec<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let guild_id = guild_id.into();
sqlx::query!(
"SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
guild_id.0
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| Self {
guild_id,
name: row.name.clone(),
description: row.description.clone(),
commands: serde_json::from_str(&row.commands).unwrap(),
})
.collect::<Vec<Self>>()
}
} }

View File

@ -5,62 +5,47 @@ pub mod timer;
pub mod user_data; pub mod user_data;
use chrono_tz::Tz; use chrono_tz::Tz;
use serenity::{ use poise::serenity::{async_trait, model::id::UserId};
async_trait,
model::id::{ChannelId, UserId},
prelude::Context,
};
use crate::{ use crate::{
models::{channel_data::ChannelData, user_data::UserData}, models::{channel_data::ChannelData, user_data::UserData},
SQLPool, Context,
}; };
#[async_trait] #[async_trait]
pub trait CtxData { pub trait CtxData {
async fn user_data<U: Into<UserId> + Send + Sync>( async fn user_data<U: Into<UserId> + Send>(
&self, &self,
user_id: U, user_id: U,
) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>; ) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz; async fn author_data(&self) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
async fn channel_data<C: Into<ChannelId> + Send + Sync>( async fn timezone(&self) -> Tz;
&self,
channel_id: C, async fn channel_data(&self) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>>;
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>>;
} }
#[async_trait] #[async_trait]
impl CtxData for Context { impl CtxData for Context<'_> {
async fn user_data<U: Into<UserId> + Send + Sync>( async fn user_data<U: Into<UserId> + Send>(
&self, &self,
user_id: U, user_id: U,
) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> { ) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user_id.into(); UserData::from_user(user_id, &self.discord(), &self.data().database).await
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
let user = user_id.to_user(self).await.unwrap();
UserData::from_user(&user, &self, &pool).await
} }
async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz { async fn author_data(&self) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user_id.into(); UserData::from_user(&self.author().id, &self.discord(), &self.data().database).await
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
UserData::timezone_of(user_id, &pool).await
} }
async fn channel_data<C: Into<ChannelId> + Send + Sync>( async fn timezone(&self) -> Tz {
&self, UserData::timezone_of(self.author().id, &self.data().database).await
channel_id: C, }
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel_id.into();
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
let channel = channel_id.to_channel_cached(&self).unwrap(); async fn channel_data(&self) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>> {
let channel = self.channel_id().to_channel_cached(&self.discord()).unwrap();
ChannelData::from_channel(&channel, &pool).await ChannelData::from_channel(&channel, &self.data().database).await
} }
} }

View File

@ -2,8 +2,7 @@ use std::{collections::HashSet, fmt::Display};
use chrono::{Duration, NaiveDateTime, Utc}; use chrono::{Duration, NaiveDateTime, Utc};
use chrono_tz::Tz; use chrono_tz::Tz;
use serenity::{ use poise::serenity::{
client::Context,
http::CacheHttp, http::CacheHttp,
model::{ model::{
channel::GuildChannel, channel::GuildChannel,
@ -15,15 +14,14 @@ use serenity::{
use sqlx::MySqlPool; use sqlx::MySqlPool;
use crate::{ use crate::{
consts, consts::{DAY, DEFAULT_AVATAR, MAX_TIME, MIN_INTERVAL},
consts::{DAY, MAX_TIME, MIN_INTERVAL},
interval_parser::Interval, interval_parser::Interval,
models::{ models::{
channel_data::ChannelData, channel_data::ChannelData,
reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder}, reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
user_data::UserData, user_data::UserData,
}, },
SQLPool, Context,
}; };
async fn create_webhook( async fn create_webhook(
@ -31,7 +29,7 @@ async fn create_webhook(
channel: GuildChannel, channel: GuildChannel,
name: impl Display, name: impl Display,
) -> SerenityResult<Webhook> { ) -> SerenityResult<Webhook> {
channel.create_webhook_with_avatar(ctx.http(), name, consts::DEFAULT_AVATAR.clone()).await channel.create_webhook_with_avatar(ctx.http(), name, DEFAULT_AVATAR.clone()).await
} }
#[derive(Hash, PartialEq, Eq)] #[derive(Hash, PartialEq, Eq)]
@ -145,7 +143,7 @@ pub struct MultiReminderBuilder<'a> {
expires: Option<NaiveDateTime>, expires: Option<NaiveDateTime>,
content: Content, content: Content,
set_by: Option<u32>, set_by: Option<u32>,
ctx: &'a Context, ctx: &'a Context<'a>,
guild_id: Option<GuildId>, guild_id: Option<GuildId>,
} }
@ -210,8 +208,6 @@ impl<'a> MultiReminderBuilder<'a> {
} }
pub async fn build(self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) { pub async fn build(self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) {
let pool = self.ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let mut errors = HashSet::new(); let mut errors = HashSet::new();
let mut ok_locs = HashSet::new(); let mut ok_locs = HashSet::new();
@ -225,12 +221,17 @@ impl<'a> MultiReminderBuilder<'a> {
for scope in self.scopes { for scope in self.scopes {
let db_channel_id = match scope { let db_channel_id = match scope {
ReminderScope::User(user_id) => { ReminderScope::User(user_id) => {
if let Ok(user) = UserId(user_id).to_user(&self.ctx).await { if let Ok(user) = UserId(user_id).to_user(&self.ctx.discord()).await {
let user_data = let user_data = UserData::from_user(
UserData::from_user(&user, &self.ctx, &pool).await.unwrap(); &user,
&self.ctx.discord(),
&self.ctx.data().database,
)
.await
.unwrap();
if let Some(guild_id) = self.guild_id { if let Some(guild_id) = self.guild_id {
if guild_id.member(&self.ctx, user).await.is_err() { if guild_id.member(&self.ctx.discord(), user).await.is_err() {
Err(ReminderError::InvalidTag) Err(ReminderError::InvalidTag)
} else { } else {
Ok(user_data.dm_channel) Ok(user_data.dm_channel)
@ -243,26 +244,36 @@ impl<'a> MultiReminderBuilder<'a> {
} }
} }
ReminderScope::Channel(channel_id) => { ReminderScope::Channel(channel_id) => {
let channel = ChannelId(channel_id).to_channel(&self.ctx).await.unwrap(); let channel =
ChannelId(channel_id).to_channel(&self.ctx.discord()).await.unwrap();
if let Some(guild_channel) = channel.clone().guild() { if let Some(guild_channel) = channel.clone().guild() {
if Some(guild_channel.guild_id) != self.guild_id { if Some(guild_channel.guild_id) != self.guild_id {
Err(ReminderError::InvalidTag) Err(ReminderError::InvalidTag)
} else { } else {
let mut channel_data = let mut channel_data =
ChannelData::from_channel(&channel, &pool).await.unwrap(); ChannelData::from_channel(&channel, &self.ctx.data().database)
.await
.unwrap();
if channel_data.webhook_id.is_none() if channel_data.webhook_id.is_none()
|| channel_data.webhook_token.is_none() || channel_data.webhook_token.is_none()
{ {
match create_webhook(&self.ctx, guild_channel, "Reminder").await match create_webhook(
&self.ctx.discord(),
guild_channel,
"Reminder",
)
.await
{ {
Ok(webhook) => { Ok(webhook) => {
channel_data.webhook_id = channel_data.webhook_id =
Some(webhook.id.as_u64().to_owned()); Some(webhook.id.as_u64().to_owned());
channel_data.webhook_token = webhook.token; channel_data.webhook_token = webhook.token;
channel_data.commit_changes(&pool).await; channel_data
.commit_changes(&self.ctx.data().database)
.await;
Ok(channel_data.id) Ok(channel_data.id)
} }
@ -282,7 +293,7 @@ impl<'a> MultiReminderBuilder<'a> {
match db_channel_id { match db_channel_id {
Ok(c) => { Ok(c) => {
let builder = ReminderBuilder { let builder = ReminderBuilder {
pool: pool.clone(), pool: self.ctx.data().database.clone(),
uid: generate_uid(), uid: generate_uid(),
channel: c, channel: c,
utc_time: self.utc_time, utc_time: self.utc_time,

View File

@ -1,6 +1,6 @@
use poise::serenity::model::id::ChannelId;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_repr::*; use serde_repr::*;
use serenity::model::id::ChannelId;
#[derive(Serialize_repr, Deserialize_repr, Copy, Clone, Debug)] #[derive(Serialize_repr, Deserialize_repr, Copy, Clone, Debug)]
#[repr(u8)] #[repr(u8)]

View File

@ -6,15 +6,12 @@ pub mod look_flags;
use chrono::{NaiveDateTime, TimeZone}; use chrono::{NaiveDateTime, TimeZone};
use chrono_tz::Tz; use chrono_tz::Tz;
use serenity::{ use poise::serenity::model::id::{ChannelId, GuildId, UserId};
client::Context, use sqlx::Executor;
model::id::{ChannelId, GuildId, UserId},
};
use sqlx::MySqlPool;
use crate::{ use crate::{
models::reminder::look_flags::{LookFlags, TimeDisplayType}, models::reminder::look_flags::{LookFlags, TimeDisplayType},
SQLPool, Context, Data, Database,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -33,7 +30,10 @@ pub struct Reminder {
} }
impl Reminder { impl Reminder {
pub async fn from_uid(pool: &MySqlPool, uid: String) -> Option<Self> { pub async fn from_uid(
pool: impl Executor<'_, Database = Database>,
uid: String,
) -> Option<Self> {
sqlx::query_as_unchecked!( sqlx::query_as_unchecked!(
Self, Self,
" "
@ -70,12 +70,10 @@ WHERE
} }
pub async fn from_channel<C: Into<ChannelId>>( pub async fn from_channel<C: Into<ChannelId>>(
ctx: &Context, ctx: &Context<'_>,
channel_id: C, channel_id: C,
flags: &LookFlags, flags: &LookFlags,
) -> Vec<Self> { ) -> Vec<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let enabled = if flags.show_disabled { "0,1" } else { "1" }; let enabled = if flags.show_disabled { "0,1" } else { "1" };
let channel_id = channel_id.into(); let channel_id = channel_id.into();
@ -113,16 +111,18 @@ ORDER BY
channel_id.as_u64(), channel_id.as_u64(),
enabled, enabled,
) )
.fetch_all(&pool) .fetch_all(&ctx.data().database)
.await .await
.unwrap() .unwrap()
} }
pub async fn from_guild(ctx: &Context, guild_id: Option<GuildId>, user: UserId) -> Vec<Self> { pub async fn from_guild(
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); ctx: &Context<'_>,
guild_id: Option<GuildId>,
user: UserId,
) -> Vec<Self> {
if let Some(guild_id) = guild_id { if let Some(guild_id) = guild_id {
let guild_opt = guild_id.to_guild_cached(&ctx); let guild_opt = guild_id.to_guild_cached(&ctx.discord());
if let Some(guild) = guild_opt { if let Some(guild) = guild_opt {
let channels = guild let channels = guild
@ -163,7 +163,7 @@ WHERE
", ",
channels channels
) )
.fetch_all(&pool) .fetch_all(&ctx.data().database)
.await .await
} else { } else {
sqlx::query_as_unchecked!( sqlx::query_as_unchecked!(
@ -196,7 +196,7 @@ WHERE
", ",
guild_id.as_u64() guild_id.as_u64()
) )
.fetch_all(&pool) .fetch_all(&ctx.data().database)
.await .await
} }
} else { } else {
@ -230,7 +230,7 @@ WHERE
", ",
user.as_u64() user.as_u64()
) )
.fetch_all(&pool) .fetch_all(&ctx.data().database)
.await .await
} }
.unwrap() .unwrap()

View File

@ -1,9 +1,6 @@
use chrono_tz::Tz; use chrono_tz::Tz;
use log::error; use log::error;
use serenity::{ use poise::serenity::{http::CacheHttp, model::id::UserId};
http::CacheHttp,
model::{id::UserId, user::User},
};
use sqlx::MySqlPool; use sqlx::MySqlPool;
use crate::consts::LOCAL_TIMEZONE; use crate::consts::LOCAL_TIMEZONE;
@ -11,7 +8,6 @@ use crate::consts::LOCAL_TIMEZONE;
pub struct UserData { pub struct UserData {
pub id: u32, pub id: u32,
pub user: u64, pub user: u64,
pub name: String,
pub dm_channel: u32, pub dm_channel: u32,
pub timezone: String, pub timezone: String,
} }
@ -40,20 +36,20 @@ SELECT timezone FROM users WHERE user = ?
.unwrap() .unwrap()
} }
pub async fn from_user( pub async fn from_user<U: Into<UserId>>(
user: &User, user: U,
ctx: impl CacheHttp, ctx: impl CacheHttp,
pool: &MySqlPool, pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> { ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().to_owned(); let user_id = user.into();
match sqlx::query_as_unchecked!( match sqlx::query_as_unchecked!(
Self, Self,
" "
SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ? SELECT id, user, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ?
", ",
*LOCAL_TIMEZONE, *LOCAL_TIMEZONE,
user_id user_id.0
) )
.fetch_one(pool) .fetch_one(pool)
.await .await
@ -61,27 +57,24 @@ SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone
Ok(c) => Ok(c), Ok(c) => Ok(c),
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {
let dm_channel = user.create_dm_channel(ctx).await?; let dm_channel = user_id.create_dm_channel(ctx).await?;
let dm_id = dm_channel.id.as_u64().to_owned();
let pool_c = pool.clone(); let pool_c = pool.clone();
sqlx::query!( sqlx::query!(
" "
INSERT IGNORE INTO channels (channel) VALUES (?) INSERT IGNORE INTO channels (channel) VALUES (?)
", ",
dm_id dm_channel.id.0
) )
.execute(&pool_c) .execute(&pool_c)
.await?; .await?;
sqlx::query!( sqlx::query!(
" "
INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?) INSERT INTO users (user, dm_channel, timezone) VALUES (?, (SELECT id FROM channels WHERE channel = ?), ?)
", ",
user_id, user_id.0,
user.name, dm_channel.id.0,
dm_id,
*LOCAL_TIMEZONE *LOCAL_TIMEZONE
) )
.execute(&pool_c) .execute(&pool_c)
@ -90,9 +83,9 @@ INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FR
Ok(sqlx::query_as_unchecked!( Ok(sqlx::query_as_unchecked!(
Self, Self,
" "
SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ? SELECT id, user, dm_channel, timezone FROM users WHERE user = ?
", ",
user_id user_id.0
) )
.fetch_one(pool) .fetch_one(pool)
.await?) .await?)
@ -109,9 +102,8 @@ SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ?
pub async fn commit_changes(&self, pool: &MySqlPool) { pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!( sqlx::query!(
" "
UPDATE users SET name = ?, timezone = ? WHERE id = ? UPDATE users SET timezone = ? WHERE id = ?
", ",
self.name,
self.timezone, self.timezone,
self.id self.id
) )

View File

@ -1,9 +1,39 @@
use serenity::{ use poise::serenity::{
builder::CreateApplicationCommands,
http::CacheHttp, http::CacheHttp,
model::id::{GuildId, UserId}, model::id::{GuildId, UserId},
}; };
use crate::consts::{CNC_GUILD, SUBSCRIPTION_ROLES}; use crate::{
consts::{CNC_GUILD, SUBSCRIPTION_ROLES},
Data, Error,
};
pub async fn register_application_commands(
ctx: &poise::serenity::client::Context,
framework: &poise::Framework<Data, Error>,
guild_id: Option<GuildId>,
) -> Result<(), poise::serenity::Error> {
let mut commands_builder = CreateApplicationCommands::default();
let commands = &framework.options().commands;
for command in commands {
if let Some(slash_command) = command.create_as_slash_command() {
commands_builder.add_application_command(slash_command);
}
if let Some(context_menu_command) = command.create_as_context_menu_command() {
commands_builder.add_application_command(context_menu_command);
}
}
let commands_builder = poise::serenity::json::Value::Array(commands_builder.0);
if let Some(guild_id) = guild_id {
ctx.http.create_guild_application_commands(guild_id.0, &commands_builder).await?;
} else {
ctx.http.create_global_application_commands(&commands_builder).await?;
}
Ok(())
}
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool { pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD { if let Some(subscription_guild) = *CNC_GUILD {