components

This commit is contained in:
jellywx 2021-09-11 20:40:58 +01:00
parent 9b5333dc87
commit 3e547861ea
11 changed files with 310 additions and 237 deletions

23
Cargo.lock generated
View File

@ -1230,7 +1230,7 @@ dependencies = [
"regex",
"regex_command_attr",
"reqwest",
"ring",
"rmp-serde",
"serde",
"serde_json",
"serenity",
@ -1302,6 +1302,27 @@ dependencies = [
"winapi",
]
[[package]]
name = "rmp"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f55e5fa1446c4d5dd1f5daeed2a4fe193071771a2636274d0d7a3b082aa7ad6"
dependencies = [
"byteorder",
"num-traits",
]
[[package]]
name = "rmp-serde"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "723ecff9ad04f4ad92fe1c8ca6c20d2196d9286e9c60727c4cb5511629260e9d"
dependencies = [
"byteorder",
"rmp",
"serde",
]
[[package]]
name = "rsa"
version = "0.4.1"

View File

@ -19,11 +19,11 @@ lazy_static = "1.4"
num-integer = "0.1"
serde = "1.0"
serde_json = "1.0"
rmp-serde = "0.15"
rand = "0.7"
levenshtein = "1.0"
serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["collector", "unstable_discord_api"] }
sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]}
ring = "0.16"
base64 = "0.13.0"
[dependencies.regex_command_attr]

View File

@ -17,6 +17,7 @@ use serenity::{
};
use crate::{
component_models::{ComponentDataModel, Restrict},
consts::{REGEX_ALIAS, REGEX_COMMANDS, THEME_COLOR},
framework::{CommandInvoke, CreateGenericResponse, PermissionLevel},
models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxData},
@ -264,6 +265,8 @@ async fn restrict(
let len = restrictable_commands.len();
let restrict_pl = ComponentDataModel::Restrict(Restrict { role_id: role });
invoke
.respond(
ctx.http.clone(),
@ -273,7 +276,7 @@ async fn restrict(
c.create_action_row(|row| {
row.create_select_menu(|select| {
select
.custom_id("test_id")
.custom_id(restrict_pl.to_custom_id())
.options(|options| {
for command in restrictable_commands {
options.create_option(|opt| {

View File

@ -10,14 +10,19 @@ use num_integer::Integer;
use regex_command_attr::command;
use serenity::{
client::Context,
model::channel::{Channel, Message},
futures::StreamExt,
model::{
channel::{Channel, Message},
id::ChannelId,
misc::Mentionable,
},
};
use crate::{
check_subscription_on_message,
consts::{
REGEX_CHANNEL_USER, REGEX_NATURAL_COMMAND_1, REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND,
THEME_COLOR,
EMBED_DESCRIPTION_MAX_LENGTH, REGEX_CHANNEL_USER, REGEX_NATURAL_COMMAND_1,
REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND, THEME_COLOR,
},
framework::{CommandInvoke, CreateGenericResponse},
models::{
@ -26,7 +31,7 @@ use crate::{
reminder::{
builder::{MultiReminderBuilder, ReminderScope},
content::Content,
look_flags::LookFlags,
look_flags::{LookFlags, TimeDisplayType},
Reminder,
},
timer::Timer,
@ -116,146 +121,249 @@ async fn pause(
}
}
/*
#[command]
#[permission_level(Restricted)]
async fn offset(ctx: &Context, msg: &Message, args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
#[command("offset")]
#[description("Move all reminders in the current server by a certain amount of time. Times get added together")]
#[arg(
name = "hours",
description = "Number of hours to offset by",
kind = "Integer",
required = false
)]
#[arg(
name = "minutes",
description = "Number of minutes to offset by",
kind = "Integer",
required = false
)]
#[arg(
name = "seconds",
description = "Number of seconds to offset by",
kind = "Integer",
required = false
)]
#[required_permissions(Restricted)]
async fn offset(
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let combined_time = args.get("hours").map_or(0, |h| h.parse::<i64>().unwrap() * 3600)
+ args.get("minutes").map_or(0, |m| m.parse::<i64>().unwrap() * 60)
+ args.get("seconds").map_or(0, |s| s.parse::<i64>().unwrap());
if args.is_empty() {
let prefix = ctx.prefix(msg.guild_id).await;
command_help(ctx, msg, lm, &prefix, &user_data.language, "offset").await;
if combined_time == 0 {
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new()
.content("Please specify one of `hours`, `minutes` or `seconds`"),
)
.await;
} else {
let parser = TimeParser::new(&args, user_data.timezone());
if let Some(guild) = invoke.guild(ctx.cache.clone()) {
let channels = guild
.channels
.iter()
.filter(|(channel_id, channel)| match channel {
Channel::Guild(guild_channel) => guild_channel.is_text_based(),
_ => false,
})
.map(|(id, _)| id.0.to_string())
.collect::<Vec<String>>()
.join(",");
if let Ok(displacement) = parser.displacement() {
if let Some(guild) = msg.guild(&ctx) {
let guild_data = GuildData::from_guild(guild, &pool).await.unwrap();
sqlx::query!(
"
sqlx::query!(
"
UPDATE reminders
INNER JOIN `channels`
ON `channels`.id = reminders.channel_id
SET
reminders.`utc_time` = reminders.`utc_time` + ?
WHERE channels.guild_id = ?
",
displacement,
guild_data.id
)
.execute(&pool)
.await
.unwrap();
} else {
sqlx::query!(
"
UPDATE reminders SET `utc_time` = `utc_time` + ? WHERE reminders.channel_id = ?
",
displacement,
user_data.dm_channel
)
.execute(&pool)
.await
.unwrap();
}
let response = lm.get(&user_data.language, "offset/success").replacen(
"{}",
&displacement.to_string(),
1,
);
let _ = msg.channel_id.say(&ctx, response).await;
INNER JOIN
`channels` ON `channels`.id = reminders.channel_id
SET reminders.`utc_time` = reminders.`utc_time` + ?
WHERE FIND_IN_SET(channels.`channel`, ?)",
combined_time,
channels
)
.execute(&pool)
.await
.unwrap();
} else {
let _ =
msg.channel_id.say(&ctx, lm.get(&user_data.language, "offset/invalid_time")).await;
sqlx::query!(
"UPDATE reminders INNER JOIN `channels` ON `channels`.id = reminders.channel_id SET reminders.`utc_time` = reminders.`utc_time` + ? WHERE channels.`channel` = ?",
combined_time,
invoke.channel_id().0
)
.execute(&pool)
.await
.unwrap();
}
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new()
.content(format!("All reminders offset by {} seconds", combined_time)),
)
.await;
}
}
#[command]
#[permission_level(Restricted)]
async fn nudge(ctx: &Context, msg: &Message, args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
#[command("nudge")]
#[description("Nudge all future reminders on this channel by a certain amount (don't use for DST! See `/offset`)")]
#[arg(
name = "minutes",
description = "Number of minutes to nudge new reminders by",
kind = "Integer",
required = false
)]
#[arg(
name = "seconds",
description = "Number of seconds to nudge new reminders by",
kind = "Integer",
required = false
)]
#[required_permissions(Restricted)]
async fn nudge(
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let language = UserData::language_of(&msg.author, &pool).await;
let timezone = UserData::timezone_of(&msg.author, &pool).await;
let combined_time = args.get("minutes").map_or(0, |m| m.parse::<i64>().unwrap() * 60)
+ args.get("seconds").map_or(0, |s| s.parse::<i64>().unwrap());
let mut channel =
ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap();
if args.is_empty() {
let content = lm
.get(&language, "nudge/no_argument")
.replace("{nudge}", &format!("{}s", &channel.nudge.to_string()));
let _ = msg.channel_id.say(&ctx, content).await;
if combined_time < i16::MIN as i64 || combined_time > i16::MAX as i64 {
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content("Nudge times must be less than 500 minutes"),
)
.await;
} else {
let parser = TimeParser::new(&args, timezone);
let nudge_time = parser.displacement();
let mut channel_data = ctx.channel_data(invoke.channel_id()).await.unwrap();
match nudge_time {
Ok(displacement) => {
if displacement < i16::MIN as i64 || displacement > i16::MAX as i64 {
let _ = msg.channel_id.say(&ctx, lm.get(&language, "nudge/invalid_time")).await;
} else {
channel.nudge = displacement as i16;
channel_data.nudge = combined_time as i16;
channel_data.commit_changes(&pool).await;
channel.commit_changes(&pool).await;
let response = lm.get(&language, "nudge/success").replacen(
"{}",
&displacement.to_string(),
1,
);
let _ = msg.channel_id.say(&ctx, response).await;
}
}
Err(_) => {
let _ = msg.channel_id.say(&ctx, lm.get(&language, "nudge/invalid_time")).await;
}
}
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content(format!(
"Future reminders will be nudged by {} seconds",
combined_time
)),
)
.await;
}
}
#[command("look")]
#[permission_level(Managed)]
async fn look(ctx: &Context, msg: &Message, args: String) {
let (pool, _lm) = get_ctx_data(&ctx).await;
#[description("View reminders on a specific channel")]
#[arg(
name = "channel",
description = "The channel to view reminders on",
kind = "Channel",
required = false
)]
#[arg(
name = "disabled",
description = "Whether to show disabled reminders or not",
kind = "Boolean",
required = false
)]
#[arg(
name = "relative",
description = "Whether to display times as relative or exact times",
kind = "Boolean",
required = false
)]
#[required_permissions(Managed)]
async fn look(
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let timezone = UserData::timezone_of(&msg.author, &pool).await;
let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await;
let flags = LookFlags::from_string(&args);
let flags = LookFlags {
show_disabled: args.get("disabled").map(|b| b == "true").unwrap_or(true),
channel_id: args.get("channel").map(|c| ChannelId(c.parse::<u64>().unwrap())),
time_display: args.get("relative").map_or(TimeDisplayType::Relative, |b| {
if b == "true" {
TimeDisplayType::Relative
} else {
TimeDisplayType::Absolute
}
}),
};
let channel_opt = msg.channel_id.to_channel_cached(&ctx);
let channel_opt = invoke.channel_id().to_channel_cached(&ctx);
let channel_id = if let Some(Channel::Guild(channel)) = channel_opt {
if Some(channel.guild_id) == msg.guild_id {
flags.channel_id.unwrap_or(msg.channel_id)
if Some(channel.guild_id) == invoke.guild_id() {
flags.channel_id.unwrap_or(invoke.channel_id())
} else {
msg.channel_id
invoke.channel_id()
}
} else {
msg.channel_id
invoke.channel_id()
};
let reminders = Reminder::from_channel(ctx, channel_id, &flags).await;
if reminders.is_empty() {
let _ = msg.channel_id.say(&ctx, "No reminders on specified channel").await;
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content("No reminders on specified channel"),
)
.await;
} else {
let display = reminders.iter().map(|reminder| reminder.display(&flags, &timezone));
let mut char_count = 0;
let _ = msg.channel_id.say_lines(&ctx, display).await;
let display = reminders
.iter()
.map(|reminder| reminder.display(&flags, &timezone))
.take_while(|p| {
char_count += p.len();
char_count < EMBED_DESCRIPTION_MAX_LENGTH
})
.collect::<Vec<String>>()
.join("\n");
let pages = reminders
.iter()
.map(|reminder| reminder.display(&flags, &timezone))
.fold(0, |t, r| t + r.len())
.div_ceil(EMBED_DESCRIPTION_MAX_LENGTH);
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new()
.embed(|e| {
e.title(format!("Reminders on {}", channel_id.mention()))
.description(display)
.footer(|f| f.text(format!("Page {} of {}", 1, pages)))
})
.components(|comp| {
comp.create_action_row(|row| {
row.create_button(|b| b.label("⏮️").custom_id(".1"))
.create_button(|b| b.label("◀️").custom_id(".2"))
.create_button(|b| b.label("▶️").custom_id(".3"))
.create_button(|b| b.label("⏭️").custom_id(".4"))
})
}),
)
.await;
}
}
/*
#[command("del")]
#[permission_level(Managed)]
async fn delete(ctx: &Context, msg: &Message, _args: String) {

View File

@ -0,0 +1,51 @@
use std::io::Cursor;
use rmp_serde::Serializer;
use serde::{Deserialize, Serialize};
use serenity::model::{
id::{ChannelId, RoleId},
interactions::message_component::MessageComponentInteraction,
};
use crate::models::reminder::look_flags::LookFlags;
#[derive(Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ComponentDataModel {
Restrict(Restrict),
LookPager(LookPager),
}
impl ComponentDataModel {
pub fn to_custom_id(&self) -> String {
let mut buf = Vec::new();
self.serialize(&mut Serializer::new(&mut buf)).unwrap();
base64::encode(buf)
}
pub fn from_custom_id(data: &String) -> Self {
let buf = base64::decode(data).unwrap();
let cur = Cursor::new(buf);
rmp_serde::from_read(cur).unwrap()
}
pub async fn act(&self, component: MessageComponentInteraction) {
match self {
ComponentDataModel::Restrict(restrict) => {
println!("{:?}", component.data.values);
}
ComponentDataModel::LookPager(pager) => {}
}
}
}
#[derive(Deserialize, Serialize)]
pub struct Restrict {
pub role_id: RoleId,
}
#[derive(Deserialize, Serialize)]
pub struct LookPager {
pub flags: LookFlags,
pub page_request: u16,
}

View File

@ -1,6 +1,7 @@
pub const DAY: u64 = 86_400;
pub const HOUR: u64 = 3_600;
pub const MINUTE: u64 = 60;
pub const EMBED_DESCRIPTION_MAX_LENGTH: usize = 4000;
pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";

View File

@ -1,7 +1,9 @@
#![feature(int_roundings)]
#[macro_use]
extern crate lazy_static;
mod commands;
mod component_models;
mod consts;
mod framework;
mod models;
@ -34,6 +36,7 @@ use tokio::sync::RwLock;
use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds},
component_models::ComponentDataModel,
consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR},
framework::RegexFramework,
models::guild_data::GuildData,
@ -253,6 +256,10 @@ DELETE FROM guilds WHERE guild = ?
framework.execute(ctx, application_command).await;
}
Interaction::MessageComponent(component) => {
let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id);
component_model.act(component).await;
}
_ => {}
}
}
@ -298,13 +305,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.add_command("n", &reminder_cmds::NATURAL_COMMAND)
.add_command("", &reminder_cmds::NATURAL_COMMAND)
// management commands
.add_command("look", &reminder_cmds::LOOK_COMMAND)
.add_command("del", &reminder_cmds::DELETE_COMMAND)
*/
.add_command(&reminder_cmds::LOOK_COMMAND)
.add_command(&reminder_cmds::PAUSE_COMMAND)
.add_command(&reminder_cmds::OFFSET_COMMAND)
.add_command(&reminder_cmds::NUDGE_COMMAND)
/*
.add_command("offset", &reminder_cmds::OFFSET_COMMAND)
.add_command("nudge", &reminder_cmds::NUDGE_COMMAND)
// to-do commands
.add_command("todo", &todo_cmds::TODO_USER_COMMAND)
.add_command("todo user", &todo_cmds::TODO_USER_COMMAND)

View File

@ -1,14 +1,16 @@
use serde::{Deserialize, Serialize};
use serenity::model::id::ChannelId;
use crate::consts::REGEX_CHANNEL;
#[derive(Serialize, Deserialize)]
pub enum TimeDisplayType {
Absolute,
Relative,
Absolute = 0,
Relative = 1,
}
#[derive(Serialize, Deserialize)]
pub struct LookFlags {
pub limit: u16,
pub show_disabled: bool,
pub channel_id: Option<ChannelId>,
pub time_display: TimeDisplayType,
@ -16,44 +18,6 @@ pub struct LookFlags {
impl Default for LookFlags {
fn default() -> Self {
Self {
limit: u16::MAX,
show_disabled: true,
channel_id: None,
time_display: TimeDisplayType::Relative,
}
}
}
impl LookFlags {
pub fn from_string(args: &str) -> Self {
let mut new_flags: Self = Default::default();
for arg in args.split(' ') {
match arg {
"enabled" => {
new_flags.show_disabled = false;
}
"time" => {
new_flags.time_display = TimeDisplayType::Absolute;
}
param => {
if let Ok(val) = param.parse::<u16>() {
new_flags.limit = val;
} else if let Some(channel) = REGEX_CHANNEL
.captures(arg)
.map(|cap| cap.get(1))
.flatten()
.map(|c| c.as_str().parse::<u64>().unwrap())
{
new_flags.channel_id = Some(ChannelId(channel));
}
}
}
}
new_flags
Self { show_disabled: true, channel_id: None, time_display: TimeDisplayType::Relative }
}
}

View File

@ -11,7 +11,6 @@ use std::{
use chrono::{NaiveDateTime, TimeZone};
use chrono_tz::Tz;
use ring::hmac;
use serenity::{
client::Context,
model::id::{ChannelId, GuildId, UserId},
@ -27,31 +26,6 @@ use crate::{
SQLPool,
};
#[derive(Clone, Copy)]
pub enum ReminderAction {
Delete,
}
impl ToString for ReminderAction {
fn to_string(&self) -> String {
match self {
Self::Delete => String::from("del"),
}
}
}
impl TryFrom<&str> for ReminderAction {
type Error = ();
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"del" => Ok(Self::Delete),
_ => Err(()),
}
}
}
#[derive(Debug)]
pub struct Reminder {
pub id: u32,
@ -178,12 +152,9 @@ WHERE
FIND_IN_SET(reminders.enabled, ?)
ORDER BY
reminders.utc_time
LIMIT
?
",
channel_id.as_u64(),
enabled,
flags.limit
)
.fetch_all(&pool)
.await
@ -341,59 +312,6 @@ WHERE
}
}
pub async fn from_interaction<U: Into<u64>>(
ctx: &Context,
member_id: U,
payload: String,
) -> Result<(Self, ReminderAction), InteractionError> {
let sections = payload.split('.').collect::<Vec<&str>>();
if sections.len() != 3 {
Err(InteractionError::InvalidFormat)
} else {
let action = ReminderAction::try_from(sections[0])
.map_err(|_| InteractionError::InvalidAction)?;
let reminder_id = u32::from_le_bytes(
base64::decode(sections[1])
.map_err(|_| InteractionError::InvalidBase64)?
.try_into()
.map_err(|_| InteractionError::InvalidSize)?,
);
if let Some(reminder) = Self::from_id(ctx, reminder_id).await {
if reminder.signed_action(member_id, action) == payload {
Ok((reminder, action))
} else {
Err(InteractionError::SignatureMismatch)
}
} else {
Err(InteractionError::NoReminder)
}
}
}
pub fn signed_action<U: Into<u64>>(&self, member_id: U, action: ReminderAction) -> String {
let s_key = hmac::Key::new(
hmac::HMAC_SHA256,
env::var("SECRET_KEY").expect("No SECRET_KEY provided").as_bytes(),
);
let mut context = hmac::Context::with_key(&s_key);
context.update(&self.id.to_le_bytes());
context.update(&member_id.into().to_le_bytes());
let signature = context.sign();
format!(
"{}.{}.{}",
action.to_string(),
base64::encode(self.id.to_le_bytes()),
base64::encode(&signature)
)
}
pub async fn delete(&self, ctx: &Context) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();