From 620f054703ad9ca87d0fe5b8d1acee967e1679ac Mon Sep 17 00:00:00 2001 From: jude Date: Sat, 19 Feb 2022 13:28:24 +0000 Subject: [PATCH] extracted event handler. removed custom sharding code. extracted util functions --- README.md | 4 - src/commands/reminder_cmds.rs | 4 +- src/component_models/mod.rs | 3 +- src/consts.rs | 3 +- src/event_handlers.rs | 161 ++++++++++++++++++++++++ src/main.rs | 227 ++-------------------------------- src/models/reminder/helper.rs | 21 +--- src/utils.rs | 37 ++++++ 8 files changed, 214 insertions(+), 246 deletions(-) create mode 100644 src/event_handlers.rs create mode 100644 src/utils.rs diff --git a/README.md b/README.md index 1b52b54..0f20ff0 100644 --- a/README.md +++ b/README.md @@ -30,14 +30,10 @@ __Other Variables__ * `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor * `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users * `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to -* `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots * `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else * `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds -* `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran -* `SHARD_RANGE` - default `None`, if `SHARD_COUNT` is specified, specifies what range of shards to start on this process * `DM_ENABLED` - default `1`, if `1`, Reminder Bot will respond to direct messages ### Todo List * Convert aliases to macros -* Help command diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 82b6455..8696c41 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -11,7 +11,6 @@ use regex_command_attr::command; use serenity::{builder::CreateEmbed, client::Context, model::channel::Channel}; use crate::{ - check_guild_subscription, check_subscription, component_models::{ pager::{DelPager, LookPager, Pager}, ComponentDataModel, DelSelector, @@ -33,6 +32,7 @@ use crate::{ CtxData, }, time_parser::natural_parser, + utils::{check_guild_subscription, check_subscription}, SQLPool, }; @@ -323,7 +323,7 @@ async fn look(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) { .iter() .map(|reminder| reminder.display(&flags, &timezone)) .fold(0, |t, r| t + r.len()) - .div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH); + .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); let pager = LookPager::new(flags, timezone); diff --git a/src/component_models/mod.rs b/src/component_models/mod.rs index 18e0c2e..bf16b2b 100644 --- a/src/component_models/mod.rs +++ b/src/component_models/mod.rs @@ -3,7 +3,6 @@ pub(crate) mod pager; use std::io::Cursor; use chrono_tz::Tz; -use num_integer::Integer; use rmp_serde::Serializer; use serde::{Deserialize, Serialize}; use serenity::{ @@ -79,7 +78,7 @@ impl ComponentDataModel { .iter() .map(|reminder| reminder.display(&flags, &pager.timezone)) .fold(0, |t, r| t + r.len()) - .div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH); + .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH); let channel_name = if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx) { diff --git a/src/consts.rs b/src/consts.rs index e1dc822..3e6de0e 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -1,6 +1,5 @@ pub const DAY: u64 = 86_400; -pub const HOUR: u64 = 3_600; -pub const MINUTE: u64 = 60; + pub const EMBED_DESCRIPTION_MAX_LENGTH: usize = 4000; pub const SELECT_MAX_ENTRIES: usize = 25; diff --git a/src/event_handlers.rs b/src/event_handlers.rs new file mode 100644 index 0000000..3f42bf4 --- /dev/null +++ b/src/event_handlers.rs @@ -0,0 +1,161 @@ +use std::{collections::HashMap, env, sync::atomic::Ordering}; + +use log::{info, warn}; +use serenity::{ + async_trait, + client::{Context, EventHandler}, + model::{ + channel::GuildChannel, + gateway::{Activity, Ready}, + guild::{Guild, UnavailableGuild}, + id::GuildId, + interactions::Interaction, + }, + utils::shard_id, +}; + +use crate::{ComponentDataModel, Handler, RegexFramework, ReqwestClient, SQLPool}; + +#[async_trait] +impl EventHandler for Handler { + async fn cache_ready(&self, ctx_base: Context, _guilds: Vec) { + info!("Cache Ready!"); + info!("Preparing to send reminders"); + + if !self.is_loop_running.load(Ordering::Relaxed) { + let ctx1 = ctx_base.clone(); + let ctx2 = ctx_base.clone(); + + let pool1 = ctx1.data.read().await.get::().cloned().unwrap(); + let pool2 = ctx2.data.read().await.get::().cloned().unwrap(); + + let run_settings = env::var("DONTRUN").unwrap_or_else(|_| "".to_string()); + + if !run_settings.contains("postman") { + tokio::spawn(async move { + postman::initialize(ctx1, &pool1).await; + }); + } else { + warn!("Not running postman") + } + + if !run_settings.contains("web") { + tokio::spawn(async move { + reminder_web::initialize(ctx2, pool2).await.unwrap(); + }); + } else { + warn!("Not running web") + } + + self.is_loop_running.swap(true, Ordering::Relaxed); + } + } + + async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) { + let pool = ctx + .data + .read() + .await + .get::() + .cloned() + .expect("Could not get SQLPool from data"); + + sqlx::query!( + " +DELETE FROM channels WHERE channel = ? + ", + channel.id.as_u64() + ) + .execute(&pool) + .await + .unwrap(); + } + + async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) { + if is_new { + let guild_id = guild.id.as_u64().to_owned(); + + { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + + let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id) + .execute(&pool) + .await; + } + + if let Ok(token) = env::var("DISCORDBOTS_TOKEN") { + let shard_count = ctx.cache.shard_count(); + let current_shard_id = shard_id(guild_id, shard_count); + + let guild_count = ctx + .cache + .guilds() + .iter() + .filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id) + .count() as u64; + + let mut hm = HashMap::new(); + hm.insert("server_count", guild_count); + hm.insert("shard_id", current_shard_id); + hm.insert("shard_count", shard_count); + + let client = ctx + .data + .read() + .await + .get::() + .cloned() + .expect("Could not get ReqwestClient from data"); + + let response = client + .post( + format!( + "https://top.gg/api/bots/{}/stats", + ctx.cache.current_user_id().as_u64() + ) + .as_str(), + ) + .header("Authorization", token) + .json(&hm) + .send() + .await; + + if let Err(res) = response { + println!("DiscordBots Response: {:?}", res); + } + } + } + } + + async fn guild_delete(&self, ctx: Context, incomplete: UnavailableGuild, _full: Option) { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0) + .execute(&pool) + .await; + } + + async fn ready(&self, ctx: Context, _: Ready) { + ctx.set_activity(Activity::watching("for /remind")).await; + } + + async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + match interaction { + Interaction::ApplicationCommand(application_command) => { + let framework = ctx + .data + .read() + .await + .get::() + .cloned() + .expect("RegexFramework not found in context"); + + framework.execute(ctx, application_command).await; + } + Interaction::MessageComponent(component) => { + let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id); + component_model.act(&ctx, component).await; + } + _ => {} + } + } +} diff --git a/src/main.rs b/src/main.rs index 2530322..1912073 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,40 +1,35 @@ +#![feature(int_roundings)] #[macro_use] extern crate lazy_static; mod commands; mod component_models; mod consts; +mod event_handlers; mod framework; mod hooks; mod interval_parser; mod models; mod time_parser; +mod utils; use std::{ collections::HashMap, env, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::{atomic::AtomicBool, Arc}, }; use chrono_tz::Tz; use dotenv::dotenv; -use log::{info, warn}; +use log::info; use serenity::{ - async_trait, client::Client, - http::{client::Http, CacheHttp}, + http::client::Http, model::{ - channel::GuildChannel, - gateway::{Activity, GatewayIntents, Ready}, - guild::{Guild, UnavailableGuild}, + gateway::GatewayIntents, id::{GuildId, UserId}, - interactions::Interaction, }, - prelude::{Context, EventHandler, TypeMapKey}, - utils::shard_id, + prelude::TypeMapKey, }; use sqlx::mysql::MySqlPool; use tokio::sync::RwLock; @@ -42,7 +37,7 @@ use tokio::sync::RwLock; use crate::{ commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds}, component_models::ComponentDataModel, - consts::{CNC_GUILD, SUBSCRIPTION_ROLES, THEME_COLOR}, + consts::THEME_COLOR, framework::RegexFramework, models::command_macro::CommandMacro, }; @@ -75,150 +70,6 @@ struct Handler { is_loop_running: AtomicBool, } -#[async_trait] -impl EventHandler for Handler { - async fn cache_ready(&self, ctx_base: Context, _guilds: Vec) { - info!("Cache Ready!"); - info!("Preparing to send reminders"); - - if !self.is_loop_running.load(Ordering::Relaxed) { - let ctx1 = ctx_base.clone(); - let ctx2 = ctx_base.clone(); - - let pool1 = ctx1.data.read().await.get::().cloned().unwrap(); - let pool2 = ctx2.data.read().await.get::().cloned().unwrap(); - - let run_settings = env::var("DONTRUN").unwrap_or_else(|_| "".to_string()); - - if !run_settings.contains("postman") { - tokio::spawn(async move { - postman::initialize(ctx1, &pool1).await; - }); - } else { - warn!("Not running postman") - } - - if !run_settings.contains("web") { - tokio::spawn(async move { - reminder_web::initialize(ctx2, pool2).await.unwrap(); - }); - } else { - warn!("Not running web") - } - - self.is_loop_running.swap(true, Ordering::Relaxed); - } - } - - async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) { - let pool = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("Could not get SQLPool from data"); - - sqlx::query!( - " -DELETE FROM channels WHERE channel = ? - ", - channel.id.as_u64() - ) - .execute(&pool) - .await - .unwrap(); - } - - async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) { - if is_new { - let guild_id = guild.id.as_u64().to_owned(); - - { - let pool = ctx.data.read().await.get::().cloned().unwrap(); - - let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id) - .execute(&pool) - .await; - } - - if let Ok(token) = env::var("DISCORDBOTS_TOKEN") { - let shard_count = ctx.cache.shard_count(); - let current_shard_id = shard_id(guild_id, shard_count); - - let guild_count = ctx - .cache - .guilds() - .iter() - .filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id) - .count() as u64; - - let mut hm = HashMap::new(); - hm.insert("server_count", guild_count); - hm.insert("shard_id", current_shard_id); - hm.insert("shard_count", shard_count); - - let client = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("Could not get ReqwestClient from data"); - - let response = client - .post( - format!( - "https://top.gg/api/bots/{}/stats", - ctx.cache.current_user_id().as_u64() - ) - .as_str(), - ) - .header("Authorization", token) - .json(&hm) - .send() - .await; - - if let Err(res) = response { - println!("DiscordBots Response: {:?}", res); - } - } - } - } - - async fn guild_delete(&self, ctx: Context, incomplete: UnavailableGuild, _full: Option) { - let pool = ctx.data.read().await.get::().cloned().unwrap(); - let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0) - .execute(&pool) - .await; - } - - async fn ready(&self, ctx: Context, _: Ready) { - ctx.set_activity(Activity::watching("for /remind")).await; - } - - async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - match interaction { - Interaction::ApplicationCommand(application_command) => { - let framework = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("RegexFramework not found in context"); - - framework.execute(ctx, application_command).await; - } - Interaction::MessageComponent(component) => { - let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id); - component_model.act(&ctx, component).await; - } - _ => {} - } - } -} - #[tokio::main] async fn main() -> Result<(), Box> { env_logger::init(); @@ -301,65 +152,9 @@ async fn main() -> Result<(), Box> { framework_arc.build_slash(&client.cache_and_http.http).await; - if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| { - let mut split = - sr.split(',').map(|val| val.parse::().expect("SHARD_RANGE not an integer")); + info!("Starting client as autosharded"); - (split.next(), split.next()) - }) { - let total_shards = env::var("SHARD_COUNT") - .map(|shard_count| shard_count.parse::().ok()) - .ok() - .flatten() - .expect("No SHARD_COUNT provided, but SHARD_RANGE was provided"); - - assert!(lower < upper, "SHARD_RANGE lower limit is not less than the upper limit"); - - info!("Starting client fragment with shards {}-{}/{}", lower, upper, total_shards); - - client.start_shard_range([lower, upper], total_shards).await?; - } else if let Ok(total_shards) = env::var("SHARD_COUNT") - .map(|shard_count| shard_count.parse::().expect("SHARD_COUNT not an integer")) - { - info!("Starting client with {} shards", total_shards); - - client.start_shards(total_shards).await?; - } else { - info!("Starting client as autosharded"); - - client.start_autosharded().await?; - } + client.start_autosharded().await?; Ok(()) } - -pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into) -> bool { - if let Some(subscription_guild) = *CNC_GUILD { - let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await; - - if let Ok(member) = guild_member { - for role in member.roles { - if SUBSCRIPTION_ROLES.contains(role.as_u64()) { - return true; - } - } - } - - false - } else { - true - } -} - -pub async fn check_guild_subscription( - cache_http: impl CacheHttp, - guild_id: impl Into, -) -> bool { - if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) { - let owner = guild.owner_id; - - check_subscription(&cache_http, owner).await - } else { - false - } -} diff --git a/src/models/reminder/helper.rs b/src/models/reminder/helper.rs index 3156f52..54cb7b9 100644 --- a/src/models/reminder/helper.rs +++ b/src/models/reminder/helper.rs @@ -1,25 +1,6 @@ -use num_integer::Integer; use rand::{rngs::OsRng, seq::IteratorRandom}; -use crate::consts::{CHARACTERS, DAY, HOUR, MINUTE}; - -pub fn longhand_displacement(seconds: u64) -> String { - let (days, seconds) = seconds.div_rem(&DAY); - let (hours, seconds) = seconds.div_rem(&HOUR); - let (minutes, seconds) = seconds.div_rem(&MINUTE); - - let mut sections = vec![]; - - for (var, name) in - [days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter()) - { - if *var > 0 { - sections.push(format!("{} {}", var, name)); - } - } - - sections.join(", ") -} +use crate::consts::CHARACTERS; pub fn generate_uid() -> String { let mut generator: OsRng = Default::default(); diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..aa37258 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,37 @@ +use serenity::{ + http::CacheHttp, + model::id::{GuildId, UserId}, +}; + +use crate::consts::{CNC_GUILD, SUBSCRIPTION_ROLES}; + +pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into) -> bool { + if let Some(subscription_guild) = *CNC_GUILD { + let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await; + + if let Ok(member) = guild_member { + for role in member.roles { + if SUBSCRIPTION_ROLES.contains(role.as_u64()) { + return true; + } + } + } + + false + } else { + true + } +} + +pub async fn check_guild_subscription( + cache_http: impl CacheHttp, + guild_id: impl Into, +) -> bool { + if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) { + let owner = guild.owner_id; + + check_subscription(&cache_http, owner).await + } else { + false + } +}