241 lines
6.0 KiB
Rust
241 lines
6.0 KiB
Rust
use crate::structures::CommandFun;
|
|
use proc_macro::TokenStream;
|
|
use proc_macro2::Span;
|
|
use proc_macro2::TokenStream as TokenStream2;
|
|
use quote::{format_ident, quote, ToTokens};
|
|
use syn::{
|
|
braced, bracketed, parenthesized,
|
|
parse::{Error, Parse, ParseStream, Result as SynResult},
|
|
parse_quote,
|
|
punctuated::Punctuated,
|
|
spanned::Spanned,
|
|
token::{Comma, Mut},
|
|
Ident, Lifetime, Lit, Type,
|
|
};
|
|
|
|
pub trait LitExt {
|
|
fn to_str(&self) -> String;
|
|
fn to_bool(&self) -> bool;
|
|
fn to_ident(&self) -> Ident;
|
|
}
|
|
|
|
impl LitExt for Lit {
|
|
fn to_str(&self) -> String {
|
|
match self {
|
|
Lit::Str(s) => s.value(),
|
|
Lit::ByteStr(s) => unsafe { String::from_utf8_unchecked(s.value()) },
|
|
Lit::Char(c) => c.value().to_string(),
|
|
Lit::Byte(b) => (b.value() as char).to_string(),
|
|
_ => panic!("values must be a (byte)string or a char"),
|
|
}
|
|
}
|
|
|
|
fn to_bool(&self) -> bool {
|
|
if let Lit::Bool(b) = self {
|
|
b.value
|
|
} else {
|
|
self.to_str()
|
|
.parse()
|
|
.unwrap_or_else(|_| panic!("expected bool from {:?}", self))
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn to_ident(&self) -> Ident {
|
|
Ident::new(&self.to_str(), self.span())
|
|
}
|
|
}
|
|
|
|
pub trait IdentExt2: Sized {
|
|
fn to_uppercase(&self) -> Self;
|
|
fn with_suffix(&self, suf: &str) -> Ident;
|
|
}
|
|
|
|
impl IdentExt2 for Ident {
|
|
#[inline]
|
|
fn to_uppercase(&self) -> Self {
|
|
format_ident!("{}", self.to_string().to_uppercase())
|
|
}
|
|
|
|
#[inline]
|
|
fn with_suffix(&self, suffix: &str) -> Ident {
|
|
format_ident!("{}_{}", self.to_string().to_uppercase(), suffix)
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn into_stream(e: Error) -> TokenStream {
|
|
e.to_compile_error().into()
|
|
}
|
|
|
|
macro_rules! propagate_err {
|
|
($res:expr) => {{
|
|
match $res {
|
|
Ok(v) => v,
|
|
Err(e) => return $crate::util::into_stream(e),
|
|
}
|
|
}};
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Bracketed<T>(pub Punctuated<T, Comma>);
|
|
|
|
impl<T: Parse> Parse for Bracketed<T> {
|
|
fn parse(input: ParseStream<'_>) -> SynResult<Self> {
|
|
let content;
|
|
bracketed!(content in input);
|
|
|
|
Ok(Bracketed(content.parse_terminated(T::parse)?))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Braced<T>(pub Punctuated<T, Comma>);
|
|
|
|
impl<T: Parse> Parse for Braced<T> {
|
|
fn parse(input: ParseStream<'_>) -> SynResult<Self> {
|
|
let content;
|
|
braced!(content in input);
|
|
|
|
Ok(Braced(content.parse_terminated(T::parse)?))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Parenthesised<T>(pub Punctuated<T, Comma>);
|
|
|
|
impl<T: Parse> Parse for Parenthesised<T> {
|
|
fn parse(input: ParseStream<'_>) -> SynResult<Self> {
|
|
let content;
|
|
parenthesized!(content in input);
|
|
|
|
Ok(Parenthesised(content.parse_terminated(T::parse)?))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct AsOption<T>(pub Option<T>);
|
|
|
|
impl<T: ToTokens> ToTokens for AsOption<T> {
|
|
fn to_tokens(&self, stream: &mut TokenStream2) {
|
|
match &self.0 {
|
|
Some(o) => stream.extend(quote!(Some(#o))),
|
|
None => stream.extend(quote!(None)),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> Default for AsOption<T> {
|
|
#[inline]
|
|
fn default() -> Self {
|
|
AsOption(None)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Argument {
|
|
pub mutable: Option<Mut>,
|
|
pub name: Ident,
|
|
pub kind: Type,
|
|
}
|
|
|
|
impl ToTokens for Argument {
|
|
fn to_tokens(&self, stream: &mut TokenStream2) {
|
|
let Argument {
|
|
mutable,
|
|
name,
|
|
kind,
|
|
} = self;
|
|
|
|
stream.extend(quote! {
|
|
#mutable #name: #kind
|
|
});
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn generate_type_validation(have: Type, expect: Type) -> syn::Stmt {
|
|
parse_quote! {
|
|
serenity::static_assertions::assert_type_eq_all!(#have, #expect);
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum DeclarFor {
|
|
Command,
|
|
Help,
|
|
Check,
|
|
}
|
|
|
|
pub fn create_declaration_validations(fun: &mut CommandFun, dec_for: DeclarFor) -> SynResult<()> {
|
|
let len = match dec_for {
|
|
DeclarFor::Command => 3,
|
|
DeclarFor::Help => 6,
|
|
DeclarFor::Check => 4,
|
|
};
|
|
|
|
if fun.args.len() > len {
|
|
return Err(Error::new(
|
|
fun.args.last().unwrap().span(),
|
|
format_args!("function's arity exceeds more than {} arguments", len),
|
|
));
|
|
}
|
|
|
|
let context: Type = parse_quote!(&serenity::client::Context);
|
|
let message: Type = parse_quote!(&serenity::model::channel::Message);
|
|
let args: Type = parse_quote!(String);
|
|
let options: Type = parse_quote!(&serenity::framework::standard::CommandOptions);
|
|
let hoptions: Type = parse_quote!(&'static serenity::framework::standard::HelpOptions);
|
|
let groups: Type = parse_quote!(&[&'static serenity::framework::standard::CommandGroup]);
|
|
let owners: Type = parse_quote!(std::collections::HashSet<serenity::model::id::UserId>);
|
|
|
|
let mut index = 0;
|
|
|
|
let mut spoof_or_check = |kind: Type, name: &str| {
|
|
match fun.args.get(index) {
|
|
Some(x) => fun.body.insert(0, generate_type_validation(x.kind.clone(), kind)),
|
|
None => fun.args.push(Argument {
|
|
mutable: None,
|
|
name: Ident::new(name, Span::call_site()),
|
|
kind,
|
|
}),
|
|
}
|
|
|
|
index += 1;
|
|
};
|
|
|
|
spoof_or_check(context, "_ctx");
|
|
spoof_or_check(message, "_msg");
|
|
|
|
if dec_for == DeclarFor::Check {
|
|
spoof_or_check(options, "_options");
|
|
|
|
return Ok(());
|
|
}
|
|
|
|
spoof_or_check(args, "_args");
|
|
|
|
if dec_for == DeclarFor::Help {
|
|
spoof_or_check(hoptions, "_hoptions");
|
|
spoof_or_check(groups, "_groups");
|
|
spoof_or_check(owners, "_owners");
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[inline]
|
|
pub fn create_return_type_validation(r#fn: &mut CommandFun, expect: Type) {
|
|
let stmt = generate_type_validation(r#fn.ret.clone(), expect);
|
|
r#fn.body.insert(0, stmt);
|
|
}
|
|
|
|
#[inline]
|
|
pub fn populate_fut_lifetimes_on_refs(args: &mut Vec<Argument>) {
|
|
for arg in args {
|
|
if let Type::Reference(reference) = &mut arg.kind {
|
|
reference.lifetime = Some(Lifetime::new("'fut", Span::call_site()));
|
|
}
|
|
}
|
|
}
|