2nd attempt at doing poise stuff
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use chrono::NaiveDateTime;
|
||||
use serenity::model::channel::Channel;
|
||||
use poise::serenity::model::channel::Channel;
|
||||
use sqlx::MySqlPool;
|
||||
|
||||
pub struct ChannelData {
|
||||
|
@ -1,33 +1,25 @@
|
||||
use serenity::{client::Context, model::id::GuildId};
|
||||
use poise::serenity::{
|
||||
client::Context,
|
||||
model::{
|
||||
id::GuildId, interactions::application_command::ApplicationCommandInteractionDataOption,
|
||||
},
|
||||
};
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::{framework::CommandOptions, SQLPool};
|
||||
#[derive(Serialize)]
|
||||
pub struct RecordedCommand<U, E> {
|
||||
#[serde(skip)]
|
||||
action: for<'a> fn(
|
||||
poise::ApplicationContext<'a, U, E>,
|
||||
&'a [ApplicationCommandInteractionDataOption],
|
||||
) -> poise::BoxFuture<'a, Result<(), poise::FrameworkError<'a, U, E>>>,
|
||||
command_name: String,
|
||||
options: Vec<ApplicationCommandInteractionDataOption>,
|
||||
}
|
||||
|
||||
pub struct CommandMacro {
|
||||
pub struct CommandMacro<U, E> {
|
||||
pub guild_id: GuildId,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub commands: Vec<CommandOptions>,
|
||||
}
|
||||
|
||||
impl CommandMacro {
|
||||
pub async fn from_guild(ctx: &Context, guild_id: impl Into<GuildId>) -> Vec<Self> {
|
||||
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
let guild_id = guild_id.into();
|
||||
|
||||
sqlx::query!(
|
||||
"SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
|
||||
guild_id.0
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|row| Self {
|
||||
guild_id,
|
||||
name: row.name.clone(),
|
||||
description: row.description.clone(),
|
||||
commands: serde_json::from_str(&row.commands).unwrap(),
|
||||
})
|
||||
.collect::<Vec<Self>>()
|
||||
}
|
||||
pub commands: Vec<RecordedCommand<U, E>>,
|
||||
}
|
||||
|
@ -5,62 +5,47 @@ pub mod timer;
|
||||
pub mod user_data;
|
||||
|
||||
use chrono_tz::Tz;
|
||||
use serenity::{
|
||||
async_trait,
|
||||
model::id::{ChannelId, UserId},
|
||||
prelude::Context,
|
||||
};
|
||||
use poise::serenity::{async_trait, model::id::UserId};
|
||||
|
||||
use crate::{
|
||||
models::{channel_data::ChannelData, user_data::UserData},
|
||||
SQLPool,
|
||||
Context,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait CtxData {
|
||||
async fn user_data<U: Into<UserId> + Send + Sync>(
|
||||
async fn user_data<U: Into<UserId> + Send>(
|
||||
&self,
|
||||
user_id: U,
|
||||
) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
|
||||
|
||||
async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz;
|
||||
async fn author_data(&self) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
|
||||
|
||||
async fn channel_data<C: Into<ChannelId> + Send + Sync>(
|
||||
&self,
|
||||
channel_id: C,
|
||||
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>>;
|
||||
async fn timezone(&self) -> Tz;
|
||||
|
||||
async fn channel_data(&self) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CtxData for Context {
|
||||
async fn user_data<U: Into<UserId> + Send + Sync>(
|
||||
impl CtxData for Context<'_> {
|
||||
async fn user_data<U: Into<UserId> + Send>(
|
||||
&self,
|
||||
user_id: U,
|
||||
) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> {
|
||||
let user_id = user_id.into();
|
||||
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
let user = user_id.to_user(self).await.unwrap();
|
||||
|
||||
UserData::from_user(&user, &self, &pool).await
|
||||
UserData::from_user(user_id, &self.discord(), &self.data().database).await
|
||||
}
|
||||
|
||||
async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz {
|
||||
let user_id = user_id.into();
|
||||
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
UserData::timezone_of(user_id, &pool).await
|
||||
async fn author_data(&self) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> {
|
||||
UserData::from_user(&self.author().id, &self.discord(), &self.data().database).await
|
||||
}
|
||||
|
||||
async fn channel_data<C: Into<ChannelId> + Send + Sync>(
|
||||
&self,
|
||||
channel_id: C,
|
||||
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>> {
|
||||
let channel_id = channel_id.into();
|
||||
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
async fn timezone(&self) -> Tz {
|
||||
UserData::timezone_of(self.author().id, &self.data().database).await
|
||||
}
|
||||
|
||||
let channel = channel_id.to_channel_cached(&self).unwrap();
|
||||
async fn channel_data(&self) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>> {
|
||||
let channel = self.channel_id().to_channel_cached(&self.discord()).unwrap();
|
||||
|
||||
ChannelData::from_channel(&channel, &pool).await
|
||||
ChannelData::from_channel(&channel, &self.data().database).await
|
||||
}
|
||||
}
|
||||
|
@ -2,8 +2,7 @@ use std::{collections::HashSet, fmt::Display};
|
||||
|
||||
use chrono::{Duration, NaiveDateTime, Utc};
|
||||
use chrono_tz::Tz;
|
||||
use serenity::{
|
||||
client::Context,
|
||||
use poise::serenity::{
|
||||
http::CacheHttp,
|
||||
model::{
|
||||
channel::GuildChannel,
|
||||
@ -15,15 +14,14 @@ use serenity::{
|
||||
use sqlx::MySqlPool;
|
||||
|
||||
use crate::{
|
||||
consts,
|
||||
consts::{DAY, MAX_TIME, MIN_INTERVAL},
|
||||
consts::{DAY, DEFAULT_AVATAR, MAX_TIME, MIN_INTERVAL},
|
||||
interval_parser::Interval,
|
||||
models::{
|
||||
channel_data::ChannelData,
|
||||
reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
|
||||
user_data::UserData,
|
||||
},
|
||||
SQLPool,
|
||||
Context,
|
||||
};
|
||||
|
||||
async fn create_webhook(
|
||||
@ -31,7 +29,7 @@ async fn create_webhook(
|
||||
channel: GuildChannel,
|
||||
name: impl Display,
|
||||
) -> SerenityResult<Webhook> {
|
||||
channel.create_webhook_with_avatar(ctx.http(), name, consts::DEFAULT_AVATAR.clone()).await
|
||||
channel.create_webhook_with_avatar(ctx.http(), name, DEFAULT_AVATAR.clone()).await
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
@ -145,7 +143,7 @@ pub struct MultiReminderBuilder<'a> {
|
||||
expires: Option<NaiveDateTime>,
|
||||
content: Content,
|
||||
set_by: Option<u32>,
|
||||
ctx: &'a Context,
|
||||
ctx: &'a Context<'a>,
|
||||
guild_id: Option<GuildId>,
|
||||
}
|
||||
|
||||
@ -210,8 +208,6 @@ impl<'a> MultiReminderBuilder<'a> {
|
||||
}
|
||||
|
||||
pub async fn build(self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) {
|
||||
let pool = self.ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
let mut errors = HashSet::new();
|
||||
|
||||
let mut ok_locs = HashSet::new();
|
||||
@ -225,12 +221,17 @@ impl<'a> MultiReminderBuilder<'a> {
|
||||
for scope in self.scopes {
|
||||
let db_channel_id = match scope {
|
||||
ReminderScope::User(user_id) => {
|
||||
if let Ok(user) = UserId(user_id).to_user(&self.ctx).await {
|
||||
let user_data =
|
||||
UserData::from_user(&user, &self.ctx, &pool).await.unwrap();
|
||||
if let Ok(user) = UserId(user_id).to_user(&self.ctx.discord()).await {
|
||||
let user_data = UserData::from_user(
|
||||
&user,
|
||||
&self.ctx.discord(),
|
||||
&self.ctx.data().database,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if let Some(guild_id) = self.guild_id {
|
||||
if guild_id.member(&self.ctx, user).await.is_err() {
|
||||
if guild_id.member(&self.ctx.discord(), user).await.is_err() {
|
||||
Err(ReminderError::InvalidTag)
|
||||
} else {
|
||||
Ok(user_data.dm_channel)
|
||||
@ -243,26 +244,36 @@ impl<'a> MultiReminderBuilder<'a> {
|
||||
}
|
||||
}
|
||||
ReminderScope::Channel(channel_id) => {
|
||||
let channel = ChannelId(channel_id).to_channel(&self.ctx).await.unwrap();
|
||||
let channel =
|
||||
ChannelId(channel_id).to_channel(&self.ctx.discord()).await.unwrap();
|
||||
|
||||
if let Some(guild_channel) = channel.clone().guild() {
|
||||
if Some(guild_channel.guild_id) != self.guild_id {
|
||||
Err(ReminderError::InvalidTag)
|
||||
} else {
|
||||
let mut channel_data =
|
||||
ChannelData::from_channel(&channel, &pool).await.unwrap();
|
||||
ChannelData::from_channel(&channel, &self.ctx.data().database)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if channel_data.webhook_id.is_none()
|
||||
|| channel_data.webhook_token.is_none()
|
||||
{
|
||||
match create_webhook(&self.ctx, guild_channel, "Reminder").await
|
||||
match create_webhook(
|
||||
&self.ctx.discord(),
|
||||
guild_channel,
|
||||
"Reminder",
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(webhook) => {
|
||||
channel_data.webhook_id =
|
||||
Some(webhook.id.as_u64().to_owned());
|
||||
channel_data.webhook_token = webhook.token;
|
||||
|
||||
channel_data.commit_changes(&pool).await;
|
||||
channel_data
|
||||
.commit_changes(&self.ctx.data().database)
|
||||
.await;
|
||||
|
||||
Ok(channel_data.id)
|
||||
}
|
||||
@ -282,7 +293,7 @@ impl<'a> MultiReminderBuilder<'a> {
|
||||
match db_channel_id {
|
||||
Ok(c) => {
|
||||
let builder = ReminderBuilder {
|
||||
pool: pool.clone(),
|
||||
pool: self.ctx.data().database.clone(),
|
||||
uid: generate_uid(),
|
||||
channel: c,
|
||||
utc_time: self.utc_time,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use poise::serenity::model::id::ChannelId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_repr::*;
|
||||
use serenity::model::id::ChannelId;
|
||||
|
||||
#[derive(Serialize_repr, Deserialize_repr, Copy, Clone, Debug)]
|
||||
#[repr(u8)]
|
||||
|
@ -6,15 +6,12 @@ pub mod look_flags;
|
||||
|
||||
use chrono::{NaiveDateTime, TimeZone};
|
||||
use chrono_tz::Tz;
|
||||
use serenity::{
|
||||
client::Context,
|
||||
model::id::{ChannelId, GuildId, UserId},
|
||||
};
|
||||
use sqlx::MySqlPool;
|
||||
use poise::serenity::model::id::{ChannelId, GuildId, UserId};
|
||||
use sqlx::Executor;
|
||||
|
||||
use crate::{
|
||||
models::reminder::look_flags::{LookFlags, TimeDisplayType},
|
||||
SQLPool,
|
||||
Context, Data, Database,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -33,7 +30,10 @@ pub struct Reminder {
|
||||
}
|
||||
|
||||
impl Reminder {
|
||||
pub async fn from_uid(pool: &MySqlPool, uid: String) -> Option<Self> {
|
||||
pub async fn from_uid(
|
||||
pool: impl Executor<'_, Database = Database>,
|
||||
uid: String,
|
||||
) -> Option<Self> {
|
||||
sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
@ -70,12 +70,10 @@ WHERE
|
||||
}
|
||||
|
||||
pub async fn from_channel<C: Into<ChannelId>>(
|
||||
ctx: &Context,
|
||||
ctx: &Context<'_>,
|
||||
channel_id: C,
|
||||
flags: &LookFlags,
|
||||
) -> Vec<Self> {
|
||||
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
let enabled = if flags.show_disabled { "0,1" } else { "1" };
|
||||
let channel_id = channel_id.into();
|
||||
|
||||
@ -113,16 +111,18 @@ ORDER BY
|
||||
channel_id.as_u64(),
|
||||
enabled,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.fetch_all(&ctx.data().database)
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn from_guild(ctx: &Context, guild_id: Option<GuildId>, user: UserId) -> Vec<Self> {
|
||||
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
|
||||
|
||||
pub async fn from_guild(
|
||||
ctx: &Context<'_>,
|
||||
guild_id: Option<GuildId>,
|
||||
user: UserId,
|
||||
) -> Vec<Self> {
|
||||
if let Some(guild_id) = guild_id {
|
||||
let guild_opt = guild_id.to_guild_cached(&ctx);
|
||||
let guild_opt = guild_id.to_guild_cached(&ctx.discord());
|
||||
|
||||
if let Some(guild) = guild_opt {
|
||||
let channels = guild
|
||||
@ -163,7 +163,7 @@ WHERE
|
||||
",
|
||||
channels
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.fetch_all(&ctx.data().database)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query_as_unchecked!(
|
||||
@ -196,7 +196,7 @@ WHERE
|
||||
",
|
||||
guild_id.as_u64()
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.fetch_all(&ctx.data().database)
|
||||
.await
|
||||
}
|
||||
} else {
|
||||
@ -230,7 +230,7 @@ WHERE
|
||||
",
|
||||
user.as_u64()
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.fetch_all(&ctx.data().database)
|
||||
.await
|
||||
}
|
||||
.unwrap()
|
||||
|
@ -1,9 +1,6 @@
|
||||
use chrono_tz::Tz;
|
||||
use log::error;
|
||||
use serenity::{
|
||||
http::CacheHttp,
|
||||
model::{id::UserId, user::User},
|
||||
};
|
||||
use poise::serenity::{http::CacheHttp, model::id::UserId};
|
||||
use sqlx::MySqlPool;
|
||||
|
||||
use crate::consts::LOCAL_TIMEZONE;
|
||||
@ -11,7 +8,6 @@ use crate::consts::LOCAL_TIMEZONE;
|
||||
pub struct UserData {
|
||||
pub id: u32,
|
||||
pub user: u64,
|
||||
pub name: String,
|
||||
pub dm_channel: u32,
|
||||
pub timezone: String,
|
||||
}
|
||||
@ -40,20 +36,20 @@ SELECT timezone FROM users WHERE user = ?
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn from_user(
|
||||
user: &User,
|
||||
pub async fn from_user<U: Into<UserId>>(
|
||||
user: U,
|
||||
ctx: impl CacheHttp,
|
||||
pool: &MySqlPool,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
|
||||
let user_id = user.id.as_u64().to_owned();
|
||||
let user_id = user.into();
|
||||
|
||||
match sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ?
|
||||
SELECT id, user, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ?
|
||||
",
|
||||
*LOCAL_TIMEZONE,
|
||||
user_id
|
||||
user_id.0
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
@ -61,27 +57,24 @@ SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone
|
||||
Ok(c) => Ok(c),
|
||||
|
||||
Err(sqlx::Error::RowNotFound) => {
|
||||
let dm_channel = user.create_dm_channel(ctx).await?;
|
||||
let dm_id = dm_channel.id.as_u64().to_owned();
|
||||
|
||||
let dm_channel = user_id.create_dm_channel(ctx).await?;
|
||||
let pool_c = pool.clone();
|
||||
|
||||
sqlx::query!(
|
||||
"
|
||||
INSERT IGNORE INTO channels (channel) VALUES (?)
|
||||
",
|
||||
dm_id
|
||||
dm_channel.id.0
|
||||
)
|
||||
.execute(&pool_c)
|
||||
.await?;
|
||||
|
||||
sqlx::query!(
|
||||
"
|
||||
INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?)
|
||||
INSERT INTO users (user, dm_channel, timezone) VALUES (?, (SELECT id FROM channels WHERE channel = ?), ?)
|
||||
",
|
||||
user_id,
|
||||
user.name,
|
||||
dm_id,
|
||||
user_id.0,
|
||||
dm_channel.id.0,
|
||||
*LOCAL_TIMEZONE
|
||||
)
|
||||
.execute(&pool_c)
|
||||
@ -90,9 +83,9 @@ INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FR
|
||||
Ok(sqlx::query_as_unchecked!(
|
||||
Self,
|
||||
"
|
||||
SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ?
|
||||
SELECT id, user, dm_channel, timezone FROM users WHERE user = ?
|
||||
",
|
||||
user_id
|
||||
user_id.0
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?)
|
||||
@ -109,9 +102,8 @@ SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ?
|
||||
pub async fn commit_changes(&self, pool: &MySqlPool) {
|
||||
sqlx::query!(
|
||||
"
|
||||
UPDATE users SET name = ?, timezone = ? WHERE id = ?
|
||||
UPDATE users SET timezone = ? WHERE id = ?
|
||||
",
|
||||
self.name,
|
||||
self.timezone,
|
||||
self.id
|
||||
)
|
||||
|
Reference in New Issue
Block a user