2nd attempt at doing poise stuff

This commit is contained in:
jude
2022-02-19 14:32:03 +00:00
parent 620f054703
commit 84ee7e77c5
18 changed files with 1071 additions and 1035 deletions

View File

@ -3,70 +3,45 @@
extern crate lazy_static;
mod commands;
mod component_models;
// mod component_models;
mod consts;
mod event_handlers;
mod framework;
mod hooks;
mod interval_parser;
mod models;
mod time_parser;
mod utils;
use std::{
collections::HashMap,
env,
sync::{atomic::AtomicBool, Arc},
};
use std::{collections::HashMap, env, sync::atomic::AtomicBool};
use chrono_tz::Tz;
use dotenv::dotenv;
use log::info;
use serenity::{
client::Client,
http::client::Http,
model::{
gateway::GatewayIntents,
id::{GuildId, UserId},
},
prelude::TypeMapKey,
use poise::serenity::model::{
gateway::{Activity, GatewayIntents},
id::{GuildId, UserId},
};
use sqlx::mysql::MySqlPool;
use sqlx::{MySql, Pool};
use tokio::sync::RwLock;
use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds},
component_models::ComponentDataModel,
commands::{info_cmds, moderation_cmds},
consts::THEME_COLOR,
framework::RegexFramework,
event_handlers::listener,
hooks::all_checks,
models::command_macro::CommandMacro,
utils::register_application_commands,
};
struct SQLPool;
type Database = MySql;
impl TypeMapKey for SQLPool {
type Value = MySqlPool;
}
type Error = Box<dyn std::error::Error + Send + Sync>;
type Context<'a> = poise::Context<'a, Data, Error>;
struct ReqwestClient;
impl TypeMapKey for ReqwestClient {
type Value = Arc<reqwest::Client>;
}
struct PopularTimezones;
impl TypeMapKey for PopularTimezones {
type Value = Arc<Vec<Tz>>;
}
struct RecordingMacros;
impl TypeMapKey for RecordingMacros {
type Value = Arc<RwLock<HashMap<(GuildId, UserId), CommandMacro>>>;
}
struct Handler {
pub struct Data {
database: Pool<Database>,
http: reqwest::Client,
recording_macros: RwLock<HashMap<(GuildId, UserId), CommandMacro<Data, Error>>>,
popular_timezones: Vec<Tz>,
is_loop_running: AtomicBool,
}
@ -76,85 +51,77 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
dotenv()?;
let token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment");
let discord_token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment");
let application_id = {
let http = Http::new_with_token(&token);
http.get_current_application_info().await?.id
let options = poise::FrameworkOptions {
commands: vec![
info_cmds::help(),
info_cmds::info(),
info_cmds::donate(),
info_cmds::clock(),
info_cmds::dashboard(),
moderation_cmds::timezone(),
poise::Command {
subcommands: vec![
moderation_cmds::delete_macro(),
moderation_cmds::finish_macro(),
moderation_cmds::list_macro(),
moderation_cmds::record_macro(),
moderation_cmds::run_macro(),
],
..moderation_cmds::macro_base()
},
],
allowed_mentions: None,
command_check: Some(|ctx| Box::pin(all_checks(ctx))),
listener: |ctx, event, _framework, data| Box::pin(listener(ctx, event, data)),
..Default::default()
};
let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1");
let database =
Pool::connect(&env::var("DATABASE_URL").expect("No database URL provided")).await.unwrap();
let framework = RegexFramework::new()
.ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
.debug_guild(env::var("DEBUG_GUILD").map_or(None, |g| {
Some(GuildId(g.parse::<u64>().expect("DEBUG_GUILD must be a guild ID")))
}))
.dm_enabled(dm_enabled)
// info commands
.add_command(&info_cmds::HELP_COMMAND)
.add_command(&info_cmds::INFO_COMMAND)
.add_command(&info_cmds::DONATE_COMMAND)
.add_command(&info_cmds::DASHBOARD_COMMAND)
.add_command(&info_cmds::CLOCK_COMMAND)
// reminder commands
.add_command(&reminder_cmds::TIMER_COMMAND)
.add_command(&reminder_cmds::REMIND_COMMAND)
// management commands
.add_command(&reminder_cmds::DELETE_COMMAND)
.add_command(&reminder_cmds::LOOK_COMMAND)
.add_command(&reminder_cmds::PAUSE_COMMAND)
.add_command(&reminder_cmds::OFFSET_COMMAND)
.add_command(&reminder_cmds::NUDGE_COMMAND)
// to-do commands
.add_command(&todo_cmds::TODO_COMMAND)
// moderation commands
.add_command(&moderation_cmds::TIMEZONE_COMMAND)
.add_command(&moderation_cmds::MACRO_CMD_COMMAND)
.add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK)
.add_hook(&hooks::MACRO_CHECK_HOOK);
let popular_timezones = sqlx::query!(
"
SELECT timezone FROM users GROUP BY timezone ORDER BY COUNT(timezone) DESC LIMIT 21
"
)
.fetch_all(&database)
.await
.unwrap()
.iter()
.map(|t| t.timezone.parse::<Tz>().unwrap())
.collect::<Vec<Tz>>();
let framework_arc = Arc::new(framework);
poise::Framework::build()
.token(discord_token)
.user_data_setup(move |ctx, _bot, framework| {
Box::pin(async move {
ctx.set_activity(Activity::watching("for /remind")).await;
let mut client = Client::builder(&token)
.intents(GatewayIntents::GUILDS)
.application_id(application_id.0)
.event_handler(Handler { is_loop_running: AtomicBool::from(false) })
.await
.expect("Error occurred creating client");
register_application_commands(
ctx,
framework,
env::var("DEBUG_GUILD")
.map(|inner| GuildId(inner.parse().expect("DEBUG_GUILD not valid")))
.ok(),
)
.await
.unwrap();
{
let pool = MySqlPool::connect(
&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment"),
)
.await
.unwrap();
let popular_timezones = sqlx::query!(
"SELECT timezone FROM users GROUP BY timezone ORDER BY COUNT(timezone) DESC LIMIT 21"
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|t| t.timezone.parse::<Tz>().unwrap())
.collect::<Vec<Tz>>();
let mut data = client.data.write().await;
data.insert::<SQLPool>(pool);
data.insert::<PopularTimezones>(Arc::new(popular_timezones));
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));
data.insert::<RegexFramework>(framework_arc.clone());
data.insert::<RecordingMacros>(Arc::new(RwLock::new(HashMap::new())));
}
framework_arc.build_slash(&client.cache_and_http.http).await;
info!("Starting client as autosharded");
client.start_autosharded().await?;
Ok(Data {
http: reqwest::Client::new(),
database,
popular_timezones,
recording_macros: Default::default(),
is_loop_running: AtomicBool::new(false),
})
})
})
.options(options)
.client_settings(move |client_builder| client_builder.intents(GatewayIntents::GUILDS))
.run_autosharded()
.await?;
Ok(())
}