more commands. fixed an issue with text only commands

This commit is contained in:
2021-09-11 00:14:23 +01:00
parent 471948bed3
commit 9b5333dc87
18 changed files with 562 additions and 897 deletions

View File

@ -15,51 +15,67 @@ pub struct ChannelData {
impl ChannelData {
pub async fn from_channel(
channel: Channel,
channel: &Channel,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel.id().as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self,
if let Ok(c) = sqlx::query_as_unchecked!(
Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id)
.fetch_one(pool)
.await {
",
channel_id
)
.fetch_one(pool)
.await
{
Ok(c)
}
else {
let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
} else {
let props = channel.to_owned().guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
let (guild_id, channel_name) = if let Some((a, b)) = props {
(Some(a), Some(b))
} else {
(None, None)
};
let (guild_id, channel_name) = if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) };
sqlx::query!(
"
INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))
", channel_id, channel_name, guild_id)
.execute(&pool.clone())
.await?;
",
channel_id,
channel_name,
guild_id
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as_unchecked!(Self,
Ok(sqlx::query_as_unchecked!(
Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id)
.fetch_one(pool)
.await?)
",
channel_id
)
.fetch_one(pool)
.await?)
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ?
", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id)
.execute(pool)
.await.unwrap();
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until \
= ? WHERE id = ?
",
self.name,
self.nudge,
self.blacklisted,
self.webhook_id,
self.webhook_token,
self.paused,
self.paused_until,
self.id
)
.execute(pool)
.await
.unwrap();
}
}

View File

@ -6,15 +6,18 @@ pub mod user_data;
use std::sync::Arc;
use guild_data::GuildData;
use serenity::{
async_trait,
model::id::{GuildId, UserId},
model::id::{ChannelId, GuildId, UserId},
prelude::Context,
};
use tokio::sync::RwLock;
use crate::{consts::DEFAULT_PREFIX, models::user_data::UserData, GuildDataCache, SQLPool};
use crate::{
consts::DEFAULT_PREFIX,
models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData},
GuildDataCache, SQLPool,
};
#[async_trait]
pub trait CtxData {
@ -23,12 +26,17 @@ pub trait CtxData {
guild_id: G,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error>;
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String;
async fn user_data<U: Into<UserId> + Send + Sync>(
&self,
user_id: U,
) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String;
async fn channel_data<C: Into<ChannelId> + Send + Sync>(
&self,
channel_id: C,
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>>;
}
#[async_trait]
@ -41,13 +49,7 @@ impl CtxData for Context {
let guild = guild_id.to_guild_cached(&self.cache).unwrap();
let guild_cache = self
.data
.read()
.await
.get::<GuildDataCache>()
.cloned()
.unwrap();
let guild_cache = self.data.read().await.get::<GuildDataCache>().cloned().unwrap();
let x = if let Some(guild_data) = guild_cache.get(&guild_id) {
Ok(guild_data.clone())
@ -70,6 +72,14 @@ impl CtxData for Context {
x
}
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String {
if let Some(guild_id) = guild_id {
self.guild_data(guild_id).await.unwrap().read().await.prefix.clone()
} else {
DEFAULT_PREFIX.clone()
}
}
async fn user_data<U: Into<UserId> + Send + Sync>(
&self,
user_id: U,
@ -82,17 +92,15 @@ impl CtxData for Context {
UserData::from_user(&user, &self, &pool).await
}
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String {
if let Some(guild_id) = guild_id {
self.guild_data(guild_id)
.await
.unwrap()
.read()
.await
.prefix
.clone()
} else {
DEFAULT_PREFIX.clone()
}
async fn channel_data<C: Into<ChannelId> + Send + Sync>(
&self,
channel_id: C,
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel_id.into();
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
let channel = channel_id.to_channel_cached(&self).unwrap();
ChannelData::from_channel(&channel, &pool).await
}
}

View File

@ -38,10 +38,7 @@ async fn create_webhook(
include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/assets/",
env!(
"WEBHOOK_AVATAR",
"WEBHOOK_AVATAR not provided for compilation"
)
env!("WEBHOOK_AVATAR", "WEBHOOK_AVATAR not provided for compilation")
)) as &[u8],
env!("WEBHOOK_AVATAR"),
),
@ -230,14 +227,7 @@ impl<'a> MultiReminderBuilder<'a> {
}
pub async fn build(mut self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) {
let pool = self
.ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.unwrap();
let pool = self.ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let mut errors = HashSet::new();
@ -296,7 +286,7 @@ impl<'a> MultiReminderBuilder<'a> {
Err(ReminderError::InvalidTag)
} else {
let mut channel_data =
ChannelData::from_channel(channel, &pool).await.unwrap();
ChannelData::from_channel(&channel, &pool).await.unwrap();
if channel_data.webhook_id.is_none()
|| channel_data.webhook_token.is_none()

View File

@ -12,12 +12,7 @@ pub struct Content {
impl Content {
pub fn new() -> Self {
Self {
content: "".to_string(),
tts: false,
attachment: None,
attachment_name: None,
}
Self { content: "".to_string(), tts: false, attachment: None, attachment_name: None }
}
pub async fn build<S: ToString>(content: S, message: &Message) -> Result<Self, ContentError> {

View File

@ -42,22 +42,50 @@ pub enum ReminderError {
impl ReminderError {
pub fn display(&self, is_natural: bool) -> String {
match self {
ReminderError::LongTime => "That time is too far in the future. Please specify a shorter time.".to_string(),
ReminderError::LongInterval => format!("Please ensure the interval specified is less than {max_time} days", max_time = *MAX_TIME / 86_400),
ReminderError::PastTime => "Please ensure the time provided is in the future. If the time should be in the future, please be more specific with the definition.".to_string(),
ReminderError::ShortInterval => format!("Please ensure the interval provided is longer than {min_interval} seconds", min_interval = *MIN_INTERVAL),
ReminderError::InvalidTag => "Couldn't find a location by your tag. Your tag must be either a channel or a user (not a role)".to_string(),
ReminderError::InvalidTime => if is_natural {
"Your time failed to process. Please make it as clear as possible, for example `\"16th of july\"` or `\"in 20 minutes\"`".to_string()
} else {
"Make sure the time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or `day/month/year-hour:minute:second`".to_string()
},
ReminderError::InvalidExpiration => if is_natural {
"Your expiration time failed to process. Please make it as clear as possible, for example `\"16th of july\"` or `\"in 20 minutes\"`".to_string()
} else {
"Make sure the expiration time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or `day/month/year-hour:minute:second`".to_string()
},
ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s)
ReminderError::LongTime => {
"That time is too far in the future. Please specify a shorter time.".to_string()
}
ReminderError::LongInterval => format!(
"Please ensure the interval specified is less than {max_time} days",
max_time = *MAX_TIME / 86_400
),
ReminderError::PastTime => {
"Please ensure the time provided is in the future. If the time should be in \
the future, please be more specific with the definition."
.to_string()
}
ReminderError::ShortInterval => format!(
"Please ensure the interval provided is longer than {min_interval} seconds",
min_interval = *MIN_INTERVAL
),
ReminderError::InvalidTag => {
"Couldn't find a location by your tag. Your tag must be either a channel or \
a user (not a role)"
.to_string()
}
ReminderError::InvalidTime => {
if is_natural {
"Your time failed to process. Please make it as clear as possible, for example `\"16th of july\"` \
or `\"in 20 minutes\"`"
.to_string()
} else {
"Make sure the time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] etc. or \
`day/month/year-hour:minute:second`"
.to_string()
}
}
ReminderError::InvalidExpiration => {
if is_natural {
"Your expiration time failed to process. Please make it as clear as possible, for example `\"16th \
of july\"` or `\"in 20 minutes\"`"
.to_string()
} else {
"Make sure the expiration time you have provided is in the format of [num][s/m/h/d][num][s/m/h/d] \
etc. or `day/month/year-hour:minute:second`"
.to_string()
}
}
ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s),
}
}
}

View File

@ -10,9 +10,8 @@ pub fn longhand_displacement(seconds: u64) -> String {
let mut sections = vec![];
for (var, name) in [days, hours, minutes, seconds]
.iter()
.zip(["days", "hours", "minutes", "seconds"].iter())
for (var, name) in
[days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter())
{
if *var > 0 {
sections.push(format!("{} {}", var, name));
@ -26,14 +25,7 @@ pub fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
(0..64)
.map(|_| {
CHARACTERS
.chars()
.choose(&mut generator)
.unwrap()
.to_owned()
.to_string()
})
.map(|_| CHARACTERS.chars().choose(&mut generator).unwrap().to_owned().to_string())
.collect::<Vec<String>>()
.join("")
}

View File

@ -329,18 +329,14 @@ WHERE
self.display_content(),
time_display,
longhand_displacement(interval as u64),
self.set_by
.map(|i| format!("<@{}>", i))
.unwrap_or_else(|| "unknown".to_string())
self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
)
} else {
format!(
"'{}' *occurs next at* **{}** (set by {})",
self.display_content(),
time_display,
self.set_by
.map(|i| format!("<@{}>", i))
.unwrap_or_else(|| "unknown".to_string())
self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
)
}
}
@ -380,9 +376,7 @@ WHERE
pub fn signed_action<U: Into<u64>>(&self, member_id: U, action: ReminderAction) -> String {
let s_key = hmac::Key::new(
hmac::HMAC_SHA256,
env::var("SECRET_KEY")
.expect("No SECRET_KEY provided")
.as_bytes(),
env::var("SECRET_KEY").expect("No SECRET_KEY provided").as_bytes(),
);
let mut context = hmac::Context::with_key(&s_key);

View File

@ -52,7 +52,8 @@ SELECT timezone FROM users WHERE user = ?
"
SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ?
",
*LOCAL_TIMEZONE, user_id
*LOCAL_TIMEZONE,
user_id
)
.fetch_one(pool)
.await
@ -77,9 +78,14 @@ INSERT IGNORE INTO channels (channel) VALUES (?)
sqlx::query!(
"
INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?)
", user_id, user.name, dm_id, *LOCAL_TIMEZONE)
.execute(&pool_c)
.await?;
",
user_id,
user.name,
dm_id,
*LOCAL_TIMEZONE
)
.execute(&pool_c)
.await?;
Ok(sqlx::query_as_unchecked!(
Self,
@ -96,7 +102,7 @@ SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ?
error!("Error querying for user: {:?}", e);
Err(Box::new(e))
},
}
}
}