extracted event handler. removed custom sharding code. extracted util functions

This commit is contained in:
jude 2022-02-19 13:28:24 +00:00
parent cb471c52f3
commit 620f054703
8 changed files with 214 additions and 246 deletions

View File

@ -30,14 +30,10 @@ __Other Variables__
* `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor * `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor
* `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users * `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users
* `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to * `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to
* `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots
* `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else * `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else
* `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds * `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds
* `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran
* `SHARD_RANGE` - default `None`, if `SHARD_COUNT` is specified, specifies what range of shards to start on this process
* `DM_ENABLED` - default `1`, if `1`, Reminder Bot will respond to direct messages * `DM_ENABLED` - default `1`, if `1`, Reminder Bot will respond to direct messages
### Todo List ### Todo List
* Convert aliases to macros * Convert aliases to macros
* Help command

View File

@ -11,7 +11,6 @@ use regex_command_attr::command;
use serenity::{builder::CreateEmbed, client::Context, model::channel::Channel}; use serenity::{builder::CreateEmbed, client::Context, model::channel::Channel};
use crate::{ use crate::{
check_guild_subscription, check_subscription,
component_models::{ component_models::{
pager::{DelPager, LookPager, Pager}, pager::{DelPager, LookPager, Pager},
ComponentDataModel, DelSelector, ComponentDataModel, DelSelector,
@ -33,6 +32,7 @@ use crate::{
CtxData, CtxData,
}, },
time_parser::natural_parser, time_parser::natural_parser,
utils::{check_guild_subscription, check_subscription},
SQLPool, SQLPool,
}; };
@ -323,7 +323,7 @@ async fn look(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) {
.iter() .iter()
.map(|reminder| reminder.display(&flags, &timezone)) .map(|reminder| reminder.display(&flags, &timezone))
.fold(0, |t, r| t + r.len()) .fold(0, |t, r| t + r.len())
.div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH); .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH);
let pager = LookPager::new(flags, timezone); let pager = LookPager::new(flags, timezone);

View File

@ -3,7 +3,6 @@ pub(crate) mod pager;
use std::io::Cursor; use std::io::Cursor;
use chrono_tz::Tz; use chrono_tz::Tz;
use num_integer::Integer;
use rmp_serde::Serializer; use rmp_serde::Serializer;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serenity::{ use serenity::{
@ -79,7 +78,7 @@ impl ComponentDataModel {
.iter() .iter()
.map(|reminder| reminder.display(&flags, &pager.timezone)) .map(|reminder| reminder.display(&flags, &pager.timezone))
.fold(0, |t, r| t + r.len()) .fold(0, |t, r| t + r.len())
.div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH); .div_ceil(EMBED_DESCRIPTION_MAX_LENGTH);
let channel_name = let channel_name =
if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx) { if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx) {

View File

@ -1,6 +1,5 @@
pub const DAY: u64 = 86_400; pub const DAY: u64 = 86_400;
pub const HOUR: u64 = 3_600;
pub const MINUTE: u64 = 60;
pub const EMBED_DESCRIPTION_MAX_LENGTH: usize = 4000; pub const EMBED_DESCRIPTION_MAX_LENGTH: usize = 4000;
pub const SELECT_MAX_ENTRIES: usize = 25; pub const SELECT_MAX_ENTRIES: usize = 25;

161
src/event_handlers.rs Normal file
View File

@ -0,0 +1,161 @@
use std::{collections::HashMap, env, sync::atomic::Ordering};
use log::{info, warn};
use serenity::{
async_trait,
client::{Context, EventHandler},
model::{
channel::GuildChannel,
gateway::{Activity, Ready},
guild::{Guild, UnavailableGuild},
id::GuildId,
interactions::Interaction,
},
utils::shard_id,
};
use crate::{ComponentDataModel, Handler, RegexFramework, ReqwestClient, SQLPool};
#[async_trait]
impl EventHandler for Handler {
async fn cache_ready(&self, ctx_base: Context, _guilds: Vec<GuildId>) {
info!("Cache Ready!");
info!("Preparing to send reminders");
if !self.is_loop_running.load(Ordering::Relaxed) {
let ctx1 = ctx_base.clone();
let ctx2 = ctx_base.clone();
let pool1 = ctx1.data.read().await.get::<SQLPool>().cloned().unwrap();
let pool2 = ctx2.data.read().await.get::<SQLPool>().cloned().unwrap();
let run_settings = env::var("DONTRUN").unwrap_or_else(|_| "".to_string());
if !run_settings.contains("postman") {
tokio::spawn(async move {
postman::initialize(ctx1, &pool1).await;
});
} else {
warn!("Not running postman")
}
if !run_settings.contains("web") {
tokio::spawn(async move {
reminder_web::initialize(ctx2, pool2).await.unwrap();
});
} else {
warn!("Not running web")
}
self.is_loop_running.swap(true, Ordering::Relaxed);
}
}
async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
sqlx::query!(
"
DELETE FROM channels WHERE channel = ?
",
channel.id.as_u64()
)
.execute(&pool)
.await
.unwrap();
}
async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) {
if is_new {
let guild_id = guild.id.as_u64().to_owned();
{
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id)
.execute(&pool)
.await;
}
if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
let shard_count = ctx.cache.shard_count();
let current_shard_id = shard_id(guild_id, shard_count);
let guild_count = ctx
.cache
.guilds()
.iter()
.filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id)
.count() as u64;
let mut hm = HashMap::new();
hm.insert("server_count", guild_count);
hm.insert("shard_id", current_shard_id);
hm.insert("shard_count", shard_count);
let client = ctx
.data
.read()
.await
.get::<ReqwestClient>()
.cloned()
.expect("Could not get ReqwestClient from data");
let response = client
.post(
format!(
"https://top.gg/api/bots/{}/stats",
ctx.cache.current_user_id().as_u64()
)
.as_str(),
)
.header("Authorization", token)
.json(&hm)
.send()
.await;
if let Err(res) = response {
println!("DiscordBots Response: {:?}", res);
}
}
}
}
async fn guild_delete(&self, ctx: Context, incomplete: UnavailableGuild, _full: Option<Guild>) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0)
.execute(&pool)
.await;
}
async fn ready(&self, ctx: Context, _: Ready) {
ctx.set_activity(Activity::watching("for /remind")).await;
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
match interaction {
Interaction::ApplicationCommand(application_command) => {
let framework = ctx
.data
.read()
.await
.get::<RegexFramework>()
.cloned()
.expect("RegexFramework not found in context");
framework.execute(ctx, application_command).await;
}
Interaction::MessageComponent(component) => {
let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id);
component_model.act(&ctx, component).await;
}
_ => {}
}
}
}

View File

@ -1,40 +1,35 @@
#![feature(int_roundings)]
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
mod commands; mod commands;
mod component_models; mod component_models;
mod consts; mod consts;
mod event_handlers;
mod framework; mod framework;
mod hooks; mod hooks;
mod interval_parser; mod interval_parser;
mod models; mod models;
mod time_parser; mod time_parser;
mod utils;
use std::{ use std::{
collections::HashMap, collections::HashMap,
env, env,
sync::{ sync::{atomic::AtomicBool, Arc},
atomic::{AtomicBool, Ordering},
Arc,
},
}; };
use chrono_tz::Tz; use chrono_tz::Tz;
use dotenv::dotenv; use dotenv::dotenv;
use log::{info, warn}; use log::info;
use serenity::{ use serenity::{
async_trait,
client::Client, client::Client,
http::{client::Http, CacheHttp}, http::client::Http,
model::{ model::{
channel::GuildChannel, gateway::GatewayIntents,
gateway::{Activity, GatewayIntents, Ready},
guild::{Guild, UnavailableGuild},
id::{GuildId, UserId}, id::{GuildId, UserId},
interactions::Interaction,
}, },
prelude::{Context, EventHandler, TypeMapKey}, prelude::TypeMapKey,
utils::shard_id,
}; };
use sqlx::mysql::MySqlPool; use sqlx::mysql::MySqlPool;
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -42,7 +37,7 @@ use tokio::sync::RwLock;
use crate::{ use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds}, commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds},
component_models::ComponentDataModel, component_models::ComponentDataModel,
consts::{CNC_GUILD, SUBSCRIPTION_ROLES, THEME_COLOR}, consts::THEME_COLOR,
framework::RegexFramework, framework::RegexFramework,
models::command_macro::CommandMacro, models::command_macro::CommandMacro,
}; };
@ -75,150 +70,6 @@ struct Handler {
is_loop_running: AtomicBool, is_loop_running: AtomicBool,
} }
#[async_trait]
impl EventHandler for Handler {
async fn cache_ready(&self, ctx_base: Context, _guilds: Vec<GuildId>) {
info!("Cache Ready!");
info!("Preparing to send reminders");
if !self.is_loop_running.load(Ordering::Relaxed) {
let ctx1 = ctx_base.clone();
let ctx2 = ctx_base.clone();
let pool1 = ctx1.data.read().await.get::<SQLPool>().cloned().unwrap();
let pool2 = ctx2.data.read().await.get::<SQLPool>().cloned().unwrap();
let run_settings = env::var("DONTRUN").unwrap_or_else(|_| "".to_string());
if !run_settings.contains("postman") {
tokio::spawn(async move {
postman::initialize(ctx1, &pool1).await;
});
} else {
warn!("Not running postman")
}
if !run_settings.contains("web") {
tokio::spawn(async move {
reminder_web::initialize(ctx2, pool2).await.unwrap();
});
} else {
warn!("Not running web")
}
self.is_loop_running.swap(true, Ordering::Relaxed);
}
}
async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
sqlx::query!(
"
DELETE FROM channels WHERE channel = ?
",
channel.id.as_u64()
)
.execute(&pool)
.await
.unwrap();
}
async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) {
if is_new {
let guild_id = guild.id.as_u64().to_owned();
{
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id)
.execute(&pool)
.await;
}
if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
let shard_count = ctx.cache.shard_count();
let current_shard_id = shard_id(guild_id, shard_count);
let guild_count = ctx
.cache
.guilds()
.iter()
.filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id)
.count() as u64;
let mut hm = HashMap::new();
hm.insert("server_count", guild_count);
hm.insert("shard_id", current_shard_id);
hm.insert("shard_count", shard_count);
let client = ctx
.data
.read()
.await
.get::<ReqwestClient>()
.cloned()
.expect("Could not get ReqwestClient from data");
let response = client
.post(
format!(
"https://top.gg/api/bots/{}/stats",
ctx.cache.current_user_id().as_u64()
)
.as_str(),
)
.header("Authorization", token)
.json(&hm)
.send()
.await;
if let Err(res) = response {
println!("DiscordBots Response: {:?}", res);
}
}
}
}
async fn guild_delete(&self, ctx: Context, incomplete: UnavailableGuild, _full: Option<Guild>) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0)
.execute(&pool)
.await;
}
async fn ready(&self, ctx: Context, _: Ready) {
ctx.set_activity(Activity::watching("for /remind")).await;
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
match interaction {
Interaction::ApplicationCommand(application_command) => {
let framework = ctx
.data
.read()
.await
.get::<RegexFramework>()
.cloned()
.expect("RegexFramework not found in context");
framework.execute(ctx, application_command).await;
}
Interaction::MessageComponent(component) => {
let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id);
component_model.act(&ctx, component).await;
}
_ => {}
}
}
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
env_logger::init(); env_logger::init();
@ -301,65 +152,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
framework_arc.build_slash(&client.cache_and_http.http).await; framework_arc.build_slash(&client.cache_and_http.http).await;
if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| { info!("Starting client as autosharded");
let mut split =
sr.split(',').map(|val| val.parse::<u64>().expect("SHARD_RANGE not an integer"));
(split.next(), split.next()) client.start_autosharded().await?;
}) {
let total_shards = env::var("SHARD_COUNT")
.map(|shard_count| shard_count.parse::<u64>().ok())
.ok()
.flatten()
.expect("No SHARD_COUNT provided, but SHARD_RANGE was provided");
assert!(lower < upper, "SHARD_RANGE lower limit is not less than the upper limit");
info!("Starting client fragment with shards {}-{}/{}", lower, upper, total_shards);
client.start_shard_range([lower, upper], total_shards).await?;
} else if let Ok(total_shards) = env::var("SHARD_COUNT")
.map(|shard_count| shard_count.parse::<u64>().expect("SHARD_COUNT not an integer"))
{
info!("Starting client with {} shards", total_shards);
client.start_shards(total_shards).await?;
} else {
info!("Starting client as autosharded");
client.start_autosharded().await?;
}
Ok(()) Ok(())
} }
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await;
if let Ok(member) = guild_member {
for role in member.roles {
if SUBSCRIPTION_ROLES.contains(role.as_u64()) {
return true;
}
}
}
false
} else {
true
}
}
pub async fn check_guild_subscription(
cache_http: impl CacheHttp,
guild_id: impl Into<GuildId>,
) -> bool {
if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) {
let owner = guild.owner_id;
check_subscription(&cache_http, owner).await
} else {
false
}
}

View File

@ -1,25 +1,6 @@
use num_integer::Integer;
use rand::{rngs::OsRng, seq::IteratorRandom}; use rand::{rngs::OsRng, seq::IteratorRandom};
use crate::consts::{CHARACTERS, DAY, HOUR, MINUTE}; use crate::consts::CHARACTERS;
pub fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let mut sections = vec![];
for (var, name) in
[days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter())
{
if *var > 0 {
sections.push(format!("{} {}", var, name));
}
}
sections.join(", ")
}
pub fn generate_uid() -> String { pub fn generate_uid() -> String {
let mut generator: OsRng = Default::default(); let mut generator: OsRng = Default::default();

37
src/utils.rs Normal file
View File

@ -0,0 +1,37 @@
use serenity::{
http::CacheHttp,
model::id::{GuildId, UserId},
};
use crate::consts::{CNC_GUILD, SUBSCRIPTION_ROLES};
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await;
if let Ok(member) = guild_member {
for role in member.roles {
if SUBSCRIPTION_ROLES.contains(role.as_u64()) {
return true;
}
}
}
false
} else {
true
}
}
pub async fn check_guild_subscription(
cache_http: impl CacheHttp,
guild_id: impl Into<GuildId>,
) -> bool {
if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) {
let owner = guild.owner_id;
check_subscription(&cache_http, owner).await
} else {
false
}
}