This commit is contained in:
jellywx 2021-09-12 16:59:19 +01:00
parent bae0433bd9
commit 395a8481f1
10 changed files with 242 additions and 134 deletions

View File

@ -27,4 +27,4 @@ sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql",
base64 = "0.13.0" base64 = "0.13.0"
[dependencies.regex_command_attr] [dependencies.regex_command_attr]
path = "./regex_command_attr" path = "command_attributes"

View File

@ -9,9 +9,7 @@ use serenity::{
client::Context, client::Context,
model::{ model::{
channel::Message, channel::Message,
guild::ActionRole::Create,
id::{ChannelId, MessageId, RoleId}, id::{ChannelId, MessageId, RoleId},
interactions::message_component::ButtonStyle,
misc::Mentionable, misc::Mentionable,
}, },
}; };
@ -19,7 +17,9 @@ use serenity::{
use crate::{ use crate::{
component_models::{ComponentDataModel, Restrict}, component_models::{ComponentDataModel, Restrict},
consts::{REGEX_ALIAS, REGEX_COMMANDS, THEME_COLOR}, consts::{REGEX_ALIAS, REGEX_COMMANDS, THEME_COLOR},
framework::{CommandInvoke, CreateGenericResponse, PermissionLevel}, framework::{
CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue, PermissionLevel,
},
models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxData}, models::{channel_data::ChannelData, guild_data::GuildData, user_data::UserData, CtxData},
PopularTimezones, RegexFramework, SQLPool, PopularTimezones, RegexFramework, SQLPool,
}; };
@ -38,14 +38,14 @@ use crate::{
async fn blacklist( async fn blacklist(
ctx: &Context, ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync), invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>, args: CommandOptions,
) { ) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let channel = match args.get("channel") { let channel = match args.get("channel") {
Some(channel_id) => ChannelId(channel_id.parse::<u64>().unwrap()), Some(OptionValue::Channel(channel_id)) => *channel_id,
None => invoke.channel_id(), _ => invoke.channel_id(),
} }
.to_channel_cached(&ctx) .to_channel_cached(&ctx)
.unwrap(); .unwrap();
@ -82,17 +82,13 @@ async fn blacklist(
kind = "String", kind = "String",
required = false required = false
)] )]
async fn timezone( async fn timezone(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) {
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap(); let mut user_data = ctx.user_data(invoke.author_id()).await.unwrap();
let footer_text = format!("Current timezone: {}", user_data.timezone); let footer_text = format!("Current timezone: {}", user_data.timezone);
if let Some(timezone) = args.get("timezone") { if let Some(OptionValue::String(timezone)) = args.get("timezone") {
match timezone.parse::<Tz>() { match timezone.parse::<Tz>() {
Ok(tz) => { Ok(tz) => {
user_data.timezone = timezone.clone(); user_data.timezone = timezone.clone();
@ -237,16 +233,11 @@ async fn prefix(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args:
)] )]
#[supports_dm(false)] #[supports_dm(false)]
#[required_permissions(Restricted)] #[required_permissions(Restricted)]
async fn restrict( async fn restrict(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) {
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let framework = ctx.data.read().await.get::<RegexFramework>().cloned().unwrap(); let framework = ctx.data.read().await.get::<RegexFramework>().cloned().unwrap();
let role = RoleId(args.get("role").unwrap().parse::<u64>().unwrap()); if let Some(OptionValue::Role(role)) = args.get("role") {
let restricted_commands = let restricted_commands =
sqlx::query!("SELECT command FROM command_restrictions WHERE role_id = ?", role.0) sqlx::query!("SELECT command FROM command_restrictions WHERE role_id = ?", role.0)
.fetch_all(&pool) .fetch_all(&pool)
@ -265,13 +256,16 @@ async fn restrict(
let len = restrictable_commands.len(); let len = restrictable_commands.len();
let restrict_pl = ComponentDataModel::Restrict(Restrict { role_id: role }); let restrict_pl = ComponentDataModel::Restrict(Restrict { role_id: *role });
invoke invoke
.respond( .respond(
ctx.http.clone(), ctx.http.clone(),
CreateGenericResponse::new() CreateGenericResponse::new()
.content(format!("Select the commands to allow to {} from below:", role.mention())) .content(format!(
"Select the commands to allow to {} from below:",
role.mention()
))
.components(|c| { .components(|c| {
c.create_action_row(|row| { c.create_action_row(|row| {
row.create_select_menu(|select| { row.create_select_menu(|select| {
@ -280,7 +274,9 @@ async fn restrict(
.options(|options| { .options(|options| {
for command in restrictable_commands { for command in restrictable_commands {
options.create_option(|opt| { options.create_option(|opt| {
opt.label(&command).value(&command).default_selection( opt.label(&command)
.value(&command)
.default_selection(
restricted_commands.contains(&command), restricted_commands.contains(&command),
) )
}); });
@ -297,6 +293,7 @@ async fn restrict(
.await .await
.unwrap(); .unwrap();
} }
}
/* /*
#[command("alias")] #[command("alias")]

View File

@ -24,7 +24,7 @@ use crate::{
EMBED_DESCRIPTION_MAX_LENGTH, REGEX_CHANNEL_USER, REGEX_NATURAL_COMMAND_1, EMBED_DESCRIPTION_MAX_LENGTH, REGEX_CHANNEL_USER, REGEX_NATURAL_COMMAND_1,
REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND, THEME_COLOR, REGEX_NATURAL_COMMAND_2, REGEX_REMIND_COMMAND, THEME_COLOR,
}, },
framework::{CommandInvoke, CreateGenericResponse}, framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue},
models::{ models::{
channel_data::ChannelData, channel_data::ChannelData,
guild_data::GuildData, guild_data::GuildData,
@ -52,11 +52,7 @@ use crate::{
)] )]
#[supports_dm(false)] #[supports_dm(false)]
#[required_permissions(Restricted)] #[required_permissions(Restricted)]
async fn pause( async fn pause(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) {
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await;
@ -64,7 +60,7 @@ async fn pause(
let mut channel = ctx.channel_data(invoke.channel_id()).await.unwrap(); let mut channel = ctx.channel_data(invoke.channel_id()).await.unwrap();
match args.get("until") { match args.get("until") {
Some(until) => { Some(OptionValue::String(until)) => {
let parsed = natural_parser(until, &timezone.to_string()).await; let parsed = natural_parser(until, &timezone.to_string()).await;
if let Some(timestamp) = parsed { if let Some(timestamp) = parsed {
@ -94,7 +90,7 @@ async fn pause(
.await; .await;
} }
} }
None => { _ => {
channel.paused = !channel.paused; channel.paused = !channel.paused;
channel.paused_until = None; channel.paused_until = None;
@ -142,16 +138,12 @@ async fn pause(
required = false required = false
)] )]
#[required_permissions(Restricted)] #[required_permissions(Restricted)]
async fn offset( async fn offset(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) {
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let combined_time = args.get("hours").map_or(0, |h| h.parse::<i64>().unwrap() * 3600) let combined_time = args.get("hours").map_or(0, |h| h.as_i64().unwrap() * 3600)
+ args.get("minutes").map_or(0, |m| m.parse::<i64>().unwrap() * 60) + args.get("minutes").map_or(0, |m| m.as_i64().unwrap() * 60)
+ args.get("seconds").map_or(0, |s| s.parse::<i64>().unwrap()); + args.get("seconds").map_or(0, |s| s.as_i64().unwrap());
if combined_time == 0 { if combined_time == 0 {
let _ = invoke let _ = invoke
@ -223,15 +215,11 @@ WHERE FIND_IN_SET(channels.`channel`, ?)",
required = false required = false
)] )]
#[required_permissions(Restricted)] #[required_permissions(Restricted)]
async fn nudge( async fn nudge(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) {
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let combined_time = args.get("minutes").map_or(0, |m| m.parse::<i64>().unwrap() * 60) let combined_time = args.get("minutes").map_or(0, |m| m.as_i64().unwrap() * 60)
+ args.get("seconds").map_or(0, |s| s.parse::<i64>().unwrap()); + args.get("seconds").map_or(0, |s| s.as_i64().unwrap());
if combined_time < i16::MIN as i64 || combined_time > i16::MAX as i64 { if combined_time < i16::MIN as i64 || combined_time > i16::MAX as i64 {
let _ = invoke let _ = invoke
@ -279,20 +267,16 @@ async fn nudge(
required = false required = false
)] )]
#[required_permissions(Managed)] #[required_permissions(Managed)]
async fn look( async fn look(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) {
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap(); let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await; let timezone = UserData::timezone_of(&invoke.author_id(), &pool).await;
let flags = LookFlags { let flags = LookFlags {
show_disabled: args.get("disabled").map(|b| b == "true").unwrap_or(true), show_disabled: args.get("disabled").map(|i| i.as_bool()).flatten().unwrap_or(true),
channel_id: args.get("channel").map(|c| ChannelId(c.parse::<u64>().unwrap())), channel_id: args.get("channel").map(|i| i.as_channel_id()).flatten(),
time_display: args.get("relative").map_or(TimeDisplayType::Relative, |b| { time_display: args.get("relative").map_or(TimeDisplayType::Relative, |b| {
if b == "true" { if b.as_bool() == Some(true) {
TimeDisplayType::Relative TimeDisplayType::Relative
} else { } else {
TimeDisplayType::Absolute TimeDisplayType::Absolute
@ -473,11 +457,7 @@ INSERT INTO events (event_name, bulk_count, guild_id, user_id) VALUES ('delete',
#[description("Delete a timer")] #[description("Delete a timer")]
#[arg(name = "name", description = "Name of the timer to delete", kind = "String", required = true)] #[arg(name = "name", description = "Name of the timer to delete", kind = "String", required = true)]
#[required_permissions(Managed)] #[required_permissions(Managed)]
async fn timer( async fn timer(ctx: &Context, invoke: &(dyn CommandInvoke + Send + Sync), args: CommandOptions) {
ctx: &Context,
invoke: &(dyn CommandInvoke + Send + Sync),
args: HashMap<String, String>,
) {
fn time_difference(start_time: NaiveDateTime) -> String { fn time_difference(start_time: NaiveDateTime) -> String {
let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
let now = NaiveDateTime::from_timestamp(unix_time, 0); let now = NaiveDateTime::from_timestamp(unix_time, 0);
@ -495,8 +475,8 @@ async fn timer(
let owner = invoke.guild_id().map(|g| g.0).unwrap_or_else(|| invoke.author_id().0); let owner = invoke.guild_id().map(|g| g.0).unwrap_or_else(|| invoke.author_id().0);
match args.get("").map(|s| s.as_str()) { match args.subcommand.clone().unwrap().as_str() {
Some("start") => { "start" => {
let count = Timer::count_from_owner(owner, &pool).await; let count = Timer::count_from_owner(owner, &pool).await;
if count >= 25 { if count >= 25 {
@ -508,7 +488,7 @@ async fn timer(
) )
.await; .await;
} else { } else {
let name = args.get("name").unwrap(); let name = args.get("name").unwrap().to_string();
if name.len() <= 32 { if name.len() <= 32 {
Timer::create(&name, owner, &pool).await; Timer::create(&name, owner, &pool).await;
@ -530,8 +510,8 @@ async fn timer(
} }
} }
} }
Some("delete") => { "delete" => {
let name = args.get("name").unwrap(); let name = args.get("name").unwrap().to_string();
let exists = sqlx::query!( let exists = sqlx::query!(
" "
@ -570,7 +550,7 @@ DELETE FROM timers WHERE owner = ? AND name = ?
.await; .await;
} }
} }
Some("list") => { "list" => {
let timers = Timer::from_owner(owner, &pool).await; let timers = Timer::from_owner(owner, &pool).await;
if timers.len() > 0 { if timers.len() > 0 {

View File

@ -18,13 +18,15 @@ use serenity::{
model::{ model::{
channel::{Channel, GuildChannel, Message}, channel::{Channel, GuildChannel, Message},
guild::{Guild, Member}, guild::{Guild, Member},
id::{ChannelId, GuildId, MessageId, UserId}, id::{ChannelId, GuildId, MessageId, RoleId, UserId},
interactions::{ interactions::{
application_command::{ application_command::{
ApplicationCommand, ApplicationCommandInteraction, ApplicationCommandOptionType, ApplicationCommand, ApplicationCommandInteraction, ApplicationCommandOption,
ApplicationCommandOptionType,
}, },
InteractionResponseType, InteractionResponseType,
}, },
prelude::application_command::ApplicationCommandInteractionDataOption,
}, },
prelude::TypeMapKey, prelude::TypeMapKey,
FutureExt, Result as SerenityResult, FutureExt, Result as SerenityResult,
@ -281,10 +283,166 @@ pub struct Arg {
pub options: &'static [&'static Self], pub options: &'static [&'static Self],
} }
pub enum OptionValue {
String(String),
Integer(i64),
Boolean(bool),
User(UserId),
Channel(ChannelId),
Role(RoleId),
Mentionable(u64),
Number(f64),
}
impl OptionValue {
pub fn as_i64(&self) -> Option<i64> {
match self {
OptionValue::Integer(i) => Some(*i),
_ => None,
}
}
pub fn as_bool(&self) -> Option<bool> {
match self {
OptionValue::Boolean(b) => Some(*b),
_ => None,
}
}
pub fn as_channel_id(&self) -> Option<ChannelId> {
match self {
OptionValue::Channel(c) => Some(*c),
_ => None,
}
}
pub fn to_string(&self) -> String {
match self {
OptionValue::String(s) => s.to_string(),
OptionValue::Integer(i) => i.to_string(),
OptionValue::Boolean(b) => b.to_string(),
OptionValue::User(u) => u.to_string(),
OptionValue::Channel(c) => c.to_string(),
OptionValue::Role(r) => r.to_string(),
OptionValue::Mentionable(m) => m.to_string(),
OptionValue::Number(n) => n.to_string(),
}
}
}
pub struct CommandOptions {
pub command: String,
pub subcommand: Option<String>,
pub subcommand_group: Option<String>,
pub options: HashMap<String, OptionValue>,
}
impl CommandOptions {
pub fn get(&self, key: &str) -> Option<&OptionValue> {
self.options.get(key)
}
}
impl From<ApplicationCommandInteraction> for CommandOptions {
fn from(interaction: ApplicationCommandInteraction) -> Self {
fn match_option(
option: ApplicationCommandInteractionDataOption,
cmd_opts: &mut CommandOptions,
) {
match option.kind {
ApplicationCommandOptionType::SubCommand => {
cmd_opts.subcommand = Some(option.name);
for opt in option.options {
match_option(opt, cmd_opts);
}
}
ApplicationCommandOptionType::SubCommandGroup => {
cmd_opts.subcommand_group = Some(option.name);
for opt in option.options {
match_option(opt, cmd_opts);
}
}
ApplicationCommandOptionType::String => {
cmd_opts.options.insert(
option.name,
OptionValue::String(option.value.unwrap().to_string()),
);
}
ApplicationCommandOptionType::Integer => {
cmd_opts.options.insert(
option.name,
OptionValue::Integer(option.value.map(|m| m.as_i64()).flatten().unwrap()),
);
}
ApplicationCommandOptionType::Boolean => {
cmd_opts.options.insert(
option.name,
OptionValue::Boolean(option.value.map(|m| m.as_bool()).flatten().unwrap()),
);
}
ApplicationCommandOptionType::User => {
cmd_opts.options.insert(
option.name,
OptionValue::User(UserId(
option.value.map(|m| m.as_u64()).flatten().unwrap(),
)),
);
}
ApplicationCommandOptionType::Channel => {
cmd_opts.options.insert(
option.name,
OptionValue::Channel(ChannelId(
option.value.map(|m| m.as_u64()).flatten().unwrap(),
)),
);
}
ApplicationCommandOptionType::Role => {
cmd_opts.options.insert(
option.name,
OptionValue::Role(RoleId(
option.value.map(|m| m.as_u64()).flatten().unwrap(),
)),
);
}
ApplicationCommandOptionType::Mentionable => {
cmd_opts.options.insert(
option.name,
OptionValue::Mentionable(
option.value.map(|m| m.as_u64()).flatten().unwrap(),
),
);
}
ApplicationCommandOptionType::Number => {
cmd_opts.options.insert(
option.name,
OptionValue::Number(option.value.map(|m| m.as_f64()).flatten().unwrap()),
);
}
_ => {}
}
}
let mut cmd_opts = Self {
command: interaction.data.name,
subcommand: None,
subcommand_group: None,
options: Default::default(),
};
for option in interaction.data.options {
match_option(option, &mut cmd_opts)
}
cmd_opts
}
}
type SlashCommandFn = for<'fut> fn( type SlashCommandFn = for<'fut> fn(
&'fut Context, &'fut Context,
&'fut (dyn CommandInvoke + Sync + Send), &'fut (dyn CommandInvoke + Sync + Send),
HashMap<String, String>, CommandOptions,
) -> BoxFuture<'fut, ()>; ) -> BoxFuture<'fut, ()>;
type TextCommandFn = for<'fut> fn( type TextCommandFn = for<'fut> fn(
@ -631,34 +789,7 @@ impl RegexFramework {
let member = interaction.clone().member.unwrap(); let member = interaction.clone().member.unwrap();
if command.check_permissions(&ctx, &guild, &member).await { if command.check_permissions(&ctx, &guild, &member).await {
let mut args = HashMap::new(); let args = CommandOptions::from(interaction.clone());
for arg in interaction.data.options.iter() {
if let Some(value) = &arg.value {
args.insert(
arg.name.clone(),
match value {
Value::Bool(b) => b.to_string(),
Value::Number(n) => n.to_string(),
Value::String(s) => s.to_owned(),
_ => String::new(),
},
);
} else {
args.insert("".to_string(), arg.name.clone());
for sub_arg in arg.options.iter().filter(|o| o.value.is_some()) {
args.insert(
sub_arg.name.clone(),
match sub_arg.value.as_ref().unwrap() {
Value::Bool(b) => b.to_string(),
Value::Number(n) => n.to_string(),
Value::String(s) => s.to_owned(),
_ => String::new(),
},
);
}
}
}
if !ctx.check_executing(interaction.author_id()).await { if !ctx.check_executing(interaction.author_id()).await {
ctx.set_executing(interaction.author_id()).await; ctx.set_executing(interaction.author_id()).await;