From 1556318d0759144e277107b2cb7508baa824d4b2 Mon Sep 17 00:00:00 2001 From: jellywx Date: Fri, 16 Jul 2021 21:28:51 +0100 Subject: [PATCH] turned models into module --- src/commands/info_cmds.rs | 21 +- src/commands/moderation_cmds.rs | 62 +---- src/commands/reminder_cmds.rs | 63 ++--- src/commands/todo_cmds.rs | 6 +- src/framework.rs | 11 +- src/main.rs | 6 +- src/models.rs | 452 -------------------------------- src/models/channel_data.rs | 67 +++++ src/models/guild_data.rs | 79 ++++++ src/models/mod.rs | 77 ++++++ src/models/timer.rs | 50 ++++ src/models/user_data.rs | 146 +++++++++++ 12 files changed, 475 insertions(+), 565 deletions(-) delete mode 100644 src/models.rs create mode 100644 src/models/channel_data.rs create mode 100644 src/models/guild_data.rs create mode 100644 src/models/mod.rs create mode 100644 src/models/timer.rs create mode 100644 src/models/user_data.rs diff --git a/src/commands/info_cmds.rs b/src/commands/info_cmds.rs index bd5456e..775dd6d 100644 --- a/src/commands/info_cmds.rs +++ b/src/commands/info_cmds.rs @@ -1,18 +1,22 @@ use regex_command_attr::command; -use serenity::{client::Context, model::channel::Message}; +use serenity::{builder::CreateEmbedFooter, client::Context, model::channel::Message}; use chrono::offset::Utc; use crate::{ - command_help, consts::DEFAULT_PREFIX, get_ctx_data, language_manager::LanguageManager, - models::UserData, FrameworkCtx, THEME_COLOR, + command_help, + consts::DEFAULT_PREFIX, + get_ctx_data, + language_manager::LanguageManager, + models::{user_data::UserData, CtxGuildData}, + FrameworkCtx, THEME_COLOR, }; -use crate::models::CtxGuildData; -use serenity::builder::CreateEmbedFooter; -use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; #[command] #[can_blacklist(false)] @@ -202,7 +206,6 @@ async fn clock(ctx: &Context, msg: &Message, _args: String) { let language = UserData::language_of(&msg.author, &pool).await; let timezone = UserData::timezone_of(&msg.author, &pool).await; - let meridian = UserData::meridian_of(&msg.author, &pool).await; let now = Utc::now().with_timezone(&timezone); @@ -212,7 +215,7 @@ async fn clock(ctx: &Context, msg: &Message, _args: String) { .channel_id .say( &ctx, - clock_display.replacen("{}", &now.format(meridian.fmt_str()).to_string(), 1), + clock_display.replacen("{}", &now.format("%H:%M").to_string(), 1), ) .await; } diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index 9d24a73..207209b 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -24,11 +24,10 @@ use crate::{ consts::{REGEX_ALIAS, REGEX_CHANNEL, REGEX_COMMANDS, REGEX_ROLE, THEME_COLOR}, framework::SendIterator, get_ctx_data, - models::{ChannelData, GuildData, UserData}, + models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxGuildData}, FrameworkCtx, PopularTimezones, }; -use crate::models::CtxGuildData; use std::{collections::HashMap, iter}; #[command] @@ -113,11 +112,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) { let content = lm .get(&user_data.language, "timezone/set_p") .replacen("{timezone}", &user_data.timezone, 1) - .replacen( - "{time}", - &now.format(user_data.meridian().fmt_str_short()).to_string(), - 1, - ); + .replacen("{time}", &now.format("%H:%M").to_string(), 1); let _ = msg.channel_id @@ -154,10 +149,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) { tz.to_string(), format!( "🕗 `{}`", - Utc::now() - .with_timezone(tz) - .format(user_data.meridian().fmt_str_short()) - .to_string() + Utc::now().with_timezone(tz).format("%H:%M").to_string() ), true, ) @@ -211,10 +203,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) { t.to_string(), format!( "🕗 `{}`", - Utc::now() - .with_timezone(t) - .format(user_data.meridian().fmt_str_short()) - .to_string() + Utc::now().with_timezone(t).format("%H:%M").to_string() ), true, ) @@ -252,49 +241,6 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) { } } -#[command("meridian")] -async fn change_meridian(ctx: &Context, msg: &Message, args: String) { - let (pool, lm) = get_ctx_data(&ctx).await; - - let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); - - if &args == "12" { - user_data.meridian_time = true; - - user_data.commit_changes(&pool).await; - - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title(lm.get(&user_data.language, "meridian/title")) - .color(*THEME_COLOR) - .description(lm.get(&user_data.language, "meridian/12")) - }) - }) - .await; - } else if &args == "24" { - user_data.meridian_time = false; - - user_data.commit_changes(&pool).await; - - let _ = msg - .channel_id - .send_message(&ctx, |m| { - m.embed(|e| { - e.title(lm.get(&user_data.language, "meridian/title")) - .color(*THEME_COLOR) - .description(lm.get(&user_data.language, "meridian/24")) - }) - }) - .await; - } else { - let prefix = ctx.prefix(msg.guild_id).await; - - command_help(ctx, msg, lm, &prefix, &user_data.language, "meridian").await; - } -} - #[command("lang")] async fn language(ctx: &Context, msg: &Message, args: String) { let (pool, lm) = get_ctx_data(&ctx).await; diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index d987222..e3c1ec5 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -24,11 +24,14 @@ use crate::{ }, framework::SendIterator, get_ctx_data, - models::{ChannelData, CtxGuildData, GuildData, Timer, UserData}, + models::{ + channel_data::ChannelData, guild_data::GuildData, timer::Timer, user_data::UserData, + CtxGuildData, + }, time_parser::{natural_parser, TimeParser}, }; -use chrono::{offset::TimeZone, NaiveDateTime}; +use chrono::NaiveDateTime; use rand::{rngs::OsRng, seq::IteratorRandom}; @@ -136,7 +139,6 @@ async fn pause(ctx: &Context, msg: &Message, args: String) { let language = UserData::language_of(&msg.author, &pool).await; let timezone = UserData::timezone_of(&msg.author, &pool).await; - let meridian = UserData::meridian_of(&msg.author, &pool).await; let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool) .await @@ -172,13 +174,9 @@ async fn pause(ctx: &Context, msg: &Message, args: String) { channel.commit_changes(&pool).await; - let content = lm.get(&language, "pause/paused_until").replace( - "{}", - &timezone - .timestamp(timestamp, 0) - .format(meridian.fmt_str()) - .to_string(), - ); + let content = lm + .get(&language, "pause/paused_until") + .replace("{}", &format!("", timestamp)); let _ = msg.channel_id.say(&ctx, content).await; } @@ -864,7 +862,6 @@ impl ReminderScope { #[derive(PartialEq, Eq, Hash, Debug)] enum ReminderError { - LongTime, LongInterval, PastTime, ShortInterval, @@ -891,7 +888,6 @@ trait ToResponse { impl ToResponse for ReminderError { fn to_response(&self) -> &'static str { match self { - Self::LongTime => "remind/long_time", Self::LongInterval => "interval/long_interval", Self::PastTime => "remind/past_time", Self::ShortInterval => "interval/short_interval", @@ -904,7 +900,6 @@ impl ToResponse for ReminderError { fn to_response_natural(&self) -> &'static str { match self { - Self::LongTime => "natural/long_time", Self::InvalidTime => "natural/invalid_time", _ => self.to_response(), } @@ -1601,7 +1596,7 @@ async fn create_reminder<'a, U: Into, T: TryInto>( expires_parser: Option, interval: Option, content: &mut Content, -) -> Result<(), ReminderError> { +) -> Result { let user_id = user_id.into(); if let Some(g_id) = guild_id { @@ -1681,11 +1676,10 @@ async fn create_reminder<'a, U: Into, T: TryInto>( .as_secs() as i64; if time >= unix_time - 10 { - if time > unix_time + *MAX_TIME { - Err(ReminderError::LongTime) - } else { - sqlx::query!( - " + let uid = generate_uid(); + + sqlx::query!( + " INSERT INTO reminders ( uid, content, @@ -1710,23 +1704,22 @@ INSERT INTO reminders ( (SELECT id FROM users WHERE user = ? LIMIT 1) ) ", - generate_uid(), - content.content, - content.tts, - content.attachment, - content.attachment_name, - db_channel_id, - time as u32, - expires, - interval, - user_id - ) - .execute(pool) - .await - .unwrap(); + uid, + content.content, + content.tts, + content.attachment, + content.attachment_name, + db_channel_id, + time, + expires, + interval, + user_id + ) + .execute(pool) + .await + .unwrap(); - Ok(()) - } + Ok(uid) } else if time < 0 { // case required for if python returns -1 Err(ReminderError::InvalidTime) diff --git a/src/commands/todo_cmds.rs b/src/commands/todo_cmds.rs index 2ae1b2a..951b84e 100644 --- a/src/commands/todo_cmds.rs +++ b/src/commands/todo_cmds.rs @@ -12,8 +12,10 @@ use serenity::{ use std::fmt; -use crate::models::CtxGuildData; -use crate::{command_help, get_ctx_data, models::UserData}; +use crate::{ + command_help, get_ctx_data, + models::{user_data::UserData, CtxGuildData}, +}; use sqlx::MySqlPool; use std::convert::TryFrom; diff --git a/src/framework.rs b/src/framework.rs index 58ea26f..de0b2ed 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -8,7 +8,7 @@ use serenity::{ model::{ channel::{Channel, GuildChannel, Message}, guild::{Guild, Member}, - id::ChannelId, + id::{ChannelId, MessageId}, }, Result as SerenityResult, }; @@ -19,10 +19,11 @@ use regex::{Match, Regex, RegexBuilder}; use std::{collections::HashMap, fmt}; -use crate::language_manager::LanguageManager; -use crate::models::{CtxGuildData, GuildData, UserData}; -use crate::{models::ChannelData, LimitExecutors, SQLPool}; -use serenity::model::id::MessageId; +use crate::{ + language_manager::LanguageManager, + models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxGuildData}, + LimitExecutors, SQLPool, +}; type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, ()>; diff --git a/src/main.rs b/src/main.rs index 415c225..2bd290a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,7 +36,7 @@ use crate::{ consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR}, framework::RegexFramework, language_manager::LanguageManager, - models::GuildData, + models::{guild_data::GuildData, user_data::UserData}, }; use inflector::Inflector; @@ -46,7 +46,6 @@ use dashmap::DashMap; use tokio::sync::RwLock; -use crate::models::UserData; use chrono::Utc; use chrono_tz::Tz; use serenity::model::prelude::{ @@ -294,7 +293,7 @@ DELETE FROM guilds WHERE guild = ? .replacen("{timezone}", &user_data.timezone, 1) .replacen( "{time}", - &now.format(user_data.meridian().fmt_str_short()).to_string(), + &now.format("%H:%M").to_string(), 1, ); @@ -399,7 +398,6 @@ async fn main() -> Result<(), Box> { .add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND) .add_command("restrict", &moderation_cmds::RESTRICT_COMMAND) .add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND) - .add_command("meridian", &moderation_cmds::CHANGE_MERIDIAN_COMMAND) .add_command("prefix", &moderation_cmds::PREFIX_COMMAND) .add_command("lang", &moderation_cmds::LANGUAGE_COMMAND) .add_command("pause", &reminder_cmds::PAUSE_COMMAND) diff --git a/src/models.rs b/src/models.rs deleted file mode 100644 index 1a01579..0000000 --- a/src/models.rs +++ /dev/null @@ -1,452 +0,0 @@ -use serenity::{ - async_trait, - http::CacheHttp, - model::{ - channel::Channel, - guild::Guild, - id::{GuildId, UserId}, - user::User, - }, - prelude::Context, -}; - -use sqlx::MySqlPool; - -use chrono::NaiveDateTime; -use chrono_tz::Tz; - -use log::error; - -use crate::{ - consts::{DEFAULT_PREFIX, LOCAL_LANGUAGE, LOCAL_TIMEZONE}, - GuildDataCache, SQLPool, -}; - -use std::sync::Arc; -use tokio::sync::RwLock; - -#[async_trait] -pub trait CtxGuildData { - async fn guild_data + Send + Sync>( - &self, - guild_id: G, - ) -> Result>, sqlx::Error>; - - async fn prefix + Send + Sync>(&self, guild_id: Option) -> String; -} - -#[async_trait] -impl CtxGuildData for Context { - async fn guild_data + Send + Sync>( - &self, - guild_id: G, - ) -> Result>, sqlx::Error> { - let guild_id = guild_id.into(); - - let guild = guild_id.to_guild_cached(&self.cache).await.unwrap(); - - let guild_cache = self - .data - .read() - .await - .get::() - .cloned() - .unwrap(); - - let x = if let Some(guild_data) = guild_cache.get(&guild_id) { - Ok(guild_data.clone()) - } else { - let pool = self.data.read().await.get::().cloned().unwrap(); - - match GuildData::from_guild(guild, &pool).await { - Ok(d) => { - let lock = Arc::new(RwLock::new(d)); - - guild_cache.insert(guild_id, lock.clone()); - - Ok(lock) - } - - Err(e) => Err(e), - } - }; - - x - } - - async fn prefix + Send + Sync>(&self, guild_id: Option) -> String { - if let Some(guild_id) = guild_id { - self.guild_data(guild_id) - .await - .unwrap() - .read() - .await - .prefix - .clone() - } else { - DEFAULT_PREFIX.clone() - } - } -} - -pub struct GuildData { - pub id: u32, - pub name: Option, - pub prefix: String, -} - -impl GuildData { - pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result { - let guild_id = guild.id.as_u64().to_owned(); - - match sqlx::query_as!( - Self, - " -SELECT id, name, prefix FROM guilds WHERE guild = ? - ", - guild_id - ) - .fetch_one(pool) - .await - { - Ok(mut g) => { - g.name = Some(guild.name); - - Ok(g) - } - - Err(sqlx::Error::RowNotFound) => { - sqlx::query!( - " -INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?) - ", - guild_id, - guild.name, - *DEFAULT_PREFIX - ) - .execute(&pool.clone()) - .await?; - - Ok(sqlx::query_as!( - Self, - " -SELECT id, name, prefix FROM guilds WHERE guild = ? - ", - guild_id - ) - .fetch_one(pool) - .await?) - } - - Err(e) => { - error!("Unexpected error in guild query: {:?}", e); - - Err(e) - } - } - } - - pub async fn commit_changes(&self, pool: &MySqlPool) { - sqlx::query!( - " -UPDATE guilds SET name = ?, prefix = ? WHERE id = ? - ", - self.name, - self.prefix, - self.id - ) - .execute(pool) - .await - .unwrap(); - } -} - -pub struct ChannelData { - pub id: u32, - pub name: Option, - pub nudge: i16, - pub blacklisted: bool, - pub webhook_id: Option, - pub webhook_token: Option, - pub paused: bool, - pub paused_until: Option, -} - -impl ChannelData { - pub async fn from_channel( - channel: Channel, - pool: &MySqlPool, - ) -> Result> { - let channel_id = channel.id().as_u64().to_owned(); - - if let Ok(c) = sqlx::query_as_unchecked!(Self, - " -SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? - ", channel_id) - .fetch_one(pool) - .await { - - Ok(c) - } - else { - let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name)); - - let (guild_id, channel_name) = if let Some((a, b)) = props { - (Some(a), Some(b)) - } else { - (None, None) - }; - - sqlx::query!( - " -INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?)) - ", channel_id, channel_name, guild_id) - .execute(&pool.clone()) - .await?; - - Ok(sqlx::query_as_unchecked!(Self, - " -SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? - ", channel_id) - .fetch_one(pool) - .await?) - } - } - - pub async fn commit_changes(&self, pool: &MySqlPool) { - sqlx::query!( - " -UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ? - ", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id) - .execute(pool) - .await.unwrap(); - } -} - -pub struct UserData { - pub id: u32, - pub user: u64, - pub name: String, - pub dm_channel: u32, - pub language: String, - pub timezone: String, - pub meridian_time: bool, -} - -pub struct MeridianType(bool); - -impl MeridianType { - pub fn fmt_str(&self) -> &str { - if self.0 { - "%Y-%m-%d %I:%M:%S %p" - } else { - "%Y-%m-%d %H:%M:%S" - } - } - - pub fn fmt_str_short(&self) -> &str { - if self.0 { - "%I:%M %p" - } else { - "%H:%M" - } - } -} - -impl UserData { - pub async fn language_of(user: U, pool: &MySqlPool) -> String - where - U: Into, - { - let user_id = user.into().as_u64().to_owned(); - - match sqlx::query!( - " -SELECT language FROM users WHERE user = ? - ", - user_id - ) - .fetch_one(pool) - .await - { - Ok(r) => r.language, - - Err(_) => LOCAL_LANGUAGE.clone(), - } - } - - pub async fn timezone_of(user: U, pool: &MySqlPool) -> Tz - where - U: Into, - { - let user_id = user.into().as_u64().to_owned(); - - match sqlx::query!( - " -SELECT timezone FROM users WHERE user = ? - ", - user_id - ) - .fetch_one(pool) - .await - { - Ok(r) => r.timezone, - - Err(_) => LOCAL_TIMEZONE.clone(), - } - .parse() - .unwrap() - } - - pub async fn meridian_of(user: U, pool: &MySqlPool) -> MeridianType - where - U: Into, - { - let user_id = user.into().as_u64().to_owned(); - - match sqlx::query!( - " -SELECT meridian_time FROM users WHERE user = ? - ", - user_id - ) - .fetch_one(pool) - .await - { - Ok(r) => MeridianType(r.meridian_time != 0), - - Err(_) => MeridianType(false), - } - } - - pub async fn from_user( - user: &User, - ctx: impl CacheHttp, - pool: &MySqlPool, - ) -> Result> { - let user_id = user.id.as_u64().to_owned(); - - match sqlx::query_as_unchecked!( - Self, - " -SELECT id, user, name, dm_channel, IF(language IS NULL, ?, language) AS language, IF(timezone IS NULL, ?, timezone) AS timezone, meridian_time FROM users WHERE user = ? - ", - *LOCAL_LANGUAGE, *LOCAL_TIMEZONE, user_id - ) - .fetch_one(pool) - .await - { - Ok(c) => Ok(c), - - Err(sqlx::Error::RowNotFound) => { - let dm_channel = user.create_dm_channel(ctx).await?; - let dm_id = dm_channel.id.as_u64().to_owned(); - - let pool_c = pool.clone(); - - sqlx::query!( - " -INSERT IGNORE INTO channels (channel) VALUES (?) - ", - dm_id - ) - .execute(&pool_c) - .await?; - - sqlx::query!( - " -INSERT INTO users (user, name, dm_channel, language, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?, ?) - ", user_id, user.name, dm_id, *LOCAL_LANGUAGE, *LOCAL_TIMEZONE) - .execute(&pool_c) - .await?; - - Ok(sqlx::query_as_unchecked!( - Self, - " -SELECT id, user, name, dm_channel, language, timezone, meridian_time FROM users WHERE user = ? - ", - user_id - ) - .fetch_one(pool) - .await?) - } - - Err(e) => { - error!("Error querying for user: {:?}", e); - - Err(Box::new(e)) - }, - } - } - - pub async fn commit_changes(&self, pool: &MySqlPool) { - sqlx::query!( - " -UPDATE users SET name = ?, language = ?, timezone = ?, meridian_time = ? WHERE id = ? - ", - self.name, - self.language, - self.timezone, - self.meridian_time, - self.id - ) - .execute(pool) - .await - .unwrap(); - } - - pub fn timezone(&self) -> Tz { - self.timezone.parse().unwrap() - } - - pub fn meridian(&self) -> MeridianType { - MeridianType(self.meridian_time) - } -} - -pub struct Timer { - pub name: String, - pub start_time: NaiveDateTime, - pub owner: u64, -} - -impl Timer { - pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec { - sqlx::query_as_unchecked!( - Timer, - " -SELECT name, start_time, owner FROM timers WHERE owner = ? - ", - owner - ) - .fetch_all(pool) - .await - .unwrap() - } - - pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 { - sqlx::query!( - " -SELECT COUNT(1) as count FROM timers WHERE owner = ? - ", - owner - ) - .fetch_one(pool) - .await - .unwrap() - .count as u32 - } - - pub async fn create(name: &str, owner: u64, pool: &MySqlPool) { - sqlx::query!( - " -INSERT INTO timers (name, owner) VALUES (?, ?) - ", - name, - owner - ) - .execute(pool) - .await - .unwrap(); - } -} diff --git a/src/models/channel_data.rs b/src/models/channel_data.rs new file mode 100644 index 0000000..a9fd65f --- /dev/null +++ b/src/models/channel_data.rs @@ -0,0 +1,67 @@ +use serenity::model::channel::Channel; + +use sqlx::MySqlPool; + +use chrono::NaiveDateTime; + +pub struct ChannelData { + pub id: u32, + pub name: Option, + pub nudge: i16, + pub blacklisted: bool, + pub webhook_id: Option, + pub webhook_token: Option, + pub paused: bool, + pub paused_until: Option, +} + +impl ChannelData { + pub async fn from_channel( + channel: Channel, + pool: &MySqlPool, + ) -> Result> { + let channel_id = channel.id().as_u64().to_owned(); + + if let Ok(c) = sqlx::query_as_unchecked!(Self, + " +SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? + ", channel_id) + .fetch_one(pool) + .await { + + Ok(c) + } + else { + let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name)); + + let (guild_id, channel_name) = if let Some((a, b)) = props { + (Some(a), Some(b)) + } else { + (None, None) + }; + + sqlx::query!( + " +INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?)) + ", channel_id, channel_name, guild_id) + .execute(&pool.clone()) + .await?; + + Ok(sqlx::query_as_unchecked!(Self, + " +SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? + ", channel_id) + .fetch_one(pool) + .await?) + } + } + + pub async fn commit_changes(&self, pool: &MySqlPool) { + sqlx::query!( + " +UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ? + ", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id) + .execute(pool) + .await.unwrap(); + } +} diff --git a/src/models/guild_data.rs b/src/models/guild_data.rs new file mode 100644 index 0000000..81dcaca --- /dev/null +++ b/src/models/guild_data.rs @@ -0,0 +1,79 @@ +use serenity::model::guild::Guild; + +use sqlx::MySqlPool; + +use log::error; + +use crate::consts::DEFAULT_PREFIX; + +pub struct GuildData { + pub id: u32, + pub name: Option, + pub prefix: String, +} + +impl GuildData { + pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result { + let guild_id = guild.id.as_u64().to_owned(); + + match sqlx::query_as!( + Self, + " +SELECT id, name, prefix FROM guilds WHERE guild = ? + ", + guild_id + ) + .fetch_one(pool) + .await + { + Ok(mut g) => { + g.name = Some(guild.name); + + Ok(g) + } + + Err(sqlx::Error::RowNotFound) => { + sqlx::query!( + " +INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?) + ", + guild_id, + guild.name, + *DEFAULT_PREFIX + ) + .execute(&pool.clone()) + .await?; + + Ok(sqlx::query_as!( + Self, + " +SELECT id, name, prefix FROM guilds WHERE guild = ? + ", + guild_id + ) + .fetch_one(pool) + .await?) + } + + Err(e) => { + error!("Unexpected error in guild query: {:?}", e); + + Err(e) + } + } + } + + pub async fn commit_changes(&self, pool: &MySqlPool) { + sqlx::query!( + " +UPDATE guilds SET name = ?, prefix = ? WHERE id = ? + ", + self.name, + self.prefix, + self.id + ) + .execute(pool) + .await + .unwrap(); + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..384dd7f --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,77 @@ +pub mod channel_data; +pub mod guild_data; +pub mod timer; +pub mod user_data; + +use serenity::{async_trait, model::id::GuildId, prelude::Context}; + +use crate::{consts::DEFAULT_PREFIX, GuildDataCache, SQLPool}; + +use guild_data::GuildData; + +use std::sync::Arc; +use tokio::sync::RwLock; + +#[async_trait] +pub trait CtxGuildData { + async fn guild_data + Send + Sync>( + &self, + guild_id: G, + ) -> Result>, sqlx::Error>; + + async fn prefix + Send + Sync>(&self, guild_id: Option) -> String; +} + +#[async_trait] +impl CtxGuildData for Context { + async fn guild_data + Send + Sync>( + &self, + guild_id: G, + ) -> Result>, sqlx::Error> { + let guild_id = guild_id.into(); + + let guild = guild_id.to_guild_cached(&self.cache).await.unwrap(); + + let guild_cache = self + .data + .read() + .await + .get::() + .cloned() + .unwrap(); + + let x = if let Some(guild_data) = guild_cache.get(&guild_id) { + Ok(guild_data.clone()) + } else { + let pool = self.data.read().await.get::().cloned().unwrap(); + + match GuildData::from_guild(guild, &pool).await { + Ok(d) => { + let lock = Arc::new(RwLock::new(d)); + + guild_cache.insert(guild_id, lock.clone()); + + Ok(lock) + } + + Err(e) => Err(e), + } + }; + + x + } + + async fn prefix + Send + Sync>(&self, guild_id: Option) -> String { + if let Some(guild_id) = guild_id { + self.guild_data(guild_id) + .await + .unwrap() + .read() + .await + .prefix + .clone() + } else { + DEFAULT_PREFIX.clone() + } + } +} diff --git a/src/models/timer.rs b/src/models/timer.rs new file mode 100644 index 0000000..8a56b9f --- /dev/null +++ b/src/models/timer.rs @@ -0,0 +1,50 @@ +use sqlx::MySqlPool; + +use chrono::NaiveDateTime; + +pub struct Timer { + pub name: String, + pub start_time: NaiveDateTime, + pub owner: u64, +} + +impl Timer { + pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec { + sqlx::query_as_unchecked!( + Timer, + " +SELECT name, start_time, owner FROM timers WHERE owner = ? + ", + owner + ) + .fetch_all(pool) + .await + .unwrap() + } + + pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 { + sqlx::query!( + " +SELECT COUNT(1) as count FROM timers WHERE owner = ? + ", + owner + ) + .fetch_one(pool) + .await + .unwrap() + .count as u32 + } + + pub async fn create(name: &str, owner: u64, pool: &MySqlPool) { + sqlx::query!( + " +INSERT INTO timers (name, owner) VALUES (?, ?) + ", + name, + owner + ) + .execute(pool) + .await + .unwrap(); + } +} diff --git a/src/models/user_data.rs b/src/models/user_data.rs new file mode 100644 index 0000000..fe365a1 --- /dev/null +++ b/src/models/user_data.rs @@ -0,0 +1,146 @@ +use serenity::{ + http::CacheHttp, + model::{id::UserId, user::User}, +}; + +use sqlx::MySqlPool; + +use chrono_tz::Tz; + +use log::error; + +use crate::consts::{LOCAL_LANGUAGE, LOCAL_TIMEZONE}; + +pub struct UserData { + pub id: u32, + pub user: u64, + pub name: String, + pub dm_channel: u32, + pub language: String, + pub timezone: String, +} + +impl UserData { + pub async fn language_of(user: U, pool: &MySqlPool) -> String + where + U: Into, + { + let user_id = user.into().as_u64().to_owned(); + + match sqlx::query!( + " +SELECT language FROM users WHERE user = ? + ", + user_id + ) + .fetch_one(pool) + .await + { + Ok(r) => r.language, + + Err(_) => LOCAL_LANGUAGE.clone(), + } + } + + pub async fn timezone_of(user: U, pool: &MySqlPool) -> Tz + where + U: Into, + { + let user_id = user.into().as_u64().to_owned(); + + match sqlx::query!( + " +SELECT timezone FROM users WHERE user = ? + ", + user_id + ) + .fetch_one(pool) + .await + { + Ok(r) => r.timezone, + + Err(_) => LOCAL_TIMEZONE.clone(), + } + .parse() + .unwrap() + } + + pub async fn from_user( + user: &User, + ctx: impl CacheHttp, + pool: &MySqlPool, + ) -> Result> { + let user_id = user.id.as_u64().to_owned(); + + match sqlx::query_as_unchecked!( + Self, + " +SELECT id, user, name, dm_channel, IF(language IS NULL, ?, language) AS language, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ? + ", + *LOCAL_LANGUAGE, *LOCAL_TIMEZONE, user_id + ) + .fetch_one(pool) + .await + { + Ok(c) => Ok(c), + + Err(sqlx::Error::RowNotFound) => { + let dm_channel = user.create_dm_channel(ctx).await?; + let dm_id = dm_channel.id.as_u64().to_owned(); + + let pool_c = pool.clone(); + + sqlx::query!( + " +INSERT IGNORE INTO channels (channel) VALUES (?) + ", + dm_id + ) + .execute(&pool_c) + .await?; + + sqlx::query!( + " +INSERT INTO users (user, name, dm_channel, language, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?, ?) + ", user_id, user.name, dm_id, *LOCAL_LANGUAGE, *LOCAL_TIMEZONE) + .execute(&pool_c) + .await?; + + Ok(sqlx::query_as_unchecked!( + Self, + " +SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ? + ", + user_id + ) + .fetch_one(pool) + .await?) + } + + Err(e) => { + error!("Error querying for user: {:?}", e); + + Err(Box::new(e)) + }, + } + } + + pub async fn commit_changes(&self, pool: &MySqlPool) { + sqlx::query!( + " +UPDATE users SET name = ?, language = ?, timezone = ? WHERE id = ? + ", + self.name, + self.language, + self.timezone, + self.id + ) + .execute(pool) + .await + .unwrap(); + } + + pub fn timezone(&self) -> Tz { + self.timezone.parse().unwrap() + } +}