From b2207e308a4348a1812d0780966cd757226056f4 Mon Sep 17 00:00:00 2001 From: jellywx Date: Thu, 16 Sep 2021 15:42:50 +0100 Subject: [PATCH] optimized packing slightly. restrict interactions --- Cargo.lock | 12 ++++++ Cargo.toml | 1 + src/commands/moderation_cmds.rs | 8 +++- src/commands/reminder_cmds.rs | 72 ++----------------------------- src/component_models/mod.rs | 48 +++++++++++++++++++-- src/framework.rs | 7 ++- src/models/reminder/look_flags.rs | 4 +- 7 files changed, 76 insertions(+), 76 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eaf4330..d737f43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1237,6 +1237,7 @@ dependencies = [ "rmp-serde", "serde", "serde_json", + "serde_repr", "serenity", "sqlx", "tokio", @@ -1446,6 +1447,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_repr" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98d0516900518c29efa217c298fa1f4e6c6ffc85ae29fd7f4ee48f176e1a9ed5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_urlencoded" version = "0.7.0" diff --git a/Cargo.toml b/Cargo.toml index bda59b8..3a7b0ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ lazy_static = "1.4" num-integer = "0.1" serde = "1.0" serde_json = "1.0" +serde_repr = "0.1" rmp-serde = "0.15" rand = "0.7" levenshtein = "1.0" diff --git a/src/commands/moderation_cmds.rs b/src/commands/moderation_cmds.rs index a68f9e1..aeaed4e 100644 --- a/src/commands/moderation_cmds.rs +++ b/src/commands/moderation_cmds.rs @@ -239,7 +239,7 @@ async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), arg if let Some(OptionValue::Role(role)) = args.get("role") { let restricted_commands = - sqlx::query!("SELECT command FROM command_restrictions WHERE role_id = ?", role.0) + sqlx::query!("SELECT command FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?)", role.0) .fetch_all(&pool) .await .unwrap() @@ -256,7 +256,11 @@ async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), arg let len = restrictable_commands.len(); - let restrict_pl = ComponentDataModel::Restrict(Restrict { role_id: *role }); + let restrict_pl = ComponentDataModel::Restrict(Restrict { + role_id: *role, + author_id: invoke.author_id(), + guild_id: invoke.guild_id().unwrap(), + }); invoke .respond( diff --git a/src/commands/reminder_cmds.rs b/src/commands/reminder_cmds.rs index 20599a7..6223d8c 100644 --- a/src/commands/reminder_cmds.rs +++ b/src/commands/reminder_cmds.rs @@ -405,22 +405,13 @@ async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: C } } -/* #[command("del")] #[description("Delete reminders")] -#[permission_level(Managed)] +#[required_permissions(Managed)] async fn delete(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { - let pool = ctx.data.read().await.get::().cloned().unwrap(); - - let _ = msg.channel_id.say(&ctx, lm.get(&user_data.language, "del/listing")).await; - - let mut reminder_ids: Vec = vec![]; - - let reminders = Reminder::from_guild(ctx, msg.guild_id, msg.author.id).await; + let reminders = Reminder::from_guild(ctx, invoke.guild_id(), invoke.author_id()).await; let enumerated_reminders = reminders.iter().enumerate().map(|(count, reminder)| { - reminder_ids.push(reminder.id); - format!( "**{}**: '{}' *<#{}>* at ", count + 1, @@ -430,65 +421,8 @@ async fn delete(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync)) { ) }); - let _ = msg.channel_id.say_lines(&ctx, enumerated_reminders).await; - let _ = msg.channel_id.say(&ctx, lm.get(&user_data.language, "del/listed")).await; - - let reply = - msg.channel_id.await_reply(&ctx).author_id(msg.author.id).channel_id(msg.channel_id).await; - - if let Some(content) = reply.map(|m| m.content.replace(",", " ")) { - let parts = content.split(' ').filter(|i| !i.is_empty()).collect::>(); - - let valid_parts = parts - .iter() - .filter_map(|i| { - i.parse::() - .ok() - .filter(|val| val > &0) - .map(|val| reminder_ids.get(val - 1)) - .flatten() - }) - .map(|item| item.to_string()) - .collect::>(); - - if parts.len() == valid_parts.len() { - let joined = valid_parts.join(","); - - let count_row = sqlx::query!( - " -SELECT COUNT(1) AS count FROM reminders WHERE FIND_IN_SET(id, ?) - ", - joined - ) - .fetch_one(&pool) - .await - .unwrap(); - - sqlx::query!( - " -DELETE FROM reminders WHERE FIND_IN_SET(id, ?) - ", - joined - ) - .execute(&pool) - .await - .unwrap(); - - let content = lm.get(&user_data.language, "del/count").replacen( - "{}", - &count_row.count.to_string(), - 1, - ); - - let _ = msg.channel_id.say(&ctx, content).await; - } else { - let content = lm.get(&user_data.language, "del/count").replacen("{}", "0", 1); - - let _ = msg.channel_id.say(&ctx, content).await; - } - } + let _ = invoke.respond(ctx.http.clone(), CreateGenericResponse::new().content("test")).await; } -*/ #[command("timer")] #[description("Manage timers")] diff --git a/src/component_models/mod.rs b/src/component_models/mod.rs index 2e95c50..ec9b4d1 100644 --- a/src/component_models/mod.rs +++ b/src/component_models/mod.rs @@ -3,16 +3,18 @@ use std::io::Cursor; use chrono_tz::Tz; use rmp_serde::Serializer; use serde::{Deserialize, Serialize}; +use serde_repr::*; use serenity::{ builder::CreateEmbed, client::Context, model::{ channel::Channel, - id::{ChannelId, RoleId}, + id::{ChannelId, GuildId, RoleId, UserId}, interactions::{ message_component::{ButtonStyle, MessageComponentInteraction}, InteractionResponseType, }, + prelude::InteractionApplicationCommandCallbackDataFlags, }, }; @@ -22,10 +24,12 @@ use crate::{ reminder::{look_flags::LookFlags, Reminder}, user_data::UserData, }, + SQLPool, }; #[derive(Deserialize, Serialize)] #[serde(tag = "type")] +#[repr(u8)] pub enum ComponentDataModel { Restrict(Restrict), LookPager(LookPager), @@ -47,7 +51,42 @@ impl ComponentDataModel { pub async fn act(&self, ctx: &Context, component: MessageComponentInteraction) { match self { ComponentDataModel::Restrict(restrict) => { - println!("{:?}", component.data.values); + if restrict.author_id == component.user.id { + let pool = ctx.data.read().await.get::().cloned().unwrap(); + + let _ = sqlx::query!( + " +INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, \"Role\", (SELECT id FROM guilds WHERE guild = ?)) + ", + restrict.role_id.0, + restrict.guild_id.0 + ) + .execute(&pool) + .await; + + for command in &component.data.values { + let _ = sqlx::query!( + "INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM roles WHERE role = ?), ?)", + restrict.role_id.0, + command + ) + .execute(&pool) + .await; + } + + component + .create_interaction_response(&ctx, |r| { + r.kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|response| response + .flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL) + .content("Role permissions updated") + ) + }) + .await + .unwrap(); + } else { + // tell them they cant do this + } } ComponentDataModel::LookPager(pager) => { let flags = pager.flags; @@ -184,9 +223,12 @@ impl ComponentDataModel { #[derive(Serialize, Deserialize)] pub struct Restrict { pub role_id: RoleId, + pub author_id: UserId, + pub guild_id: GuildId, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize_repr, Deserialize_repr, Debug)] +#[repr(u8)] pub enum PageAction { First = 0, Previous = 1, diff --git a/src/framework.rs b/src/framework.rs index 36c37cf..a714218 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -404,7 +404,12 @@ impl From for CommandOptions { cmd_opts.options.insert( option.name, OptionValue::Role(RoleId( - option.value.map(|m| m.as_u64()).flatten().unwrap(), + option + .value + .map(|m| m.as_str().map(|s| s.parse::().ok())) + .flatten() + .flatten() + .unwrap(), )), ); } diff --git a/src/models/reminder/look_flags.rs b/src/models/reminder/look_flags.rs index 03eb5db..04a684f 100644 --- a/src/models/reminder/look_flags.rs +++ b/src/models/reminder/look_flags.rs @@ -1,9 +1,11 @@ use serde::{Deserialize, Serialize}; +use serde_repr::*; use serenity::model::id::ChannelId; use crate::consts::REGEX_CHANNEL; -#[derive(Serialize, Deserialize, Copy, Clone, Debug)] +#[derive(Serialize_repr, Deserialize_repr, Copy, Clone, Debug)] +#[repr(u8)] pub enum TimeDisplayType { Absolute = 0, Relative = 1,