linked everything together

This commit is contained in:
jellywx 2021-09-10 18:09:25 +01:00
parent c148cdf556
commit 471948bed3
5 changed files with 385 additions and 290 deletions

433
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -20,7 +20,6 @@ num-integer = "0.1"
serde = "1.0"
serde_json = "1.0"
rand = "0.7"
Inflector = "0.11"
levenshtein = "1.0"
serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["collector", "unstable_discord_api"] }
sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]}

View File

@ -1,17 +1,12 @@
use std::{
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use chrono::offset::Utc;
use regex_command_attr::command;
use serenity::{builder::CreateEmbedFooter, client::Context, model::channel::Message};
use serenity::{builder::CreateEmbedFooter, client::Context};
use crate::{
consts::DEFAULT_PREFIX,
framework::{CommandInvoke, CreateGenericResponse},
models::{user_data::UserData, CtxData},
FrameworkCtx, THEME_COLOR,
models::CtxData,
THEME_COLOR,
};
fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter {
@ -121,8 +116,8 @@ async fn dashboard(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) {
#[description("View the current time in your selected timezone")]
#[group("Info")]
async fn clock(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) {
let ud = ctx.user_data(&msg.author).await.unwrap();
let now = Utc::now().with_timezone(ud.timezone());
let ud = ctx.user_data(&invoke.author_id()).await.unwrap();
let now = Utc::now().with_timezone(&ud.timezone());
invoke
.respond(

View File

@ -14,15 +14,19 @@ use serenity::{
framework::Framework,
futures::prelude::future::BoxFuture,
http::Http,
json::Value,
model::{
channel::{Channel, GuildChannel, Message},
guild::{Guild, Member},
id::{ChannelId, GuildId, MessageId, UserId},
interactions::{
application_command::{ApplicationCommandInteraction, ApplicationCommandOptionType},
application_command::{
ApplicationCommand, ApplicationCommandInteraction, ApplicationCommandOptionType,
},
InteractionResponseType,
},
},
prelude::TypeMapKey,
FutureExt, Result as SerenityResult,
};
@ -38,18 +42,6 @@ pub enum PermissionLevel {
Restricted,
}
pub struct Args {
pub args: HashMap<String, String>,
}
impl Args {
pub fn named<D: ToString>(&self, name: D) -> Option<&String> {
let name = name.to_string();
self.args.get(&name)
}
}
pub struct CreateGenericResponse {
content: String,
embed: Option<CreateEmbed>,
@ -294,7 +286,7 @@ pub struct Arg {
type SlashCommandFn = for<'fut> fn(
&'fut Context,
&'fut (dyn CommandInvoke + Sync + Send),
Args,
HashMap<String, String>,
) -> BoxFuture<'fut, ()>;
type TextCommandFn = for<'fut> fn(
@ -443,6 +435,11 @@ pub struct RegexFramework {
case_insensitive: bool,
dm_enabled: bool,
default_text_fun: TextCommandFn,
debug_guild: Option<GuildId>,
}
impl TypeMapKey for RegexFramework {
type Value = Arc<RegexFramework>;
}
fn drop_text<'fut>(
@ -467,6 +464,7 @@ impl RegexFramework {
case_insensitive: true,
dm_enabled: true,
default_text_fun: drop_text,
debug_guild: None,
}
}
@ -504,6 +502,12 @@ impl RegexFramework {
self
}
pub fn debug_guild(mut self, guild_id: Option<GuildId>) -> Self {
self.debug_guild = guild_id;
self
}
pub fn build(mut self) -> Self {
{
let command_names;
@ -571,6 +575,133 @@ impl RegexFramework {
self
}
pub async fn build_slash(&self, http: impl AsRef<Http>) {
info!("Building slash commands...");
match self.debug_guild {
None => {
ApplicationCommand::set_global_application_commands(&http, |commands| {
for command in &self.commands {
commands.create_application_command(|c| {
c.name(command.names[0]).description(command.desc);
for arg in command.args {
c.create_option(|o| {
o.name(arg.name)
.description(arg.description)
.kind(arg.kind)
.required(arg.required)
});
}
c
});
}
commands
})
.await;
}
Some(debug_guild) => {
debug_guild
.set_application_commands(&http, |commands| {
for command in &self.commands {
commands.create_application_command(|c| {
c.name(command.names[0]).description(command.desc);
for arg in command.args {
c.create_option(|o| {
o.name(arg.name)
.description(arg.description)
.kind(arg.kind)
.required(arg.required)
});
}
c
});
}
commands
})
.await;
}
}
info!("Slash commands built!");
}
pub async fn execute(&self, ctx: Context, interaction: ApplicationCommandInteraction) {
let command = {
self.commands_map
.get(&interaction.data.name)
.expect(&format!(
"Received invalid command: {}",
interaction.data.name
))
};
let guild = interaction.guild(ctx.cache.clone()).unwrap();
let member = interaction.clone().member.unwrap();
if command.check_permissions(&ctx, &guild, &member).await {
let mut args = HashMap::new();
for arg in interaction
.data
.options
.iter()
.filter(|o| o.value.is_some())
{
args.insert(
arg.name.clone(),
match arg.value.clone().unwrap() {
Value::Bool(b) => {
if b {
arg.name.clone()
} else {
String::new()
}
}
Value::Number(n) => n.to_string(),
Value::String(s) => s,
_ => String::new(),
},
);
}
if !ctx.check_executing(interaction.author_id()).await {
ctx.set_executing(interaction.author_id()).await;
match command.fun {
CommandFnType::Slash(t) => t(&ctx, &interaction, args).await,
CommandFnType::Multi(m) => m(&ctx, &interaction).await,
_ => (),
}
ctx.drop_executing(interaction.author_id()).await;
}
} else if command.required_permissions == PermissionLevel::Restricted {
let _ = interaction
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content(
"You must have the `Manage Server` permission to use this command.",
),
)
.await;
} else if command.required_permissions == PermissionLevel::Managed {
let _ = interaction
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content(
"You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `/restrict` command to specify allowed roles.",
),
)
.await;
}
}
}
enum PermissionCheck {

View File

@ -9,11 +9,9 @@ mod time_parser;
use std::{collections::HashMap, env, sync::Arc, time::Instant};
use chrono::Utc;
use chrono_tz::Tz;
use dashmap::DashMap;
use dotenv::dotenv;
use inflector::Inflector;
use log::info;
use serenity::{
async_trait,
@ -23,11 +21,10 @@ use serenity::{
http::{client::Http, CacheHttp},
model::{
channel::{GuildChannel, Message},
gateway::{Activity, Ready},
guild::{Guild, GuildUnavailable},
id::{GuildId, UserId},
interactions::{
Interaction, InteractionApplicationCommandCallbackDataFlags, InteractionResponseType,
},
interactions::Interaction,
},
prelude::{Context, EventHandler, TypeMapKey},
utils::shard_id,
@ -39,11 +36,7 @@ use crate::{
commands::info_cmds,
consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR},
framework::RegexFramework,
models::{
guild_data::GuildData,
reminder::{Reminder, ReminderAction},
user_data::UserData,
},
models::guild_data::GuildData,
};
struct GuildDataCache;
@ -64,12 +57,6 @@ impl TypeMapKey for ReqwestClient {
type Value = Arc<reqwest::Client>;
}
struct FrameworkCtx;
impl TypeMapKey for FrameworkCtx {
type Value = Arc<RegexFramework>;
}
struct PopularTimezones;
impl TypeMapKey for PopularTimezones {
@ -139,6 +126,18 @@ struct Handler;
#[async_trait]
impl EventHandler for Handler {
async fn cache_ready(&self, ctx: Context, _: Vec<GuildId>) {
let framework = ctx
.data
.read()
.await
.get::<RegexFramework>()
.cloned()
.expect("RegexFramework not found in context");
framework.build_slash(ctx).await;
}
async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) {
let pool = ctx
.data
@ -256,6 +255,31 @@ DELETE FROM guilds WHERE guild = ?
.await
.unwrap();
}
async fn ready(&self, ctx: Context, _: Ready) {
ctx.set_activity(Activity::watching("for /remind")).await;
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
match interaction {
Interaction::ApplicationCommand(application_command) => {
if application_command.guild_id.is_none() {
return;
}
let framework = ctx
.data
.read()
.await
.get::<RegexFramework>()
.cloned()
.expect("RegexFramework not found in context");
framework.execute(ctx, application_command).await;
}
_ => {}
}
}
}
#[tokio::main]
@ -280,13 +304,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.default_prefix(DEFAULT_PREFIX.clone())
.case_insensitive(env::var("CASE_INSENSITIVE").map_or(true, |var| var == "1"))
.ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
.debug_guild(env::var("DEBUG_GUILD").map_or(None, |g| {
Some(GuildId(
g.parse::<u64>().expect("DEBUG_GUILD must be a guild ID"),
))
}))
.dm_enabled(dm_enabled)
// info commands
//.add_command("help", &info_cmds::HELP_COMMAND)
.add_command(&info_cmds::INFO_COMMAND)
.add_command(&info_cmds::DONATE_COMMAND)
//.add_command("dashboard", &info_cmds::DASHBOARD_COMMAND)
//.add_command("clock", &info_cmds::CLOCK_COMMAND)
.add_command(&info_cmds::DASHBOARD_COMMAND)
.add_command(&info_cmds::CLOCK_COMMAND)
// reminder commands
/*
.add_command("timer", &reminder_cmds::TIMER_COMMAND)
@ -364,7 +393,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
data.insert::<SQLPool>(pool);
data.insert::<PopularTimezones>(Arc::new(popular_timezones));
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));
data.insert::<FrameworkCtx>(framework_arc.clone());
data.insert::<RegexFramework>(framework_arc.clone());
}
if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| {