From 0e3c514f35b6e4dcd580ef468f248666a8c7e00c Mon Sep 17 00:00:00 2001 From: jude Date: Thu, 6 Aug 2020 19:18:30 +0100 Subject: [PATCH] command macros that will hopefully work --- Cargo.lock | 12 +- Cargo.toml | 9 +- regex_command_attr/Cargo.toml | 14 ++ regex_command_attr/src/attributes.rs | 293 +++++++++++++++++++++++++++ regex_command_attr/src/consts.rs | 5 + regex_command_attr/src/lib.rs | 113 +++++++++++ regex_command_attr/src/structures.rs | 235 +++++++++++++++++++++ regex_command_attr/src/util.rs | 242 ++++++++++++++++++++++ src/framework.rs | 70 ++----- src/main.rs | 14 +- 10 files changed, 938 insertions(+), 69 deletions(-) create mode 100644 regex_command_attr/Cargo.toml create mode 100644 regex_command_attr/src/attributes.rs create mode 100644 regex_command_attr/src/consts.rs create mode 100644 regex_command_attr/src/lib.rs create mode 100644 regex_command_attr/src/structures.rs create mode 100644 regex_command_attr/src/util.rs diff --git a/Cargo.lock b/Cargo.lock index c755cef..c8f888b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1058,12 +1058,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26412eb97c6b088a6997e05f69403a802a92d520de2f8e63c2b65f9e0f47c4e8" [[package]] -name = "reminder-rs" +name = "regex_command_attr" +version = "0.2.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "reminder_rs" version = "0.1.0" dependencies = [ "async-trait", "dotenv", "regex", + "regex_command_attr", "reqwest", "serenity", "sqlx", diff --git a/Cargo.toml b/Cargo.toml index 84661b6..1df279e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,9 @@ [package] -name = "reminder-rs" +name = "reminder_rs" version = "0.1.0" -authors = ["jude "] +authors = ["jellywx "] edition = "2018" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] serenity = {git = "https://github.com/acdenisSK/serenity", branch = "await_next"} dotenv = "0.15" @@ -14,3 +12,6 @@ reqwest = "0.10.6" sqlx = {version = "0.3.5", default-features = false, features = ["runtime-tokio", "macros", "mysql", "bigdecimal"]} regex = "1.3.9" async-trait = "0.1.36" + +[dependencies.regex_command_attr] +path = "./regex_command_attr" diff --git a/regex_command_attr/Cargo.toml b/regex_command_attr/Cargo.toml new file mode 100644 index 0000000..dbe01d1 --- /dev/null +++ b/regex_command_attr/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "regex_command_attr" +version = "0.2.0" +authors = ["acdenisSK ", "jellywx "] +edition = "2018" +description = "Procedural macros for command creation for the RegexFramework for serenity." + +[lib] +proc-macro = true + +[dependencies] +quote = "^1.0" +syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] } +proc-macro2 = "1.0" diff --git a/regex_command_attr/src/attributes.rs b/regex_command_attr/src/attributes.rs new file mode 100644 index 0000000..d4c2a27 --- /dev/null +++ b/regex_command_attr/src/attributes.rs @@ -0,0 +1,293 @@ +use proc_macro2::Span; +use syn::parse::{Error, Result}; +use syn::spanned::Spanned; +use syn::{Attribute, Ident, Lit, LitStr, Meta, NestedMeta, Path}; + +use crate::structures::PermissionLevel; +use crate::util::{AsOption, LitExt}; + +use std::fmt::{self, Write}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ValueKind { + // #[] + Name, + + // #[ = ] + Equals, + + // #[([, , , ...])] + List, + + // #[()] + SingleList, +} + +impl fmt::Display for ValueKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ValueKind::Name => f.pad("`#[]`"), + ValueKind::Equals => f.pad("`#[ = ]`"), + ValueKind::List => f.pad("`#[([, , , ...])]`"), + ValueKind::SingleList => f.pad("`#[()]`"), + } + } +} + +fn to_ident(p: Path) -> Result { + if p.segments.is_empty() { + return Err(Error::new( + p.span(), + "cannot convert an empty path to an identifier", + )); + } + + if p.segments.len() > 1 { + return Err(Error::new( + p.span(), + "the path must not have more than one segment", + )); + } + + if !p.segments[0].arguments.is_empty() { + return Err(Error::new( + p.span(), + "the singular path segment must not have any arguments", + )); + } + + Ok(p.segments[0].ident.clone()) +} + +#[derive(Debug)] +pub struct Values { + pub name: Ident, + pub literals: Vec, + pub kind: ValueKind, + pub span: Span, +} + +impl Values { + #[inline] + pub fn new(name: Ident, kind: ValueKind, literals: Vec, span: Span) -> Self { + Values { + name, + literals, + kind, + span, + } + } +} + +pub fn parse_values(attr: &Attribute) -> Result { + let meta = attr.parse_meta()?; + + match meta { + Meta::Path(path) => { + let name = to_ident(path)?; + + Ok(Values::new(name, ValueKind::Name, Vec::new(), attr.span())) + } + Meta::List(meta) => { + let name = to_ident(meta.path)?; + let nested = meta.nested; + + if nested.is_empty() { + return Err(Error::new(attr.span(), "list cannot be empty")); + } + + let mut lits = Vec::with_capacity(nested.len()); + + for meta in nested { + match meta { + NestedMeta::Lit(l) => lits.push(l), + NestedMeta::Meta(m) => match m { + Meta::Path(path) => { + let i = to_ident(path)?; + lits.push(Lit::Str(LitStr::new(&i.to_string(), i.span()))) + } + Meta::List(_) | Meta::NameValue(_) => { + return Err(Error::new(attr.span(), "cannot nest a list; only accept literals and identifiers at this level")) + } + }, + } + } + + let kind = if lits.len() == 1 { + ValueKind::SingleList + } else { + ValueKind::List + }; + + Ok(Values::new(name, kind, lits, attr.span())) + } + Meta::NameValue(meta) => { + let name = to_ident(meta.path)?; + let lit = meta.lit; + + Ok(Values::new(name, ValueKind::Equals, vec![lit], attr.span())) + } + } +} + +#[derive(Debug, Clone)] +struct DisplaySlice<'a, T>(&'a [T]); + +impl<'a, T: fmt::Display> fmt::Display for DisplaySlice<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut iter = self.0.iter().enumerate(); + + match iter.next() { + None => f.write_str("nothing")?, + Some((idx, elem)) => { + write!(f, "{}: {}", idx, elem)?; + + for (idx, elem) in iter { + f.write_char('\n')?; + write!(f, "{}: {}", idx, elem)?; + } + } + } + + Ok(()) + } +} + +#[inline] +fn is_form_acceptable(expect: &[ValueKind], kind: ValueKind) -> bool { + if expect.contains(&ValueKind::List) && kind == ValueKind::SingleList { + true + } else { + expect.contains(&kind) + } +} + +#[inline] +fn validate(values: &Values, forms: &[ValueKind]) -> Result<()> { + if !is_form_acceptable(forms, values.kind) { + return Err(Error::new( + values.span, + // Using the `_args` version here to avoid an allocation. + format_args!( + "the attribute must be in of these forms:\n{}", + DisplaySlice(forms) + ), + )); + } + + Ok(()) +} + +#[inline] +pub fn parse(values: Values) -> Result { + T::parse(values) +} + +pub trait AttributeOption: Sized { + fn parse(values: Values) -> Result; +} + +impl AttributeOption for Vec { + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::List])?; + + Ok(values + .literals + .into_iter() + .map(|lit| lit.to_str()) + .collect()) + } +} + +impl AttributeOption for String { + #[inline] + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::Equals, ValueKind::SingleList])?; + + Ok(values.literals[0].to_str()) + } +} + +impl AttributeOption for bool { + #[inline] + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::Name, ValueKind::SingleList])?; + + Ok(values.literals.get(0).map_or(true, |l| l.to_bool())) + } +} + +impl AttributeOption for Ident { + #[inline] + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::SingleList])?; + + Ok(values.literals[0].to_ident()) + } +} + +impl AttributeOption for Vec { + #[inline] + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::List])?; + + Ok(values.literals.into_iter().map(|l| l.to_ident()).collect()) + } +} + +impl AttributeOption for Option { + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList])?; + + Ok(values.literals.get(0).map(|l| l.to_str())) + } +} + +impl AttributeOption for PermissionLevel { + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::SingleList])?; + + Ok(values.literals.get(0).map(|l| PermissionLevel::from_str(&*l.to_str()).unwrap()).unwrap()) + } +} + +impl AttributeOption for AsOption { + #[inline] + fn parse(values: Values) -> Result { + Ok(AsOption(Some(T::parse(values)?))) + } +} + +macro_rules! attr_option_num { + ($($n:ty),*) => { + $( + impl AttributeOption for $n { + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::SingleList])?; + + Ok(match &values.literals[0] { + Lit::Int(l) => l.base10_parse::<$n>()?, + l => { + let s = l.to_str(); + // Use `as_str` to guide the compiler to use `&str`'s parse method. + // We don't want to use our `parse` method here (`impl AttributeOption for String`). + match s.as_str().parse::<$n>() { + Ok(n) => n, + Err(_) => return Err(Error::new(l.span(), "invalid integer")), + } + } + }) + } + } + + impl AttributeOption for Option<$n> { + #[inline] + fn parse(values: Values) -> Result { + <$n as AttributeOption>::parse(values).map(Some) + } + } + )* + } +} + +attr_option_num!(u16, u32, usize); diff --git a/regex_command_attr/src/consts.rs b/regex_command_attr/src/consts.rs new file mode 100644 index 0000000..94ca381 --- /dev/null +++ b/regex_command_attr/src/consts.rs @@ -0,0 +1,5 @@ +pub mod suffixes { + pub const COMMAND: &str = "COMMAND"; +} + +pub use self::suffixes::*; diff --git a/regex_command_attr/src/lib.rs b/regex_command_attr/src/lib.rs new file mode 100644 index 0000000..6dbf7c6 --- /dev/null +++ b/regex_command_attr/src/lib.rs @@ -0,0 +1,113 @@ +#![deny(rust_2018_idioms)] +// FIXME: Remove this in a foreseeable future. +// Currently exists for backwards compatibility to previous Rust versions. +#![recursion_limit = "128"] + +#[allow(unused_extern_crates)] +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + parse::Error, + parse_macro_input, parse_quote, + spanned::Spanned, + Lit, +}; + +pub(crate) mod attributes; +pub(crate) mod consts; +pub(crate) mod structures; + +#[macro_use] +pub(crate) mod util; + +use attributes::*; +use consts::*; +use structures::*; +use util::*; + +macro_rules! match_options { + ($v:expr, $values:ident, $options:ident, $span:expr => [$($name:ident);*]) => { + match $v { + $( + stringify!($name) => $options.$name = propagate_err!($crate::attributes::parse($values)), + )* + _ => { + return Error::new($span, format_args!("invalid attribute: {:?}", $v)) + .to_compile_error() + .into(); + }, + } + }; +} + +#[proc_macro_attribute] +pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { + let mut fun = parse_macro_input!(input as CommandFun); + + let lit_name = if !attr.is_empty() { + parse_macro_input!(attr as Lit).to_str() + } else { + fun.name.to_string() + }; + + let mut options = Options::new(); + + for attribute in &fun.attributes { + let span = attribute.span(); + let values = propagate_err!(parse_values(attribute)); + + let name = values.name.to_string(); + let name = &name[..]; + + match_options!(name, values, options, span => [ + permission_level; + supports_dm; + can_blacklist + ]); + } + + let Options { + permission_level, + supports_dm, + can_blacklist, + } = options; + + propagate_err!(create_declaration_validations(&mut fun, DeclarFor::Command)); + + let res = parse_quote!(serenity::framework::standard::CommandResult); + create_return_type_validation(&mut fun, res); + + let visibility = fun.visibility; + let name = fun.name.clone(); + let body = fun.body; + let ret = fun.ret; + + let n = name.with_suffix(COMMAND); + + let cooked = fun.cooked.clone(); + + let command_path = quote!(crate::framework::Command); + + populate_fut_lifetimes_on_refs(&mut fun.args); + let args = fun.args; + + (quote! { + #(#cooked)* + pub static #n: #command_path = #command_path { + func: #name, + name: #lit_name, + required_perms: #permission_level, + supports_dm: #supports_dm, + can_blacklist: #can_blacklist, + }; + + #visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> { + use ::serenity::futures::future::FutureExt; + + async move { #(#body)* }.boxed() + } + }) + .into() +} diff --git a/regex_command_attr/src/structures.rs b/regex_command_attr/src/structures.rs new file mode 100644 index 0000000..6f8d45d --- /dev/null +++ b/regex_command_attr/src/structures.rs @@ -0,0 +1,235 @@ +use crate::util::{Argument, Parenthesised}; +use proc_macro2::Span; +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, ToTokens}; +use syn::{ + braced, + parse::{Error, Parse, ParseStream, Result}, + spanned::Spanned, + Attribute, Block, FnArg, Ident, Pat, Path, PathSegment, ReturnType, Stmt, + Token, Type, Visibility, +}; + +fn parse_argument(arg: FnArg) -> Result { + match arg { + FnArg::Typed(typed) => { + let pat = typed.pat; + let kind = typed.ty; + + match *pat { + Pat::Ident(id) => { + let name = id.ident; + let mutable = id.mutability; + + Ok(Argument { + mutable, + name, + kind: *kind, + }) + } + Pat::Wild(wild) => { + let token = wild.underscore_token; + + let name = Ident::new("_", token.spans[0]); + + Ok(Argument { + mutable: None, + name, + kind: *kind, + }) + } + _ => Err(Error::new( + pat.span(), + format_args!("unsupported pattern: {:?}", pat), + )), + } + } + FnArg::Receiver(_) => Err(Error::new( + arg.span(), + format_args!("`self` arguments are prohibited: {:?}", arg), + )), + } +} + +/// Test if the attribute is cooked. +fn is_cooked(attr: &Attribute) -> bool { + const COOKED_ATTRIBUTE_NAMES: &[&str] = &[ + "cfg", + "cfg_attr", + "doc", + "derive", + "inline", + "allow", + "warn", + "deny", + "forbid", + ]; + + COOKED_ATTRIBUTE_NAMES.iter().any(|n| attr.path.is_ident(n)) +} + +/// Removes cooked attributes from a vector of attributes. Uncooked attributes are left in the vector. +/// +/// # Return +/// +/// Returns a vector of cooked attributes that have been removed from the input vector. +fn remove_cooked(attrs: &mut Vec) -> Vec { + let mut cooked = Vec::new(); + + // FIXME: Replace with `Vec::drain_filter` once it is stable. + let mut i = 0; + while i < attrs.len() { + if !is_cooked(&attrs[i]) { + i += 1; + continue; + } + + cooked.push(attrs.remove(i)); + } + + cooked +} + +#[derive(Debug)] +pub struct CommandFun { + /// `#[...]`-style attributes. + pub attributes: Vec, + /// Populated cooked attributes. These are attributes outside of the realm of this crate's procedural macros + /// and will appear in generated output. + pub cooked: Vec, + pub visibility: Visibility, + pub name: Ident, + pub args: Vec, + pub ret: Type, + pub body: Vec, +} + +impl Parse for CommandFun { + fn parse(input: ParseStream<'_>) -> Result { + let mut attributes = input.call(Attribute::parse_outer)?; + + // `#[doc = "..."]` is a cooked attribute but it is special-cased for commands. + for attr in &mut attributes { + // Rename documentation comment attributes (`#[doc = "..."]`) to `#[description = "..."]`. + if attr.path.is_ident("doc") { + attr.path = Path::from(PathSegment::from(Ident::new( + "description", + Span::call_site(), + ))); + } + } + + let cooked = remove_cooked(&mut attributes); + + let visibility = input.parse::()?; + + input.parse::()?; + + input.parse::()?; + let name = input.parse()?; + + // (...) + let Parenthesised(args) = input.parse::>()?; + + let ret = match input.parse::()? { + ReturnType::Type(_, t) => (*t).clone(), + ReturnType::Default => { + return Err(input + .error("expected a result type of either `CommandResult` or `CheckResult`")) + } + }; + + // { ... } + let bcont; + braced!(bcont in input); + let body = bcont.call(Block::parse_within)?; + + let args = args + .into_iter() + .map(parse_argument) + .collect::>>()?; + + Ok(Self { + attributes, + cooked, + visibility, + name, + args, + ret, + body, + }) + } +} + +impl ToTokens for CommandFun { + fn to_tokens(&self, stream: &mut TokenStream2) { + let Self { + attributes: _, + cooked, + visibility, + name, + args, + ret, + body, + } = self; + + stream.extend(quote! { + #(#cooked)* + #visibility async fn #name (#(#args),*) -> #ret { + #(#body)* + } + }); + } +} + +#[derive(Debug)] +pub enum PermissionLevel { + Unrestricted, + Managed, + Restricted, +} + +impl Default for PermissionLevel { + fn default() -> Self { + Self::Unrestricted + } +} + +impl PermissionLevel { + pub fn from_str(s: &str) -> Option { + Some(match s.to_uppercase().as_str() { + "UNRESTRICTED" => Self::Unrestricted, + "MANAGED" => Self::Managed, + "RESTRICTED" => Self::Restricted, + _ => return None, + }) + } +} + +impl ToTokens for PermissionLevel { + fn to_tokens(&self, stream: &mut TokenStream2) { + let path = quote!(crate::framework::PermissionLevel); + + stream.extend(quote! { + #path::Unrestricted + }); + } +} + +#[derive(Debug, Default)] +pub struct Options { + pub permission_level: PermissionLevel, + pub supports_dm: bool, + pub can_blacklist: bool, +} + +impl Options { + #[inline] + pub fn new() -> Self { + let mut options = Self::default(); + + options.can_blacklist = true; + + options + } +} diff --git a/regex_command_attr/src/util.rs b/regex_command_attr/src/util.rs new file mode 100644 index 0000000..b219201 --- /dev/null +++ b/regex_command_attr/src/util.rs @@ -0,0 +1,242 @@ +use crate::structures::CommandFun; +use proc_macro::TokenStream; +use proc_macro2::Span; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + braced, bracketed, parenthesized, + parse::{Error, Parse, ParseStream, Result as SynResult}, + parse_quote, + punctuated::Punctuated, + spanned::Spanned, + token::{Comma, Mut}, + Ident, Lifetime, Lit, Type, +}; + +pub trait LitExt { + fn to_str(&self) -> String; + fn to_bool(&self) -> bool; + fn to_ident(&self) -> Ident; +} + +impl LitExt for Lit { + fn to_str(&self) -> String { + match self { + Lit::Str(s) => s.value(), + Lit::ByteStr(s) => unsafe { String::from_utf8_unchecked(s.value()) }, + Lit::Char(c) => c.value().to_string(), + Lit::Byte(b) => (b.value() as char).to_string(), + _ => panic!("values must be a (byte)string or a char"), + } + } + + fn to_bool(&self) -> bool { + if let Lit::Bool(b) = self { + b.value + } else { + self.to_str() + .parse() + .unwrap_or_else(|_| panic!("expected bool from {:?}", self)) + } + } + + #[inline] + fn to_ident(&self) -> Ident { + Ident::new(&self.to_str(), self.span()) + } +} + +pub trait IdentExt2: Sized { + fn to_uppercase(&self) -> Self; + fn with_suffix(&self, suf: &str) -> Ident; +} + +impl IdentExt2 for Ident { + #[inline] + fn to_uppercase(&self) -> Self { + format_ident!("{}", self.to_string().to_uppercase()) + } + + #[inline] + fn with_suffix(&self, suffix: &str) -> Ident { + format_ident!("{}_{}", self.to_string().to_uppercase(), suffix) + } +} + +#[inline] +pub fn into_stream(e: Error) -> TokenStream { + e.to_compile_error().into() +} + +macro_rules! propagate_err { + ($res:expr) => {{ + match $res { + Ok(v) => v, + Err(e) => return $crate::util::into_stream(e), + } + }}; +} + +#[derive(Debug)] +pub struct Bracketed(pub Punctuated); + +impl Parse for Bracketed { + fn parse(input: ParseStream<'_>) -> SynResult { + let content; + bracketed!(content in input); + + Ok(Bracketed(content.parse_terminated(T::parse)?)) + } +} + +#[derive(Debug)] +pub struct Braced(pub Punctuated); + +impl Parse for Braced { + fn parse(input: ParseStream<'_>) -> SynResult { + let content; + braced!(content in input); + + Ok(Braced(content.parse_terminated(T::parse)?)) + } +} + +#[derive(Debug)] +pub struct Parenthesised(pub Punctuated); + +impl Parse for Parenthesised { + fn parse(input: ParseStream<'_>) -> SynResult { + let content; + parenthesized!(content in input); + + Ok(Parenthesised(content.parse_terminated(T::parse)?)) + } +} + +#[derive(Debug)] +pub struct AsOption(pub Option); + +impl ToTokens for AsOption { + fn to_tokens(&self, stream: &mut TokenStream2) { + match &self.0 { + Some(o) => stream.extend(quote!(Some(#o))), + None => stream.extend(quote!(None)), + } + } +} + +impl Default for AsOption { + #[inline] + fn default() -> Self { + AsOption(None) + } +} + +#[derive(Debug)] +pub struct Argument { + pub mutable: Option, + pub name: Ident, + pub kind: Type, +} + +impl ToTokens for Argument { + fn to_tokens(&self, stream: &mut TokenStream2) { + let Argument { + mutable, + name, + kind, + } = self; + + stream.extend(quote! { + #mutable #name: #kind + }); + } +} + +#[inline] +pub fn generate_type_validation(have: Type, expect: Type) -> syn::Stmt { + parse_quote! { + serenity::static_assertions::assert_type_eq_all!(#have, #expect); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DeclarFor { + Command, + Help, + Check, +} + +pub fn create_declaration_validations(fun: &mut CommandFun, dec_for: DeclarFor) -> SynResult<()> { + let len = match dec_for { + DeclarFor::Command => 3, + DeclarFor::Help => 6, + DeclarFor::Check => 4, + }; + + if fun.args.len() > len { + return Err(Error::new( + fun.args.last().unwrap().span(), + format_args!("function's arity exceeds more than {} arguments", len), + )); + } + + let context: Type = parse_quote!(&serenity::client::Context); + let message: Type = parse_quote!(&serenity::model::channel::Message); + let args: Type = parse_quote!(serenity::framework::standard::Args); + let args2: Type = parse_quote!(&mut serenity::framework::standard::Args); + let options: Type = parse_quote!(&serenity::framework::standard::CommandOptions); + let hoptions: Type = parse_quote!(&'static serenity::framework::standard::HelpOptions); + let groups: Type = parse_quote!(&[&'static serenity::framework::standard::CommandGroup]); + let owners: Type = parse_quote!(std::collections::HashSet); + + let mut index = 0; + + let mut spoof_or_check = |kind: Type, name: &str| { + match fun.args.get(index) { + Some(x) => fun.body.insert(0, generate_type_validation(x.kind.clone(), kind)), + None => fun.args.push(Argument { + mutable: None, + name: Ident::new(name, Span::call_site()), + kind, + }), + } + + index += 1; + }; + + spoof_or_check(context, "_ctx"); + spoof_or_check(message, "_msg"); + + if dec_for == DeclarFor::Check { + spoof_or_check(args2, "_args"); + spoof_or_check(options, "_options"); + + return Ok(()); + } + + spoof_or_check(args, "_args"); + + if dec_for == DeclarFor::Help { + spoof_or_check(hoptions, "_hoptions"); + spoof_or_check(groups, "_groups"); + spoof_or_check(owners, "_owners"); + } + + Ok(()) +} + +#[inline] +pub fn create_return_type_validation(r#fn: &mut CommandFun, expect: Type) { + let stmt = generate_type_validation(r#fn.ret.clone(), expect); + r#fn.body.insert(0, stmt); +} + +#[inline] +pub fn populate_fut_lifetimes_on_refs(args: &mut Vec) { + for arg in args { + if let Type::Reference(reference) = &mut arg.kind { + reference.lifetime = Some(Lifetime::new("'fut", Span::call_site())); + } + } +} diff --git a/src/framework.rs b/src/framework.rs index 3c3dc45..5d85aec 100644 --- a/src/framework.rs +++ b/src/framework.rs @@ -6,13 +6,7 @@ use serenity::{ model::channel::Message, }; -use std::{ - collections::HashSet, - hash::{ - Hash, - Hasher - }, -}; +use std::collections::HashMap; use serenity::framework::standard::CommandFn; @@ -23,63 +17,25 @@ pub enum PermissionLevel { } pub struct Command { - name: String, - required_perms: PermissionLevel, - can_blacklist: bool, - supports_dm: bool, - func: CommandFn, + pub name: &'static str, + pub required_perms: PermissionLevel, + pub can_blacklist: bool, + pub supports_dm: bool, + pub func: CommandFn, } -impl Hash for Command { - fn hash(&self, state: &mut H) { - self.name.hash(state); - } -} - -impl PartialEq for Command { - fn eq(&self, other: &Self) -> bool { - self.name == other.name - } -} - -impl Eq for Command {} - // create event handler for bot pub struct RegexFramework { - commands: HashSet, + commands: HashMap, command_names: String, default_prefix: String, ignore_bots: bool, } -impl Command { - pub fn from(name: &str, required_perms: PermissionLevel, func: CommandFn) -> Self { - Command { - name: name.to_string(), - required_perms, - can_blacklist: true, - supports_dm: false, - func, - } - } - - pub fn can_blacklist(&mut self, can_blacklist: bool) -> &mut Self { - self.can_blacklist = can_blacklist; - - self - } - - pub fn supports_dm(&mut self, supports_dm: bool) -> &mut Self { - self.supports_dm = supports_dm; - - self - } -} - impl RegexFramework { pub fn new() -> Self { Self { - commands: HashSet::new(), + commands: HashMap::new(), command_names: String::new(), default_prefix: String::from("$"), ignore_bots: true, @@ -98,17 +54,17 @@ impl RegexFramework { self } - pub fn add_command(mut self, command: Command) -> Self { - self.commands.insert(command); + pub fn add_command(mut self, name: String, command: &'static Command) -> Self { + self.commands.insert(name, command); self } pub fn build(mut self) -> Self { self.command_names = self.commands - .iter() - .map(|c| c.name.clone()) - .collect::>() + .keys() + .map(|k| &k[..]) + .collect::>() .join("|"); self diff --git a/src/main.rs b/src/main.rs index 6e4d3cc..223dc00 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,13 +12,12 @@ use serenity::{ }, framework::standard::{ Args, CommandResult, - macros::{ - command, - } }, prelude::TypeMapKey, }; +use regex_command_attr::command; + use sqlx::{ Pool, mysql::{ @@ -33,7 +32,7 @@ use std::{ env, }; -use crate::framework::{RegexFramework, Command, PermissionLevel}; +use crate::framework::RegexFramework; struct SQLPool; @@ -56,13 +55,14 @@ async fn main() -> Result<(), Box> { let framework = RegexFramework::new() .ignore_bots(true) .default_prefix("$") - .add_command(Command::from("help", PermissionLevel::Unrestricted, help_command)) + .add_command("help".to_string(), &HELP_COMMAND) + .add_command("h".to_string(), &HELP_COMMAND) .build(); let mut client = Client::new(&env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment")) .intents(GatewayIntents::GUILD_MESSAGES | GatewayIntents::GUILDS | GatewayIntents::DIRECT_MESSAGES) .framework(framework) - .await.expect("Error occured creating client"); + .await.expect("Error occurred creating client"); client.start_autosharded().await?; @@ -70,7 +70,7 @@ async fn main() -> Result<(), Box> { } #[command] -async fn help_command(_ctx: &Context, _msg: &Message, _args: Args) -> CommandResult { +async fn help(_ctx: &Context, _msg: &Message, _args: Args) -> CommandResult { println!("Help command called"); Ok(())