diff --git a/Cargo.lock b/Cargo.lock index ce837a8..11fb074 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1470,6 +1470,15 @@ version = "0.6.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5eb417147ba9860a96cfe72a0b93bf88fee1744b5636ec99ab20c1aa9376581" +[[package]] +name = "regex_command_attr" +version = "0.2.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "remove_dir_all" version = "0.5.3" @@ -1889,6 +1898,9 @@ version = "1.1.0" dependencies = [ "dotenv", "lazy_static", + "log 0.4.13", + "regex", + "regex_command_attr", "reqwest", "serenity", "songbird", diff --git a/Cargo.toml b/Cargo.toml index fc80d19..dfba20f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,8 @@ tokio = { version = "1.0", features = ["fs", "process", "io-util"] } lazy_static = "1.4" reqwest = "0.11" songbird = "0.1" +regex = "1.4" +log = "0.4" + +[dependencies.regex_command_attr] +path = "./regex_command_attr" diff --git a/regex_command_attr/Cargo.toml b/regex_command_attr/Cargo.toml new file mode 100644 index 0000000..dbe01d1 --- /dev/null +++ b/regex_command_attr/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "regex_command_attr" +version = "0.2.0" +authors = ["acdenisSK ", "jellywx "] +edition = "2018" +description = "Procedural macros for command creation for the RegexFramework for serenity." + +[lib] +proc-macro = true + +[dependencies] +quote = "^1.0" +syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] } +proc-macro2 = "1.0" diff --git a/regex_command_attr/src/attributes.rs b/regex_command_attr/src/attributes.rs new file mode 100644 index 0000000..d1b6386 --- /dev/null +++ b/regex_command_attr/src/attributes.rs @@ -0,0 +1,300 @@ +use proc_macro2::Span; +use syn::parse::{Error, Result}; +use syn::spanned::Spanned; +use syn::{Attribute, Ident, Lit, LitStr, Meta, NestedMeta, Path}; + +use crate::structures::PermissionLevel; +use crate::util::{AsOption, LitExt}; + +use std::fmt::{self, Write}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ValueKind { + // #[] + Name, + + // #[ = ] + Equals, + + // #[([, , , ...])] + List, + + // #[()] + SingleList, +} + +impl fmt::Display for ValueKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ValueKind::Name => f.pad("`#[]`"), + ValueKind::Equals => f.pad("`#[ = ]`"), + ValueKind::List => f.pad("`#[([, , , ...])]`"), + ValueKind::SingleList => f.pad("`#[()]`"), + } + } +} + +fn to_ident(p: Path) -> Result { + if p.segments.is_empty() { + return Err(Error::new( + p.span(), + "cannot convert an empty path to an identifier", + )); + } + + if p.segments.len() > 1 { + return Err(Error::new( + p.span(), + "the path must not have more than one segment", + )); + } + + if !p.segments[0].arguments.is_empty() { + return Err(Error::new( + p.span(), + "the singular path segment must not have any arguments", + )); + } + + Ok(p.segments[0].ident.clone()) +} + +#[derive(Debug)] +pub struct Values { + pub name: Ident, + pub literals: Vec, + pub kind: ValueKind, + pub span: Span, +} + +impl Values { + #[inline] + pub fn new(name: Ident, kind: ValueKind, literals: Vec, span: Span) -> Self { + Values { + name, + literals, + kind, + span, + } + } +} + +pub fn parse_values(attr: &Attribute) -> Result { + let meta = attr.parse_meta()?; + + match meta { + Meta::Path(path) => { + let name = to_ident(path)?; + + Ok(Values::new(name, ValueKind::Name, Vec::new(), attr.span())) + } + Meta::List(meta) => { + let name = to_ident(meta.path)?; + let nested = meta.nested; + + if nested.is_empty() { + return Err(Error::new(attr.span(), "list cannot be empty")); + } + + let mut lits = Vec::with_capacity(nested.len()); + + for meta in nested { + match meta { + NestedMeta::Lit(l) => lits.push(l), + NestedMeta::Meta(m) => match m { + Meta::Path(path) => { + let i = to_ident(path)?; + lits.push(Lit::Str(LitStr::new(&i.to_string(), i.span()))) + } + Meta::List(_) | Meta::NameValue(_) => { + return Err(Error::new(attr.span(), "cannot nest a list; only accept literals and identifiers at this level")) + } + }, + } + } + + let kind = if lits.len() == 1 { + ValueKind::SingleList + } else { + ValueKind::List + }; + + Ok(Values::new(name, kind, lits, attr.span())) + } + Meta::NameValue(meta) => { + let name = to_ident(meta.path)?; + let lit = meta.lit; + + Ok(Values::new(name, ValueKind::Equals, vec![lit], attr.span())) + } + } +} + +#[derive(Debug, Clone)] +struct DisplaySlice<'a, T>(&'a [T]); + +impl<'a, T: fmt::Display> fmt::Display for DisplaySlice<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut iter = self.0.iter().enumerate(); + + match iter.next() { + None => f.write_str("nothing")?, + Some((idx, elem)) => { + write!(f, "{}: {}", idx, elem)?; + + for (idx, elem) in iter { + f.write_char('\n')?; + write!(f, "{}: {}", idx, elem)?; + } + } + } + + Ok(()) + } +} + +#[inline] +fn is_form_acceptable(expect: &[ValueKind], kind: ValueKind) -> bool { + if expect.contains(&ValueKind::List) && kind == ValueKind::SingleList { + true + } else { + expect.contains(&kind) + } +} + +#[inline] +fn validate(values: &Values, forms: &[ValueKind]) -> Result<()> { + if !is_form_acceptable(forms, values.kind) { + return Err(Error::new( + values.span, + // Using the `_args` version here to avoid an allocation. + format_args!( + "the attribute must be in of these forms:\n{}", + DisplaySlice(forms) + ), + )); + } + + Ok(()) +} + +#[inline] +pub fn parse(values: Values) -> Result { + T::parse(values) +} + +pub trait AttributeOption: Sized { + fn parse(values: Values) -> Result; +} + +impl AttributeOption for Vec { + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::List])?; + + Ok(values + .literals + .into_iter() + .map(|lit| lit.to_str()) + .collect()) + } +} + +impl AttributeOption for String { + #[inline] + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::Equals, ValueKind::SingleList])?; + + Ok(values.literals[0].to_str()) + } +} + +impl AttributeOption for bool { + #[inline] + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::Name, ValueKind::SingleList])?; + + Ok(values.literals.get(0).map_or(true, |l| l.to_bool())) + } +} + +impl AttributeOption for Ident { + #[inline] + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::SingleList])?; + + Ok(values.literals[0].to_ident()) + } +} + +impl AttributeOption for Vec { + #[inline] + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::List])?; + + Ok(values.literals.into_iter().map(|l| l.to_ident()).collect()) + } +} + +impl AttributeOption for Option { + fn parse(values: Values) -> Result { + validate( + &values, + &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList], + )?; + + Ok(values.literals.get(0).map(|l| l.to_str())) + } +} + +impl AttributeOption for PermissionLevel { + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::SingleList])?; + + Ok(values + .literals + .get(0) + .map(|l| PermissionLevel::from_str(&*l.to_str()).unwrap()) + .unwrap()) + } +} + +impl AttributeOption for AsOption { + #[inline] + fn parse(values: Values) -> Result { + Ok(AsOption(Some(T::parse(values)?))) + } +} + +macro_rules! attr_option_num { + ($($n:ty),*) => { + $( + impl AttributeOption for $n { + fn parse(values: Values) -> Result { + validate(&values, &[ValueKind::SingleList])?; + + Ok(match &values.literals[0] { + Lit::Int(l) => l.base10_parse::<$n>()?, + l => { + let s = l.to_str(); + // Use `as_str` to guide the compiler to use `&str`'s parse method. + // We don't want to use our `parse` method here (`impl AttributeOption for String`). + match s.as_str().parse::<$n>() { + Ok(n) => n, + Err(_) => return Err(Error::new(l.span(), "invalid integer")), + } + } + }) + } + } + + impl AttributeOption for Option<$n> { + #[inline] + fn parse(values: Values) -> Result { + <$n as AttributeOption>::parse(values).map(Some) + } + } + )* + } +} + +attr_option_num!(u16, u32, usize); diff --git a/regex_command_attr/src/consts.rs b/regex_command_attr/src/consts.rs new file mode 100644 index 0000000..94ca381 --- /dev/null +++ b/regex_command_attr/src/consts.rs @@ -0,0 +1,5 @@ +pub mod suffixes { + pub const COMMAND: &str = "COMMAND"; +} + +pub use self::suffixes::*; diff --git a/regex_command_attr/src/lib.rs b/regex_command_attr/src/lib.rs new file mode 100644 index 0000000..e72451d --- /dev/null +++ b/regex_command_attr/src/lib.rs @@ -0,0 +1,100 @@ +#![deny(rust_2018_idioms)] +// FIXME: Remove this in a foreseeable future. +// Currently exists for backwards compatibility to previous Rust versions. +#![recursion_limit = "128"] + +#[allow(unused_extern_crates)] +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse::Error, parse_macro_input, parse_quote, spanned::Spanned, Lit}; + +pub(crate) mod attributes; +pub(crate) mod consts; +pub(crate) mod structures; + +#[macro_use] +pub(crate) mod util; + +use attributes::*; +use consts::*; +use structures::*; +use util::*; + +macro_rules! match_options { + ($v:expr, $values:ident, $options:ident, $span:expr => [$($name:ident);*]) => { + match $v { + $( + stringify!($name) => $options.$name = propagate_err!($crate::attributes::parse($values)), + )* + _ => { + return Error::new($span, format_args!("invalid attribute: {:?}", $v)) + .to_compile_error() + .into(); + }, + } + }; +} + +#[proc_macro_attribute] +pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream { + let mut fun = parse_macro_input!(input as CommandFun); + + let lit_name = if !attr.is_empty() { + parse_macro_input!(attr as Lit).to_str() + } else { + fun.name.to_string() + }; + + let mut options = Options::new(); + + for attribute in &fun.attributes { + let span = attribute.span(); + let values = propagate_err!(parse_values(attribute)); + + let name = values.name.to_string(); + let name = &name[..]; + + match_options!(name, values, options, span => [ + permission_level + ]); + } + + let Options { permission_level } = options; + + propagate_err!(create_declaration_validations(&mut fun, DeclarFor::Command)); + + let res = parse_quote!(serenity::framework::standard::CommandResult); + create_return_type_validation(&mut fun, res); + + let visibility = fun.visibility; + let name = fun.name.clone(); + let body = fun.body; + let ret = fun.ret; + + let n = name.with_suffix(COMMAND); + + let cooked = fun.cooked.clone(); + + let command_path = quote!(crate::framework::Command); + + populate_fut_lifetimes_on_refs(&mut fun.args); + let args = fun.args; + + (quote! { + #(#cooked)* + pub static #n: #command_path = #command_path { + func: #name, + name: #lit_name, + required_perms: #permission_level, + }; + + #visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> { + use ::serenity::futures::future::FutureExt; + + async move { #(#body)* }.boxed() + } + }) + .into() +} diff --git a/regex_command_attr/src/structures.rs b/regex_command_attr/src/structures.rs new file mode 100644 index 0000000..ed7952b --- /dev/null +++ b/regex_command_attr/src/structures.rs @@ -0,0 +1,236 @@ +use crate::util::{Argument, Parenthesised}; +use proc_macro2::Span; +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, ToTokens}; +use syn::{ + braced, + parse::{Error, Parse, ParseStream, Result}, + spanned::Spanned, + Attribute, Block, FnArg, Ident, Pat, Path, PathSegment, ReturnType, Stmt, Token, Type, + Visibility, +}; + +fn parse_argument(arg: FnArg) -> Result { + match arg { + FnArg::Typed(typed) => { + let pat = typed.pat; + let kind = typed.ty; + + match *pat { + Pat::Ident(id) => { + let name = id.ident; + let mutable = id.mutability; + + Ok(Argument { + mutable, + name, + kind: *kind, + }) + } + Pat::Wild(wild) => { + let token = wild.underscore_token; + + let name = Ident::new("_", token.spans[0]); + + Ok(Argument { + mutable: None, + name, + kind: *kind, + }) + } + _ => Err(Error::new( + pat.span(), + format_args!("unsupported pattern: {:?}", pat), + )), + } + } + FnArg::Receiver(_) => Err(Error::new( + arg.span(), + format_args!("`self` arguments are prohibited: {:?}", arg), + )), + } +} + +/// Test if the attribute is cooked. +fn is_cooked(attr: &Attribute) -> bool { + const COOKED_ATTRIBUTE_NAMES: &[&str] = &[ + "cfg", "cfg_attr", "doc", "derive", "inline", "allow", "warn", "deny", "forbid", + ]; + + COOKED_ATTRIBUTE_NAMES.iter().any(|n| attr.path.is_ident(n)) +} + +/// Removes cooked attributes from a vector of attributes. Uncooked attributes are left in the vector. +/// +/// # Return +/// +/// Returns a vector of cooked attributes that have been removed from the input vector. +fn remove_cooked(attrs: &mut Vec) -> Vec { + let mut cooked = Vec::new(); + + // FIXME: Replace with `Vec::drain_filter` once it is stable. + let mut i = 0; + while i < attrs.len() { + if !is_cooked(&attrs[i]) { + i += 1; + continue; + } + + cooked.push(attrs.remove(i)); + } + + cooked +} + +#[derive(Debug)] +pub struct CommandFun { + /// `#[...]`-style attributes. + pub attributes: Vec, + /// Populated cooked attributes. These are attributes outside of the realm of this crate's procedural macros + /// and will appear in generated output. + pub cooked: Vec, + pub visibility: Visibility, + pub name: Ident, + pub args: Vec, + pub ret: Type, + pub body: Vec, +} + +impl Parse for CommandFun { + fn parse(input: ParseStream<'_>) -> Result { + let mut attributes = input.call(Attribute::parse_outer)?; + + // `#[doc = "..."]` is a cooked attribute but it is special-cased for commands. + for attr in &mut attributes { + // Rename documentation comment attributes (`#[doc = "..."]`) to `#[description = "..."]`. + if attr.path.is_ident("doc") { + attr.path = Path::from(PathSegment::from(Ident::new( + "description", + Span::call_site(), + ))); + } + } + + let cooked = remove_cooked(&mut attributes); + + let visibility = input.parse::()?; + + input.parse::()?; + + input.parse::()?; + let name = input.parse()?; + + // (...) + let Parenthesised(args) = input.parse::>()?; + + let ret = match input.parse::()? { + ReturnType::Type(_, t) => (*t).clone(), + ReturnType::Default => { + return Err(input + .error("expected a result type of either `CommandResult` or `CheckResult`")) + } + }; + + // { ... } + let bcont; + braced!(bcont in input); + let body = bcont.call(Block::parse_within)?; + + let args = args + .into_iter() + .map(parse_argument) + .collect::>>()?; + + Ok(Self { + attributes, + cooked, + visibility, + name, + args, + ret, + body, + }) + } +} + +impl ToTokens for CommandFun { + fn to_tokens(&self, stream: &mut TokenStream2) { + let Self { + attributes: _, + cooked, + visibility, + name, + args, + ret, + body, + } = self; + + stream.extend(quote! { + #(#cooked)* + #visibility async fn #name (#(#args),*) -> #ret { + #(#body)* + } + }); + } +} + +#[derive(Debug)] +pub enum PermissionLevel { + Unrestricted, + Managed, + Restricted, +} + +impl Default for PermissionLevel { + fn default() -> Self { + Self::Unrestricted + } +} + +impl PermissionLevel { + pub fn from_str(s: &str) -> Option { + Some(match s.to_uppercase().as_str() { + "UNRESTRICTED" => Self::Unrestricted, + "MANAGED" => Self::Managed, + "RESTRICTED" => Self::Restricted, + _ => return None, + }) + } +} + +impl ToTokens for PermissionLevel { + fn to_tokens(&self, stream: &mut TokenStream2) { + let path = quote!(crate::framework::PermissionLevel); + let variant; + + match self { + Self::Unrestricted => { + variant = quote!(Unrestricted); + } + + Self::Managed => { + variant = quote!(Managed); + } + + Self::Restricted => { + variant = quote!(Restricted); + } + } + + stream.extend(quote! { + #path::#variant + }); + } +} + +#[derive(Debug, Default)] +pub struct Options { + pub permission_level: PermissionLevel, +} + +impl Options { + #[inline] + pub fn new() -> Self { + Self::default() + } +} diff --git a/regex_command_attr/src/util.rs b/regex_command_attr/src/util.rs new file mode 100644 index 0000000..0298135 --- /dev/null +++ b/regex_command_attr/src/util.rs @@ -0,0 +1,244 @@ +use crate::structures::CommandFun; +use proc_macro::TokenStream; +use proc_macro2::Span; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + braced, bracketed, parenthesized, + parse::{Error, Parse, ParseStream, Result as SynResult}, + parse_quote, + punctuated::Punctuated, + spanned::Spanned, + token::{Comma, Mut}, + Ident, Lifetime, Lit, Type, +}; + +pub trait LitExt { + fn to_str(&self) -> String; + fn to_bool(&self) -> bool; + fn to_ident(&self) -> Ident; +} + +impl LitExt for Lit { + fn to_str(&self) -> String { + match self { + Lit::Str(s) => s.value(), + Lit::ByteStr(s) => unsafe { String::from_utf8_unchecked(s.value()) }, + Lit::Char(c) => c.value().to_string(), + Lit::Byte(b) => (b.value() as char).to_string(), + _ => panic!("values must be a (byte)string or a char"), + } + } + + fn to_bool(&self) -> bool { + if let Lit::Bool(b) = self { + b.value + } else { + self.to_str() + .parse() + .unwrap_or_else(|_| panic!("expected bool from {:?}", self)) + } + } + + #[inline] + fn to_ident(&self) -> Ident { + Ident::new(&self.to_str(), self.span()) + } +} + +pub trait IdentExt2: Sized { + fn to_uppercase(&self) -> Self; + fn with_suffix(&self, suf: &str) -> Ident; +} + +impl IdentExt2 for Ident { + #[inline] + fn to_uppercase(&self) -> Self { + format_ident!("{}", self.to_string().to_uppercase()) + } + + #[inline] + fn with_suffix(&self, suffix: &str) -> Ident { + format_ident!("{}_{}", self.to_string().to_uppercase(), suffix) + } +} + +#[inline] +pub fn into_stream(e: Error) -> TokenStream { + e.to_compile_error().into() +} + +macro_rules! propagate_err { + ($res:expr) => {{ + match $res { + Ok(v) => v, + Err(e) => return $crate::util::into_stream(e), + } + }}; +} + +#[derive(Debug)] +pub struct Bracketed(pub Punctuated); + +impl Parse for Bracketed { + fn parse(input: ParseStream<'_>) -> SynResult { + let content; + bracketed!(content in input); + + Ok(Bracketed(content.parse_terminated(T::parse)?)) + } +} + +#[derive(Debug)] +pub struct Braced(pub Punctuated); + +impl Parse for Braced { + fn parse(input: ParseStream<'_>) -> SynResult { + let content; + braced!(content in input); + + Ok(Braced(content.parse_terminated(T::parse)?)) + } +} + +#[derive(Debug)] +pub struct Parenthesised(pub Punctuated); + +impl Parse for Parenthesised { + fn parse(input: ParseStream<'_>) -> SynResult { + let content; + parenthesized!(content in input); + + Ok(Parenthesised(content.parse_terminated(T::parse)?)) + } +} + +#[derive(Debug)] +pub struct AsOption(pub Option); + +impl ToTokens for AsOption { + fn to_tokens(&self, stream: &mut TokenStream2) { + match &self.0 { + Some(o) => stream.extend(quote!(Some(#o))), + None => stream.extend(quote!(None)), + } + } +} + +impl Default for AsOption { + #[inline] + fn default() -> Self { + AsOption(None) + } +} + +#[derive(Debug)] +pub struct Argument { + pub mutable: Option, + pub name: Ident, + pub kind: Type, +} + +impl ToTokens for Argument { + fn to_tokens(&self, stream: &mut TokenStream2) { + let Argument { + mutable, + name, + kind, + } = self; + + stream.extend(quote! { + #mutable #name: #kind + }); + } +} + +#[inline] +pub fn generate_type_validation(have: Type, expect: Type) -> syn::Stmt { + parse_quote! { + serenity::static_assertions::assert_type_eq_all!(#have, #expect); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DeclarFor { + Command, + Help, + Check, +} + +pub fn create_declaration_validations(fun: &mut CommandFun, dec_for: DeclarFor) -> SynResult<()> { + let len = match dec_for { + DeclarFor::Command => 3, + DeclarFor::Help => 6, + DeclarFor::Check => 4, + }; + + if fun.args.len() > len { + return Err(Error::new( + fun.args.last().unwrap().span(), + format_args!("function's arity exceeds more than {} arguments", len), + )); + } + + let context: Type = parse_quote!(&serenity::client::Context); + let message: Type = parse_quote!(&serenity::model::channel::Message); + let args: Type = parse_quote!(serenity::framework::standard::Args); + let args2: Type = parse_quote!(&mut serenity::framework::standard::Args); + let options: Type = parse_quote!(&serenity::framework::standard::CommandOptions); + let hoptions: Type = parse_quote!(&'static serenity::framework::standard::HelpOptions); + let groups: Type = parse_quote!(&[&'static serenity::framework::standard::CommandGroup]); + let owners: Type = parse_quote!(std::collections::HashSet); + + let mut index = 0; + + let mut spoof_or_check = |kind: Type, name: &str| { + match fun.args.get(index) { + Some(x) => fun + .body + .insert(0, generate_type_validation(x.kind.clone(), kind)), + None => fun.args.push(Argument { + mutable: None, + name: Ident::new(name, Span::call_site()), + kind, + }), + } + + index += 1; + }; + + spoof_or_check(context, "_ctx"); + spoof_or_check(message, "_msg"); + + if dec_for == DeclarFor::Check { + spoof_or_check(args2, "_args"); + spoof_or_check(options, "_options"); + + return Ok(()); + } + + spoof_or_check(args, "_args"); + + if dec_for == DeclarFor::Help { + spoof_or_check(hoptions, "_hoptions"); + spoof_or_check(groups, "_groups"); + spoof_or_check(owners, "_owners"); + } + + Ok(()) +} + +#[inline] +pub fn create_return_type_validation(r#fn: &mut CommandFun, expect: Type) { + let stmt = generate_type_validation(r#fn.ret.clone(), expect); + r#fn.body.insert(0, stmt); +} + +#[inline] +pub fn populate_fut_lifetimes_on_refs(args: &mut Vec) { + for arg in args { + if let Type::Reference(reference) = &mut arg.kind { + reference.lifetime = Some(Lifetime::new("'fut", Span::call_site())); + } + } +} diff --git a/src/framework.rs b/src/framework.rs new file mode 100644 index 0000000..1684f1f --- /dev/null +++ b/src/framework.rs @@ -0,0 +1,339 @@ +use serenity::{ + async_trait, + client::Context, + constants::MESSAGE_CODE_LIMIT, + framework::{standard::Args, Framework}, + futures::prelude::future::BoxFuture, + http::Http, + model::{ + channel::{Channel, GuildChannel, Message}, + guild::{Guild, Member}, + id::ChannelId, + }, + Result as SerenityResult, +}; + +use log::{error, info, warn}; + +use regex::{Match, Regex, RegexBuilder}; + +use std::{collections::HashMap, fmt}; + +use crate::{guild_data::GuildData, MySQL}; +use serenity::framework::standard::{CommandResult, Delimiter}; + +type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, Args) -> BoxFuture<'fut, CommandResult>; + +#[derive(Debug, PartialEq)] +pub enum PermissionLevel { + Unrestricted, + Managed, + Restricted, +} + +pub struct Command { + pub name: &'static str, + pub required_perms: PermissionLevel, + pub func: CommandFn, +} + +impl Command { + async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool { + if self.required_perms == PermissionLevel::Unrestricted { + true + } else { + let permissions = guild.member_permissions(&ctx, &member.user).await.unwrap(); + + if permissions.manage_guild() && self.required_perms == PermissionLevel::Managed { + return true; + } + + if self.required_perms == PermissionLevel::Managed { + let pool = ctx + .data + .read() + .await + .get::() + .cloned() + .expect("Could not get SQLPool from data"); + + match sqlx::query!( + " +SELECT role + FROM roles + WHERE guild_id = ? + ", + guild.id.as_u64() + ) + .fetch_all(&pool) + .await + { + Ok(rows) => { + let role_ids = member + .roles + .iter() + .map(|r| *r.as_u64()) + .collect::>(); + + for row in rows { + if role_ids.contains(&row.role) || &row.role == guild.id.as_u64() { + return true; + } + } + + false + } + + Err(sqlx::Error::RowNotFound) => false, + + Err(e) => { + warn!("Unexpected error occurred querying roles: {:?}", e); + + false + } + } + } else { + false + } + } + } +} + +impl fmt::Debug for Command { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Command") + .field("name", &self.name) + .field("required_perms", &self.required_perms) + .finish() + } +} + +#[async_trait] +pub trait SendIterator { + async fn say_lines( + self, + http: impl AsRef + Send + Sync + 'async_trait, + content: impl Iterator + Send + 'async_trait, + ) -> SerenityResult<()>; +} + +#[async_trait] +impl SendIterator for ChannelId { + async fn say_lines( + self, + http: impl AsRef + Send + Sync + 'async_trait, + content: impl Iterator + Send + 'async_trait, + ) -> SerenityResult<()> { + let mut current_content = String::new(); + + for line in content { + if current_content.len() + line.len() > MESSAGE_CODE_LIMIT as usize { + self.send_message(&http, |m| { + m.allowed_mentions(|am| am.empty_parse()) + .content(¤t_content) + }) + .await?; + + current_content = line; + } else { + current_content = format!("{}\n{}", current_content, line); + } + } + if !current_content.is_empty() { + self.send_message(&http, |m| { + m.allowed_mentions(|am| am.empty_parse()) + .content(¤t_content) + }) + .await?; + } + + Ok(()) + } +} + +pub struct RegexFramework { + commands: HashMap, + command_matcher: Regex, + default_prefix: String, + client_id: u64, + ignore_bots: bool, + case_insensitive: bool, +} + +impl RegexFramework { + pub fn new>(client_id: T) -> Self { + Self { + commands: HashMap::new(), + command_matcher: Regex::new(r#"^$"#).unwrap(), + default_prefix: "".to_string(), + client_id: client_id.into(), + ignore_bots: true, + case_insensitive: true, + } + } + + pub fn case_insensitive(mut self, case_insensitive: bool) -> Self { + self.case_insensitive = case_insensitive; + + self + } + + pub fn default_prefix(mut self, new_prefix: T) -> Self { + self.default_prefix = new_prefix.to_string(); + + self + } + + pub fn ignore_bots(mut self, ignore_bots: bool) -> Self { + self.ignore_bots = ignore_bots; + + self + } + + pub fn add_command(mut self, name: S, command: &'static Command) -> Self { + self.commands.insert(name.to_string(), command); + + self + } + + pub fn build(mut self) -> Self { + let command_names; + + { + let mut command_names_vec = self.commands.keys().map(|k| &k[..]).collect::>(); + + command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len())); + + command_names = command_names_vec.join("|"); + } + + info!("Command names: {}", command_names); + + { + let match_string = r#"^(?:(?:<@ID>\s*)|(?:<@!ID>\s*)|(?P\S{1,5}?))(?PCOMMANDS)(?:$|\s+(?P.*))$"# + .replace("COMMANDS", command_names.as_str()) + .replace("ID", self.client_id.to_string().as_str()); + + self.command_matcher = RegexBuilder::new(match_string.as_str()) + .case_insensitive(self.case_insensitive) + .dot_matches_new_line(true) + .build() + .unwrap(); + } + + self + } +} + +enum PermissionCheck { + None, // No permissions + All, // Sufficient permissions +} + +#[async_trait] +impl Framework for RegexFramework { + async fn dispatch(&self, ctx: Context, msg: Message) { + async fn check_self_permissions( + ctx: &Context, + channel: &GuildChannel, + ) -> SerenityResult { + let user_id = ctx.cache.current_user_id().await; + + let channel_perms = channel.permissions_for_user(ctx, user_id).await?; + + Ok( + if channel_perms.send_messages() && channel_perms.embed_links() { + PermissionCheck::All + } else { + PermissionCheck::None + }, + ) + } + + async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option>) -> bool { + if let Some(prefix) = prefix_opt { + let pool = ctx + .data + .read() + .await + .get::() + .cloned() + .expect("Could not get SQLPool from data"); + + let guild_prefix = match GuildData::get_from_id(guild.clone(), pool.clone()).await { + Some(guild_data) => guild_data.prefix, + + None => { + GuildData::create_from_guild(guild, pool).await.unwrap(); + String::from("?") + } + }; + + guild_prefix.as_str() == prefix.as_str() + } else { + true + } + } + + // gate to prevent analysing messages unnecessarily + if msg.author.bot || msg.content.is_empty() { + } + // Guild Command + else if let (Some(guild), Some(Channel::Guild(channel))) = + (msg.guild(&ctx).await, msg.channel(&ctx).await) + { + if let Some(full_match) = self.command_matcher.captures(&msg.content) { + if check_prefix(&ctx, &guild, full_match.name("prefix")).await { + match check_self_permissions(&ctx, &channel).await { + Ok(perms) => match perms { + PermissionCheck::All => { + let command = self + .commands + .get(&full_match.name("cmd").unwrap().as_str().to_lowercase()) + .unwrap(); + + let args = full_match + .name("args") + .map(|m| m.as_str()) + .unwrap_or("") + .to_string(); + + let member = guild.member(&ctx, &msg.author).await.unwrap(); + + if command.check_permissions(&ctx, &guild, &member).await { + dbg!(command.name); + + (command.func)( + &ctx, + &msg, + Args::new(&args, &[Delimiter::Single(' ')]), + ) + .await + .unwrap(); + } else if command.required_perms == PermissionLevel::Restricted { + let _ = msg.channel_id.say(&ctx, "You must either be an Admin or have a role specified in `?roles` to do this command").await; + } else if command.required_perms == PermissionLevel::Managed { + let _ = msg + .channel_id + .say(&ctx, "You must be an Admin to do this command") + .await; + } + } + + PermissionCheck::None => { + warn!("Missing enough permissions for guild {}", guild.id); + } + }, + + Err(e) => { + error!( + "Error occurred getting permissions in guild {}: {:?}", + guild.id, e + ); + } + } + } + } + } + } +} diff --git a/src/guilddata.rs b/src/guild_data.rs similarity index 94% rename from src/guilddata.rs rename to src/guild_data.rs index 803238a..90f432c 100644 --- a/src/guilddata.rs +++ b/src/guild_data.rs @@ -25,7 +25,7 @@ SELECT id, prefix, volume, allow_greets match guild_data { Ok(g) => Some(g), - Err(sqlx::Error::RowNotFound) => Self::create_from_guild(guild, db_pool).await.ok(), + Err(sqlx::Error::RowNotFound) => Self::create_from_guild(&guild, db_pool).await.ok(), Err(e) => { println!("{:?}", e); @@ -36,7 +36,7 @@ SELECT id, prefix, volume, allow_greets } pub async fn create_from_guild( - guild: Guild, + guild: &Guild, db_pool: MySqlPool, ) -> Result> { sqlx::query!( @@ -62,7 +62,7 @@ INSERT IGNORE INTO roles (guild_id, role) .await?; Ok(GuildData { - id: *guild.id.as_u64(), + id: guild.id.as_u64().to_owned(), prefix: String::from("?"), volume: 100, allow_greets: true, diff --git a/src/main.rs b/src/main.rs index cf8a34a..ae17e58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,18 +4,18 @@ extern crate lazy_static; extern crate reqwest; mod error; -mod guilddata; +mod framework; +mod guild_data; mod sound; -use guilddata::GuildData; +use guild_data::GuildData; use sound::Sound; +use regex_command_attr::command; + use serenity::{ client::{bridge::gateway::GatewayIntents, Client, Context}, - framework::standard::{ - macros::{check, command, group, hook}, - Args, CommandError, CommandResult, DispatchError, Reason, StandardFramework, - }, + framework::standard::{Args, CommandResult}, http::Http, model::{ channel::{Channel, Message}, @@ -34,12 +34,11 @@ use songbird::{ Call, SerenityInit, }; -type CheckResult = Result<(), Reason>; - use sqlx::mysql::MySqlPool; use dotenv::dotenv; +use crate::framework::RegexFramework; use std::{collections::HashMap, env, sync::Arc, time::Duration}; use tokio::sync::MutexGuard; @@ -72,164 +71,6 @@ lazy_static! { }; } -#[group] -#[commands( - info, - help, - list_sounds, - change_public, - search_sounds, - show_popular_sounds, - show_random_sounds, - set_greet_sound -)] -#[checks(self_perm_check)] -struct AllUsers; - -#[group] -#[commands(play, upload_new_sound, change_volume, delete_sound, stop_playing)] -#[checks(self_perm_check, role_check)] -struct RoleManagedUsers; - -#[group] -#[commands(change_prefix, set_allowed_roles, allow_greet_sounds)] -#[checks(self_perm_check, permission_check)] -struct PermissionManagedUsers; - -#[check] -#[name("self_perm_check")] -async fn self_perm_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResult { - let channel_o = msg.channel(&ctx).await; - - if let Some(channel_e) = channel_o { - if let Channel::Guild(channel) = channel_e { - let permissions_r = channel - .permissions_for_user(&ctx, &ctx.cache.current_user_id().await) - .await; - - if let Ok(permissions) = permissions_r { - if permissions.send_messages() && permissions.embed_links() { - Ok(()) - } else { - Err(Reason::Log( - "Bot does not have enough permissions".to_string(), - )) - } - } else { - Err(Reason::Log("No perms found".to_string())) - } - } else { - Err(Reason::Log("No DM commands".to_string())) - } - } else { - Err(Reason::Log("Channel not available".to_string())) - } -} - -#[check] -#[name("role_check")] -async fn role_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResult { - async fn check_for_roles(ctx: &&Context, msg: &&Message) -> CheckResult { - let pool = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("Could not get SQLPool from data"); - - let guild_opt = msg.guild(&ctx).await; - - match guild_opt { - Some(guild) => { - let member_res = guild.member(*ctx, msg.author.id).await; - - match member_res { - Ok(member) => { - let user_roles: String = member - .roles - .iter() - .map(|r| (*r.as_u64()).to_string()) - .collect::>() - .join(", "); - - let guild_id = *msg.guild_id.unwrap().as_u64(); - - let role_res = sqlx::query!( - " -SELECT COUNT(1) as count - FROM roles - WHERE - (guild_id = ? AND role IN (?)) OR - (role = ?) - ", - guild_id, - user_roles, - guild_id - ) - .fetch_one(&pool) - .await; - - match role_res { - Ok(role_count) => { - if role_count.count > 0 { - Ok(()) - } - else { - Err(Reason::User("User has not got a sufficient role. Use `?roles` to set up role restrictions".to_string())) - } - } - - Err(_) => { - Err(Reason::User("User has not got a sufficient role. Use `?roles` to set up role restrictions".to_string())) - } - } - } - - Err(_) => Err(Reason::User( - "Unexpected error looking up user roles".to_string(), - )), - } - } - - None => Err(Reason::User( - "Unexpected error looking up guild".to_string(), - )), - } - } - - if perform_permission_check(ctx, &msg).await.is_ok() { - Ok(()) - } else { - check_for_roles(&ctx, &msg).await - } -} - -#[check] -#[name("permission_check")] -async fn permission_check(ctx: &Context, msg: &Message, _args: &mut Args) -> CheckResult { - perform_permission_check(ctx, &msg).await -} - -async fn perform_permission_check(ctx: &Context, msg: &&Message) -> CheckResult { - if let Some(guild) = msg.guild(&ctx).await { - if guild - .member_permissions(&ctx, &msg.author) - .await - .unwrap() - .manage_guild() - { - Ok(()) - } else { - Err(Reason::User(String::from( - "User needs `Manage Guild` permission", - ))) - } - } else { - Err(Reason::User(String::from("Guild not cached"))) - } -} - // create event handler for bot struct Handler; @@ -403,28 +244,6 @@ async fn join_channel( } } -#[hook] -async fn log_errors(_: &Context, m: &Message, cmd_name: &str, error: Result<(), CommandError>) { - if let Err(e) = error { - println!("Error in command {} ({}): {:?}", cmd_name, m.content, e); - } -} - -#[hook] -async fn dispatch_error_hook(ctx: &Context, msg: &Message, error: DispatchError) { - match error { - DispatchError::CheckFailed(_f, reason) => { - if let Reason::User(description) = reason { - let _ = msg - .reply(ctx, format!("You cannot do this command: {}", description)) - .await; - } - } - - _ => {} - } -} - // entry point #[tokio::main] async fn main() -> Result<(), Box> { @@ -436,46 +255,16 @@ async fn main() -> Result<(), Box> { let logged_in_id = http.get_current_user().await?.id; - let framework = StandardFramework::new() - .configure(|c| { - c.dynamic_prefix(|ctx, msg| { - Box::pin(async move { - let pool = ctx - .data - .read() - .await - .get::() - .cloned() - .expect("Could not get SQLPool from data"); - - let guild = match msg.guild(&ctx.cache).await { - Some(guild) => guild, - - None => { - return Some(String::from("?")); - } - }; - - match GuildData::get_from_id(guild.clone(), pool.clone()).await { - Some(guild_data) => Some(guild_data.prefix), - - None => { - GuildData::create_from_guild(guild, pool).await.unwrap(); - Some(String::from("?")) - } - } - }) - }) - .allow_dm(false) - .ignore_bots(true) - .ignore_webhooks(true) - .on_mention(Some(logged_in_id)) - }) - .group(&ALLUSERS_GROUP) - .group(&ROLEMANAGEDUSERS_GROUP) - .group(&PERMISSIONMANAGEDUSERS_GROUP) - .after(log_errors) - .on_dispatch_error(dispatch_error_hook); + let framework = RegexFramework::new(logged_in_id) + .default_prefix("?") + .case_insensitive(true) + .ignore_bots(true) + // info commands + .add_command("help", &HELP_COMMAND) + .add_command("info", &INFO_COMMAND) + .add_command("invite", &INFO_COMMAND) + .add_command("donate", &INFO_COMMAND) + .build(); let mut client = Client::builder(&env::var("DISCORD_TOKEN").expect("Missing token from environment")) @@ -518,8 +307,7 @@ async fn main() -> Result<(), Box> { Ok(()) } -#[command("play")] -#[aliases("p")] +#[command] async fn play(ctx: &Context, msg: &Message, args: Args) -> CommandResult { let guild = match msg.guild(&ctx.cache).await { Some(guild) => guild, diff --git a/src/sound.rs b/src/sound.rs index ce3280c..f86c68c 100644 --- a/src/sound.rs +++ b/src/sound.rs @@ -130,7 +130,7 @@ SELECT src .await .unwrap(); - return record.src; + record.src } pub async fn store_sound_source(