Compare commits
	
		
			57 Commits
		
	
	
		
			1.5-dead
			...
			postman-in
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| e5ab99f67b | |||
| e47715917e | |||
| 4f9eb58c16 | |||
| c953bc0cd3 | |||
| 610779a293 | |||
| ebd1efa990 | |||
| 5230101a8d | |||
| d8f42c1b25 | |||
| 23c6b3869e | |||
| a21f518b21 | |||
| f1bfc11160 | |||
| 72228911f2 | |||
| db7cca6296 | |||
| e36e718f28 | |||
| 44debf93c5 | |||
| 9b54fba5e5 | |||
| 6cf660c7ee | |||
| 4490f19c04 | |||
| a362a24cfc | |||
| 903daf65e6 | |||
| b310e99085 | |||
| ebabe0e85a | |||
| 6b5d6ae288 | |||
| 379e488f7a | |||
| d84d7ab62b | |||
| a0974795e1 | |||
| a9c91bee93 | |||
| b2207e308a | |||
| 3c1eeed92f | |||
| 395a8481f1 | |||
| bae0433bd9 | |||
| 3e547861ea | |||
| 9b5333dc87 | |||
| 471948bed3 | |||
| c148cdf556 | |||
| 98aed91d21 | |||
| 40630c0014 | |||
| 85a8ae625d | |||
| 43bbcb3fe0 | |||
| 1556318d07 | |||
| ea2b0f4b0a | |||
| f02c04b313 | |||
| 320060b1bd | |||
| bef33c6dac | |||
| 7bcb3c4a70 | |||
| 2e153cffab | |||
| 540f120d7d | |||
| 59ffb505dc | |||
| 2bec2b9e12 | |||
| 507075d9d4 | |||
| 85659f05aa | |||
| eb07ece779 | |||
| 1a09f026c9 | |||
| b31843c478 | |||
| 9109250fe8 | |||
| 2346c2e978 | |||
| a0da4dcf00 | 
							
								
								
									
										964
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										964
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										30
									
								
								Cargo.toml
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								Cargo.toml
									
									
									
									
									
								
							@@ -1,11 +1,10 @@
 | 
			
		||||
[package]
 | 
			
		||||
name = "reminder_rs"
 | 
			
		||||
version = "1.4.13"
 | 
			
		||||
version = "1.6.0-beta3"
 | 
			
		||||
authors = ["jellywx <judesouthworth@pm.me>"]
 | 
			
		||||
edition = "2018"
 | 
			
		||||
 | 
			
		||||
[dependencies]
 | 
			
		||||
dashmap = "4.0"
 | 
			
		||||
dotenv = "0.15"
 | 
			
		||||
humantime = "2.1"
 | 
			
		||||
tokio = { version = "1", features = ["process", "full"] }
 | 
			
		||||
@@ -14,17 +13,34 @@ regex = "1.4"
 | 
			
		||||
log = "0.4"
 | 
			
		||||
env_logger = "0.8"
 | 
			
		||||
chrono = "0.4"
 | 
			
		||||
chrono-tz = "0.5"
 | 
			
		||||
chrono-tz = { version = "0.5", features = ["serde"] }
 | 
			
		||||
lazy_static = "1.4"
 | 
			
		||||
num-integer = "0.1"
 | 
			
		||||
serde = "1.0"
 | 
			
		||||
serde_json = "1.0"
 | 
			
		||||
serde_repr = "0.1"
 | 
			
		||||
rmp-serde = "0.15"
 | 
			
		||||
rand = "0.7"
 | 
			
		||||
Inflector = "0.11"
 | 
			
		||||
levenshtein = "1.0"
 | 
			
		||||
# serenity = { version = "0.10", features = ["collector"] }
 | 
			
		||||
serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["collector"] }
 | 
			
		||||
sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]}
 | 
			
		||||
base64 = "0.13.0"
 | 
			
		||||
 | 
			
		||||
[dependencies.regex_command_attr]
 | 
			
		||||
path = "./regex_command_attr"
 | 
			
		||||
path = "command_attributes"
 | 
			
		||||
 | 
			
		||||
[dependencies.serenity]
 | 
			
		||||
git = "https://github.com/serenity-rs/serenity"
 | 
			
		||||
branch = "next"
 | 
			
		||||
default-features = false
 | 
			
		||||
features = [
 | 
			
		||||
    "builder",
 | 
			
		||||
    "client",
 | 
			
		||||
    "cache",
 | 
			
		||||
    "gateway",
 | 
			
		||||
    "http",
 | 
			
		||||
    "model",
 | 
			
		||||
    "utils",
 | 
			
		||||
    "rustls_backend",
 | 
			
		||||
    "collector",
 | 
			
		||||
    "unstable_discord_api"
 | 
			
		||||
]
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										13
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								README.md
									
									
									
									
									
								
							@@ -1,6 +1,5 @@
 | 
			
		||||
# reminder-rs
 | 
			
		||||
Reminder Bot for Discord, now in Rust.
 | 
			
		||||
Old Python version: https://github.com/reminder-bot/bot
 | 
			
		||||
Reminder Bot for Discord.
 | 
			
		||||
 | 
			
		||||
## How do I use it?
 | 
			
		||||
We offer a hosted version of the bot. You can invite it with: **https://invite.reminder-bot.com**. The catch is that repeating 
 | 
			
		||||
@@ -15,7 +14,6 @@ Reminder Bot can be built by running `cargo build --release` in the top level di
 | 
			
		||||
These environment variables must be provided when compiling the bot
 | 
			
		||||
* `DATABASE_URL` - the URL of your MySQL database (`mysql://user[:password]@domain/database`)
 | 
			
		||||
* `WEBHOOK_AVATAR` - accepts the name of an image file located in `$CARGO_MANIFEST_DIR/assets/` to be used as the avatar when creating webhooks. **IMPORTANT: image file must be 128x128 or smaller in size**
 | 
			
		||||
* `STRINGS_FILE` - accepts the name of a compiled strings file located in `$CARGO_MANIFEST_DIR/assets/` to be used for creating messages. Compiled string files can be generated with `compile.py` at https://github.com/reminder-bot/languages
 | 
			
		||||
 | 
			
		||||
### Setting up Python
 | 
			
		||||
Reminder Bot by default looks for a venv within it's working directory to run Python out of. To set up a venv, install `python3-venv` and run `python3 -m venv venv`. Then, run `source venv/bin/activate` to activate the venv, and do `pip install dateparser` to install the required library
 | 
			
		||||
@@ -29,16 +27,17 @@ __Required Variables__
 | 
			
		||||
 | 
			
		||||
__Other Variables__
 | 
			
		||||
* `MIN_INTERVAL` - default `600`, defines the shortest interval the bot should accept
 | 
			
		||||
* `MAX_TIME` - default `1576800000`, defines the maximum time ahead that reminders can be set for
 | 
			
		||||
* `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor
 | 
			
		||||
* `DEFAULT_PREFIX` - default `$`, used for the default prefix on new guilds
 | 
			
		||||
* `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users
 | 
			
		||||
* `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to
 | 
			
		||||
* `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots
 | 
			
		||||
* `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else
 | 
			
		||||
* `LOCAL_LANGUAGE` - default `EN`. Specifies the string set to fall back to if a string cannot be found (and to be used with new users)
 | 
			
		||||
* `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds 
 | 
			
		||||
* `CASE_INSENSITIVE` - default `1`, if `1`, commands will be treated with case insensitivity (so both `$help` and `$HELP` will work)
 | 
			
		||||
* `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran
 | 
			
		||||
* `SHARD_RANGE` - default `None`, if `SHARD_COUNT` is specified, specifies what range of shards to start on this process 
 | 
			
		||||
* `DM_ENABLED` - default `1`, if `1`, Reminder Bot will respond to direct messages
 | 
			
		||||
 | 
			
		||||
### Todo List
 | 
			
		||||
 | 
			
		||||
* Convert aliases to macros
 | 
			
		||||
* Help command
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,10 @@
 | 
			
		||||
[package]
 | 
			
		||||
name = "regex_command_attr"
 | 
			
		||||
version = "0.2.0"
 | 
			
		||||
version = "0.3.6"
 | 
			
		||||
authors = ["acdenisSK <acdenissk69@gmail.com>", "jellywx <judesouthworth@pm.me>"]
 | 
			
		||||
edition = "2018"
 | 
			
		||||
description = "Procedural macros for command creation for the RegexFramework for serenity."
 | 
			
		||||
description = "Procedural macros for command creation for the Serenity library."
 | 
			
		||||
license = "ISC"
 | 
			
		||||
 | 
			
		||||
[lib]
 | 
			
		||||
proc-macro = true
 | 
			
		||||
@@ -12,3 +13,4 @@ proc-macro = true
 | 
			
		||||
quote = "^1.0"
 | 
			
		||||
syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] }
 | 
			
		||||
proc-macro2 = "1.0"
 | 
			
		||||
uuid = { version = "0.8", features = ["v4"] }
 | 
			
		||||
@@ -1,13 +1,17 @@
 | 
			
		||||
use proc_macro2::Span;
 | 
			
		||||
use syn::parse::{Error, Result};
 | 
			
		||||
use syn::spanned::Spanned;
 | 
			
		||||
use syn::{Attribute, Ident, Lit, LitStr, Meta, NestedMeta, Path};
 | 
			
		||||
 | 
			
		||||
use crate::structures::PermissionLevel;
 | 
			
		||||
use crate::util::{AsOption, LitExt};
 | 
			
		||||
 | 
			
		||||
use std::fmt::{self, Write};
 | 
			
		||||
 | 
			
		||||
use proc_macro2::Span;
 | 
			
		||||
use syn::{
 | 
			
		||||
    parse::{Error, Result},
 | 
			
		||||
    spanned::Spanned,
 | 
			
		||||
    Attribute, Ident, Lit, LitStr, Meta, NestedMeta, Path,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    structures::{ApplicationCommandOptionType, Arg},
 | 
			
		||||
    util::{AsOption, LitExt},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone, Copy, PartialEq)]
 | 
			
		||||
pub enum ValueKind {
 | 
			
		||||
    // #[<name>]
 | 
			
		||||
@@ -19,6 +23,9 @@ pub enum ValueKind {
 | 
			
		||||
    // #[<name>([<value>, <value>, <value>, ...])]
 | 
			
		||||
    List,
 | 
			
		||||
 | 
			
		||||
    // #[<name>([<prop> = <value>, <prop> = <value>, ...])]
 | 
			
		||||
    EqualsList,
 | 
			
		||||
 | 
			
		||||
    // #[<name>(<value>)]
 | 
			
		||||
    SingleList,
 | 
			
		||||
}
 | 
			
		||||
@@ -29,6 +36,9 @@ impl fmt::Display for ValueKind {
 | 
			
		||||
            ValueKind::Name => f.pad("`#[<name>]`"),
 | 
			
		||||
            ValueKind::Equals => f.pad("`#[<name> = <value>]`"),
 | 
			
		||||
            ValueKind::List => f.pad("`#[<name>([<value>, <value>, <value>, ...])]`"),
 | 
			
		||||
            ValueKind::EqualsList => {
 | 
			
		||||
                f.pad("`#[<name>([<prop> = <value>, <prop> = <value>, ...])]`")
 | 
			
		||||
            }
 | 
			
		||||
            ValueKind::SingleList => f.pad("`#[<name>(<value>)]`"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@@ -36,24 +46,15 @@ impl fmt::Display for ValueKind {
 | 
			
		||||
 | 
			
		||||
fn to_ident(p: Path) -> Result<Ident> {
 | 
			
		||||
    if p.segments.is_empty() {
 | 
			
		||||
        return Err(Error::new(
 | 
			
		||||
            p.span(),
 | 
			
		||||
            "cannot convert an empty path to an identifier",
 | 
			
		||||
        ));
 | 
			
		||||
        return Err(Error::new(p.span(), "cannot convert an empty path to an identifier"));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if p.segments.len() > 1 {
 | 
			
		||||
        return Err(Error::new(
 | 
			
		||||
            p.span(),
 | 
			
		||||
            "the path must not have more than one segment",
 | 
			
		||||
        ));
 | 
			
		||||
        return Err(Error::new(p.span(), "the path must not have more than one segment"));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if !p.segments[0].arguments.is_empty() {
 | 
			
		||||
        return Err(Error::new(
 | 
			
		||||
            p.span(),
 | 
			
		||||
            "the singular path segment must not have any arguments",
 | 
			
		||||
        ));
 | 
			
		||||
        return Err(Error::new(p.span(), "the singular path segment must not have any arguments"));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(p.segments[0].ident.clone())
 | 
			
		||||
@@ -62,24 +63,37 @@ fn to_ident(p: Path) -> Result<Ident> {
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct Values {
 | 
			
		||||
    pub name: Ident,
 | 
			
		||||
    pub literals: Vec<Lit>,
 | 
			
		||||
    pub literals: Vec<(Option<String>, Lit)>,
 | 
			
		||||
    pub kind: ValueKind,
 | 
			
		||||
    pub span: Span,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Values {
 | 
			
		||||
    #[inline]
 | 
			
		||||
    pub fn new(name: Ident, kind: ValueKind, literals: Vec<Lit>, span: Span) -> Self {
 | 
			
		||||
        Values {
 | 
			
		||||
            name,
 | 
			
		||||
            literals,
 | 
			
		||||
            kind,
 | 
			
		||||
            span,
 | 
			
		||||
        }
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        name: Ident,
 | 
			
		||||
        kind: ValueKind,
 | 
			
		||||
        literals: Vec<(Option<String>, Lit)>,
 | 
			
		||||
        span: Span,
 | 
			
		||||
    ) -> Self {
 | 
			
		||||
        Values { name, literals, kind, span }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn parse_values(attr: &Attribute) -> Result<Values> {
 | 
			
		||||
    fn is_list_or_named_list(meta: &NestedMeta) -> ValueKind {
 | 
			
		||||
        match meta {
 | 
			
		||||
            // catch if the nested value is a literal value
 | 
			
		||||
            NestedMeta::Lit(_) => ValueKind::List,
 | 
			
		||||
            // catch if the nested value is a meta value
 | 
			
		||||
            NestedMeta::Meta(m) => match m {
 | 
			
		||||
                // path => some quoted value
 | 
			
		||||
                Meta::Path(_) => ValueKind::List,
 | 
			
		||||
                Meta::List(_) | Meta::NameValue(_) => ValueKind::EqualsList,
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let meta = attr.parse_meta()?;
 | 
			
		||||
 | 
			
		||||
    match meta {
 | 
			
		||||
@@ -96,36 +110,62 @@ pub fn parse_values(attr: &Attribute) -> Result<Values> {
 | 
			
		||||
                return Err(Error::new(attr.span(), "list cannot be empty"));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let mut lits = Vec::with_capacity(nested.len());
 | 
			
		||||
            if is_list_or_named_list(nested.first().unwrap()) == ValueKind::List {
 | 
			
		||||
                let mut lits = Vec::with_capacity(nested.len());
 | 
			
		||||
 | 
			
		||||
            for meta in nested {
 | 
			
		||||
                match meta {
 | 
			
		||||
                    NestedMeta::Lit(l) => lits.push(l),
 | 
			
		||||
                    NestedMeta::Meta(m) => match m {
 | 
			
		||||
                        Meta::Path(path) => {
 | 
			
		||||
                            let i = to_ident(path)?;
 | 
			
		||||
                            lits.push(Lit::Str(LitStr::new(&i.to_string(), i.span())))
 | 
			
		||||
                        }
 | 
			
		||||
                        Meta::List(_) | Meta::NameValue(_) => {
 | 
			
		||||
                            return Err(Error::new(attr.span(), "cannot nest a list; only accept literals and identifiers at this level"))
 | 
			
		||||
                        }
 | 
			
		||||
                    },
 | 
			
		||||
                for meta in nested {
 | 
			
		||||
                    match meta {
 | 
			
		||||
                        // catch if the nested value is a literal value
 | 
			
		||||
                        NestedMeta::Lit(l) => lits.push((None, l)),
 | 
			
		||||
                        // catch if the nested value is a meta value
 | 
			
		||||
                        NestedMeta::Meta(m) => match m {
 | 
			
		||||
                            // path => some quoted value
 | 
			
		||||
                            Meta::Path(path) => {
 | 
			
		||||
                                let i = to_ident(path)?;
 | 
			
		||||
                                lits.push((None, Lit::Str(LitStr::new(&i.to_string(), i.span()))))
 | 
			
		||||
                            }
 | 
			
		||||
                            Meta::List(_) | Meta::NameValue(_) => {
 | 
			
		||||
                                return Err(Error::new(attr.span(), "cannot nest a list; only accept literals and identifiers at this level"))
 | 
			
		||||
                            }
 | 
			
		||||
                        },
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let kind = if lits.len() == 1 {
 | 
			
		||||
                ValueKind::SingleList
 | 
			
		||||
                let kind = if lits.len() == 1 { ValueKind::SingleList } else { ValueKind::List };
 | 
			
		||||
 | 
			
		||||
                Ok(Values::new(name, kind, lits, attr.span()))
 | 
			
		||||
            } else {
 | 
			
		||||
                ValueKind::List
 | 
			
		||||
            };
 | 
			
		||||
                let mut lits = Vec::with_capacity(nested.len());
 | 
			
		||||
 | 
			
		||||
            Ok(Values::new(name, kind, lits, attr.span()))
 | 
			
		||||
                for meta in nested {
 | 
			
		||||
                    match meta {
 | 
			
		||||
                        // catch if the nested value is a literal value
 | 
			
		||||
                        NestedMeta::Lit(_) => {
 | 
			
		||||
                            return Err(Error::new(attr.span(), "key-value pairs expected"))
 | 
			
		||||
                        }
 | 
			
		||||
                        // catch if the nested value is a meta value
 | 
			
		||||
                        NestedMeta::Meta(m) => match m {
 | 
			
		||||
                            Meta::NameValue(n) => {
 | 
			
		||||
                                let name = to_ident(n.path)?.to_string();
 | 
			
		||||
                                let value = n.lit;
 | 
			
		||||
 | 
			
		||||
                                lits.push((Some(name), value));
 | 
			
		||||
                            }
 | 
			
		||||
                            Meta::List(_) | Meta::Path(_) => {
 | 
			
		||||
                                return Err(Error::new(attr.span(), "key-value pairs expected"))
 | 
			
		||||
                            }
 | 
			
		||||
                        },
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                Ok(Values::new(name, ValueKind::EqualsList, lits, attr.span()))
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Meta::NameValue(meta) => {
 | 
			
		||||
            let name = to_ident(meta.path)?;
 | 
			
		||||
            let lit = meta.lit;
 | 
			
		||||
 | 
			
		||||
            Ok(Values::new(name, ValueKind::Equals, vec![lit], attr.span()))
 | 
			
		||||
            Ok(Values::new(name, ValueKind::Equals, vec![(None, lit)], attr.span()))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -168,10 +208,7 @@ fn validate(values: &Values, forms: &[ValueKind]) -> Result<()> {
 | 
			
		||||
        return Err(Error::new(
 | 
			
		||||
            values.span,
 | 
			
		||||
            // Using the `_args` version here to avoid an allocation.
 | 
			
		||||
            format_args!(
 | 
			
		||||
                "the attribute must be in of these forms:\n{}",
 | 
			
		||||
                DisplaySlice(forms)
 | 
			
		||||
            ),
 | 
			
		||||
            format_args!("the attribute must be in of these forms:\n{}", DisplaySlice(forms)),
 | 
			
		||||
        ));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -191,11 +228,7 @@ impl AttributeOption for Vec<String> {
 | 
			
		||||
    fn parse(values: Values) -> Result<Self> {
 | 
			
		||||
        validate(&values, &[ValueKind::List])?;
 | 
			
		||||
 | 
			
		||||
        Ok(values
 | 
			
		||||
            .literals
 | 
			
		||||
            .into_iter()
 | 
			
		||||
            .map(|lit| lit.to_str())
 | 
			
		||||
            .collect())
 | 
			
		||||
        Ok(values.literals.into_iter().map(|(_, l)| l.to_str()).collect())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -204,7 +237,7 @@ impl AttributeOption for String {
 | 
			
		||||
    fn parse(values: Values) -> Result<Self> {
 | 
			
		||||
        validate(&values, &[ValueKind::Equals, ValueKind::SingleList])?;
 | 
			
		||||
 | 
			
		||||
        Ok(values.literals[0].to_str())
 | 
			
		||||
        Ok(values.literals[0].1.to_str())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -213,7 +246,7 @@ impl AttributeOption for bool {
 | 
			
		||||
    fn parse(values: Values) -> Result<Self> {
 | 
			
		||||
        validate(&values, &[ValueKind::Name, ValueKind::SingleList])?;
 | 
			
		||||
 | 
			
		||||
        Ok(values.literals.get(0).map_or(true, |l| l.to_bool()))
 | 
			
		||||
        Ok(values.literals.get(0).map_or(true, |(_, l)| l.to_bool()))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -222,7 +255,7 @@ impl AttributeOption for Ident {
 | 
			
		||||
    fn parse(values: Values) -> Result<Self> {
 | 
			
		||||
        validate(&values, &[ValueKind::SingleList])?;
 | 
			
		||||
 | 
			
		||||
        Ok(values.literals[0].to_ident())
 | 
			
		||||
        Ok(values.literals[0].1.to_ident())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -231,7 +264,7 @@ impl AttributeOption for Vec<Ident> {
 | 
			
		||||
    fn parse(values: Values) -> Result<Self> {
 | 
			
		||||
        validate(&values, &[ValueKind::List])?;
 | 
			
		||||
 | 
			
		||||
        Ok(values.literals.into_iter().map(|l| l.to_ident()).collect())
 | 
			
		||||
        Ok(values.literals.into_iter().map(|(_, l)| l.to_ident()).collect())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -239,15 +272,40 @@ impl AttributeOption for Option<String> {
 | 
			
		||||
    fn parse(values: Values) -> Result<Self> {
 | 
			
		||||
        validate(&values, &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList])?;
 | 
			
		||||
 | 
			
		||||
        Ok(values.literals.get(0).map(|l| l.to_str()))
 | 
			
		||||
        Ok(values.literals.get(0).map(|(_, l)| l.to_str()))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AttributeOption for PermissionLevel {
 | 
			
		||||
impl AttributeOption for Arg {
 | 
			
		||||
    fn parse(values: Values) -> Result<Self> {
 | 
			
		||||
        validate(&values, &[ValueKind::SingleList])?;
 | 
			
		||||
        validate(&values, &[ValueKind::EqualsList])?;
 | 
			
		||||
 | 
			
		||||
        Ok(values.literals.get(0).map(|l| PermissionLevel::from_str(&*l.to_str()).unwrap()).unwrap())
 | 
			
		||||
        let mut arg: Arg = Default::default();
 | 
			
		||||
 | 
			
		||||
        for (key, value) in &values.literals {
 | 
			
		||||
            match key {
 | 
			
		||||
                Some(s) => match s.as_str() {
 | 
			
		||||
                    "name" => {
 | 
			
		||||
                        arg.name = value.to_str();
 | 
			
		||||
                    }
 | 
			
		||||
                    "description" => {
 | 
			
		||||
                        arg.description = value.to_str();
 | 
			
		||||
                    }
 | 
			
		||||
                    "required" => {
 | 
			
		||||
                        arg.required = value.to_bool();
 | 
			
		||||
                    }
 | 
			
		||||
                    "kind" => arg.kind = ApplicationCommandOptionType::from_str(value.to_str()),
 | 
			
		||||
                    _ => {
 | 
			
		||||
                        return Err(Error::new(key.span(), "unexpected attribute"));
 | 
			
		||||
                    }
 | 
			
		||||
                },
 | 
			
		||||
                _ => {
 | 
			
		||||
                    return Err(Error::new(key.span(), "unnamed attribute"));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(arg)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -265,7 +323,7 @@ macro_rules! attr_option_num {
 | 
			
		||||
                fn parse(values: Values) -> Result<Self> {
 | 
			
		||||
                    validate(&values, &[ValueKind::SingleList])?;
 | 
			
		||||
 | 
			
		||||
                    Ok(match &values.literals[0] {
 | 
			
		||||
                    Ok(match &values.literals[0].1 {
 | 
			
		||||
                        Lit::Int(l) => l.base10_parse::<$n>()?,
 | 
			
		||||
                        l => {
 | 
			
		||||
                            let s = l.to_str();
 | 
			
		||||
							
								
								
									
										10
									
								
								command_attributes/src/consts.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								command_attributes/src/consts.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
pub mod suffixes {
 | 
			
		||||
    pub const COMMAND: &str = "COMMAND";
 | 
			
		||||
    pub const ARG: &str = "ARG";
 | 
			
		||||
    pub const SUBCOMMAND: &str = "SUBCOMMAND";
 | 
			
		||||
    pub const SUBCOMMAND_GROUP: &str = "GROUP";
 | 
			
		||||
    pub const CHECK: &str = "CHECK";
 | 
			
		||||
    pub const HOOK: &str = "HOOK";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub use self::suffixes::*;
 | 
			
		||||
							
								
								
									
										321
									
								
								command_attributes/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										321
									
								
								command_attributes/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,321 @@
 | 
			
		||||
#![deny(rust_2018_idioms)]
 | 
			
		||||
#![deny(broken_intra_doc_links)]
 | 
			
		||||
 | 
			
		||||
use proc_macro::TokenStream;
 | 
			
		||||
use proc_macro2::Ident;
 | 
			
		||||
use quote::quote;
 | 
			
		||||
use syn::{parse::Error, parse_macro_input, parse_quote, spanned::Spanned, Lit, Type};
 | 
			
		||||
use uuid::Uuid;
 | 
			
		||||
 | 
			
		||||
pub(crate) mod attributes;
 | 
			
		||||
pub(crate) mod consts;
 | 
			
		||||
pub(crate) mod structures;
 | 
			
		||||
 | 
			
		||||
#[macro_use]
 | 
			
		||||
pub(crate) mod util;
 | 
			
		||||
 | 
			
		||||
use attributes::*;
 | 
			
		||||
use consts::*;
 | 
			
		||||
use structures::*;
 | 
			
		||||
use util::*;
 | 
			
		||||
 | 
			
		||||
macro_rules! match_options {
 | 
			
		||||
    ($v:expr, $values:ident, $options:ident, $span:expr => [$($name:ident);*]) => {
 | 
			
		||||
        match $v {
 | 
			
		||||
            $(
 | 
			
		||||
                stringify!($name) => $options.$name = propagate_err!($crate::attributes::parse($values)),
 | 
			
		||||
            )*
 | 
			
		||||
            _ => {
 | 
			
		||||
                return Error::new($span, format_args!("invalid attribute: {:?}", $v))
 | 
			
		||||
                    .to_compile_error()
 | 
			
		||||
                    .into();
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[proc_macro_attribute]
 | 
			
		||||
pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
 | 
			
		||||
    enum LastItem {
 | 
			
		||||
        Fun,
 | 
			
		||||
        SubFun,
 | 
			
		||||
        SubGroup,
 | 
			
		||||
        SubGroupFun,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let mut fun = parse_macro_input!(input as CommandFun);
 | 
			
		||||
 | 
			
		||||
    let _name = if !attr.is_empty() {
 | 
			
		||||
        parse_macro_input!(attr as Lit).to_str()
 | 
			
		||||
    } else {
 | 
			
		||||
        fun.name.to_string()
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let mut hooks: Vec<Ident> = Vec::new();
 | 
			
		||||
    let mut options = Options::new();
 | 
			
		||||
    let mut last_desc = LastItem::Fun;
 | 
			
		||||
 | 
			
		||||
    for attribute in &fun.attributes {
 | 
			
		||||
        let span = attribute.span();
 | 
			
		||||
        let values = propagate_err!(parse_values(attribute));
 | 
			
		||||
 | 
			
		||||
        let name = values.name.to_string();
 | 
			
		||||
        let name = &name[..];
 | 
			
		||||
 | 
			
		||||
        match name {
 | 
			
		||||
            "subcommand" => {
 | 
			
		||||
                let new_subcommand = Subcommand::new(propagate_err!(attributes::parse(values)));
 | 
			
		||||
 | 
			
		||||
                if let Some(subcommand_group) = options.subcommand_groups.last_mut() {
 | 
			
		||||
                    last_desc = LastItem::SubGroupFun;
 | 
			
		||||
                    subcommand_group.subcommands.push(new_subcommand);
 | 
			
		||||
                } else {
 | 
			
		||||
                    last_desc = LastItem::SubFun;
 | 
			
		||||
                    options.subcommands.push(new_subcommand);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            "subcommandgroup" => {
 | 
			
		||||
                let new_group = SubcommandGroup::new(propagate_err!(attributes::parse(values)));
 | 
			
		||||
                last_desc = LastItem::SubGroup;
 | 
			
		||||
 | 
			
		||||
                options.subcommand_groups.push(new_group);
 | 
			
		||||
            }
 | 
			
		||||
            "arg" => {
 | 
			
		||||
                let arg = propagate_err!(attributes::parse(values));
 | 
			
		||||
 | 
			
		||||
                match last_desc {
 | 
			
		||||
                    LastItem::Fun => {
 | 
			
		||||
                        options.cmd_args.push(arg);
 | 
			
		||||
                    }
 | 
			
		||||
                    LastItem::SubFun => {
 | 
			
		||||
                        options.subcommands.last_mut().unwrap().cmd_args.push(arg);
 | 
			
		||||
                    }
 | 
			
		||||
                    LastItem::SubGroup => {
 | 
			
		||||
                        panic!("Argument not expected under subcommand group");
 | 
			
		||||
                    }
 | 
			
		||||
                    LastItem::SubGroupFun => {
 | 
			
		||||
                        options
 | 
			
		||||
                            .subcommand_groups
 | 
			
		||||
                            .last_mut()
 | 
			
		||||
                            .unwrap()
 | 
			
		||||
                            .subcommands
 | 
			
		||||
                            .last_mut()
 | 
			
		||||
                            .unwrap()
 | 
			
		||||
                            .cmd_args
 | 
			
		||||
                            .push(arg);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            "example" => {
 | 
			
		||||
                options.examples.push(propagate_err!(attributes::parse(values)));
 | 
			
		||||
            }
 | 
			
		||||
            "description" => {
 | 
			
		||||
                let line: String = propagate_err!(attributes::parse(values));
 | 
			
		||||
 | 
			
		||||
                match last_desc {
 | 
			
		||||
                    LastItem::Fun => {
 | 
			
		||||
                        util::append_line(&mut options.description, line);
 | 
			
		||||
                    }
 | 
			
		||||
                    LastItem::SubFun => {
 | 
			
		||||
                        util::append_line(
 | 
			
		||||
                            &mut options.subcommands.last_mut().unwrap().description,
 | 
			
		||||
                            line,
 | 
			
		||||
                        );
 | 
			
		||||
                    }
 | 
			
		||||
                    LastItem::SubGroup => {
 | 
			
		||||
                        util::append_line(
 | 
			
		||||
                            &mut options.subcommand_groups.last_mut().unwrap().description,
 | 
			
		||||
                            line,
 | 
			
		||||
                        );
 | 
			
		||||
                    }
 | 
			
		||||
                    LastItem::SubGroupFun => {
 | 
			
		||||
                        util::append_line(
 | 
			
		||||
                            &mut options
 | 
			
		||||
                                .subcommand_groups
 | 
			
		||||
                                .last_mut()
 | 
			
		||||
                                .unwrap()
 | 
			
		||||
                                .subcommands
 | 
			
		||||
                                .last_mut()
 | 
			
		||||
                                .unwrap()
 | 
			
		||||
                                .description,
 | 
			
		||||
                            line,
 | 
			
		||||
                        );
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            "hook" => {
 | 
			
		||||
                hooks.push(propagate_err!(attributes::parse(values)));
 | 
			
		||||
            }
 | 
			
		||||
            _ => {
 | 
			
		||||
                match_options!(name, values, options, span => [
 | 
			
		||||
                    aliases;
 | 
			
		||||
                    group;
 | 
			
		||||
                    can_blacklist;
 | 
			
		||||
                    supports_dm
 | 
			
		||||
                ]);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let Options {
 | 
			
		||||
        aliases,
 | 
			
		||||
        description,
 | 
			
		||||
        group,
 | 
			
		||||
        examples,
 | 
			
		||||
        can_blacklist,
 | 
			
		||||
        supports_dm,
 | 
			
		||||
        mut cmd_args,
 | 
			
		||||
        mut subcommands,
 | 
			
		||||
        mut subcommand_groups,
 | 
			
		||||
    } = options;
 | 
			
		||||
 | 
			
		||||
    let visibility = fun.visibility;
 | 
			
		||||
    let name = fun.name.clone();
 | 
			
		||||
    let body = fun.body;
 | 
			
		||||
 | 
			
		||||
    let root_ident = name.with_suffix(COMMAND);
 | 
			
		||||
 | 
			
		||||
    let command_path = quote!(crate::framework::Command);
 | 
			
		||||
 | 
			
		||||
    populate_fut_lifetimes_on_refs(&mut fun.args);
 | 
			
		||||
 | 
			
		||||
    let mut subcommand_group_idents = subcommand_groups
 | 
			
		||||
        .iter()
 | 
			
		||||
        .map(|subcommand| {
 | 
			
		||||
            root_ident
 | 
			
		||||
                .with_suffix(subcommand.name.replace("-", "_").as_str())
 | 
			
		||||
                .with_suffix(SUBCOMMAND_GROUP)
 | 
			
		||||
        })
 | 
			
		||||
        .collect::<Vec<Ident>>();
 | 
			
		||||
 | 
			
		||||
    let mut subcommand_idents = subcommands
 | 
			
		||||
        .iter()
 | 
			
		||||
        .map(|subcommand| {
 | 
			
		||||
            root_ident
 | 
			
		||||
                .with_suffix(subcommand.name.replace("-", "_").as_str())
 | 
			
		||||
                .with_suffix(SUBCOMMAND)
 | 
			
		||||
        })
 | 
			
		||||
        .collect::<Vec<Ident>>();
 | 
			
		||||
 | 
			
		||||
    let mut arg_idents = cmd_args
 | 
			
		||||
        .iter()
 | 
			
		||||
        .map(|arg| root_ident.with_suffix(arg.name.replace("-", "_").as_str()).with_suffix(ARG))
 | 
			
		||||
        .collect::<Vec<Ident>>();
 | 
			
		||||
 | 
			
		||||
    let mut tokens = quote! {};
 | 
			
		||||
 | 
			
		||||
    tokens.extend(
 | 
			
		||||
        subcommand_groups
 | 
			
		||||
            .iter_mut()
 | 
			
		||||
            .zip(subcommand_group_idents.iter())
 | 
			
		||||
            .map(|(group, group_ident)| group.as_tokens(group_ident))
 | 
			
		||||
            .fold(quote! {}, |mut a, b| {
 | 
			
		||||
                a.extend(b);
 | 
			
		||||
                a
 | 
			
		||||
            }),
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    tokens.extend(
 | 
			
		||||
        subcommands
 | 
			
		||||
            .iter_mut()
 | 
			
		||||
            .zip(subcommand_idents.iter())
 | 
			
		||||
            .map(|(subcommand, sc_ident)| subcommand.as_tokens(sc_ident))
 | 
			
		||||
            .fold(quote! {}, |mut a, b| {
 | 
			
		||||
                a.extend(b);
 | 
			
		||||
                a
 | 
			
		||||
            }),
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    tokens.extend(
 | 
			
		||||
        cmd_args.iter_mut().zip(arg_idents.iter()).map(|(arg, ident)| arg.as_tokens(ident)).fold(
 | 
			
		||||
            quote! {},
 | 
			
		||||
            |mut a, b| {
 | 
			
		||||
                a.extend(b);
 | 
			
		||||
                a
 | 
			
		||||
            },
 | 
			
		||||
        ),
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    arg_idents.append(&mut subcommand_group_idents);
 | 
			
		||||
    arg_idents.append(&mut subcommand_idents);
 | 
			
		||||
 | 
			
		||||
    let args = fun.args;
 | 
			
		||||
 | 
			
		||||
    let variant = if args.len() == 2 {
 | 
			
		||||
        quote!(crate::framework::CommandFnType::Multi)
 | 
			
		||||
    } else {
 | 
			
		||||
        let string: Type = parse_quote!(String);
 | 
			
		||||
 | 
			
		||||
        let final_arg = args.get(2).unwrap();
 | 
			
		||||
 | 
			
		||||
        if final_arg.kind == string {
 | 
			
		||||
            quote!(crate::framework::CommandFnType::Text)
 | 
			
		||||
        } else {
 | 
			
		||||
            quote!(crate::framework::CommandFnType::Slash)
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    tokens.extend(quote! {
 | 
			
		||||
        #[allow(missing_docs)]
 | 
			
		||||
        pub static #root_ident: #command_path = #command_path {
 | 
			
		||||
            fun: #variant(#name),
 | 
			
		||||
            names: &[#_name, #(#aliases),*],
 | 
			
		||||
            desc: #description,
 | 
			
		||||
            group: #group,
 | 
			
		||||
            examples: &[#(#examples),*],
 | 
			
		||||
            can_blacklist: #can_blacklist,
 | 
			
		||||
            supports_dm: #supports_dm,
 | 
			
		||||
            args: &[#(&#arg_idents),*],
 | 
			
		||||
            hooks: &[#(&#hooks),*],
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        #[allow(missing_docs)]
 | 
			
		||||
        #visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, ()> {
 | 
			
		||||
            use ::serenity::futures::future::FutureExt;
 | 
			
		||||
 | 
			
		||||
            async move {
 | 
			
		||||
                #(#body)*;
 | 
			
		||||
            }.boxed()
 | 
			
		||||
        }
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    tokens.into()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[proc_macro_attribute]
 | 
			
		||||
pub fn check(_attr: TokenStream, input: TokenStream) -> TokenStream {
 | 
			
		||||
    let mut fun = parse_macro_input!(input as CommandFun);
 | 
			
		||||
 | 
			
		||||
    let n = fun.name.clone();
 | 
			
		||||
    let name = n.with_suffix(HOOK);
 | 
			
		||||
    let fn_name = n.with_suffix(CHECK);
 | 
			
		||||
    let visibility = fun.visibility;
 | 
			
		||||
 | 
			
		||||
    let body = fun.body;
 | 
			
		||||
    let ret = fun.ret;
 | 
			
		||||
    populate_fut_lifetimes_on_refs(&mut fun.args);
 | 
			
		||||
    let args = fun.args;
 | 
			
		||||
 | 
			
		||||
    let hook_path = quote!(crate::framework::Hook);
 | 
			
		||||
    let uuid = Uuid::new_v4().as_u128();
 | 
			
		||||
 | 
			
		||||
    (quote! {
 | 
			
		||||
        #[allow(missing_docs)]
 | 
			
		||||
        #visibility fn #fn_name<'fut>(#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> {
 | 
			
		||||
            use ::serenity::futures::future::FutureExt;
 | 
			
		||||
 | 
			
		||||
            async move {
 | 
			
		||||
                let _output: #ret = { #(#body)* };
 | 
			
		||||
                #[allow(unreachable_code)]
 | 
			
		||||
                _output
 | 
			
		||||
            }.boxed()
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        #[allow(missing_docs)]
 | 
			
		||||
        pub static #name: #hook_path = #hook_path {
 | 
			
		||||
            fun: #fn_name,
 | 
			
		||||
            uuid: #uuid,
 | 
			
		||||
        };
 | 
			
		||||
    })
 | 
			
		||||
    .into()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										331
									
								
								command_attributes/src/structures.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										331
									
								
								command_attributes/src/structures.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,331 @@
 | 
			
		||||
use proc_macro2::TokenStream as TokenStream2;
 | 
			
		||||
use quote::{quote, ToTokens};
 | 
			
		||||
use syn::{
 | 
			
		||||
    braced,
 | 
			
		||||
    parse::{Error, Parse, ParseStream, Result},
 | 
			
		||||
    spanned::Spanned,
 | 
			
		||||
    Attribute, Block, FnArg, Ident, Pat, ReturnType, Stmt, Token, Type, Visibility,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    consts::{ARG, SUBCOMMAND},
 | 
			
		||||
    util::{Argument, IdentExt2, Parenthesised},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
fn parse_argument(arg: FnArg) -> Result<Argument> {
 | 
			
		||||
    match arg {
 | 
			
		||||
        FnArg::Typed(typed) => {
 | 
			
		||||
            let pat = typed.pat;
 | 
			
		||||
            let kind = typed.ty;
 | 
			
		||||
 | 
			
		||||
            match *pat {
 | 
			
		||||
                Pat::Ident(id) => {
 | 
			
		||||
                    let name = id.ident;
 | 
			
		||||
                    let mutable = id.mutability;
 | 
			
		||||
 | 
			
		||||
                    Ok(Argument { mutable, name, kind: *kind })
 | 
			
		||||
                }
 | 
			
		||||
                Pat::Wild(wild) => {
 | 
			
		||||
                    let token = wild.underscore_token;
 | 
			
		||||
 | 
			
		||||
                    let name = Ident::new("_", token.spans[0]);
 | 
			
		||||
 | 
			
		||||
                    Ok(Argument { mutable: None, name, kind: *kind })
 | 
			
		||||
                }
 | 
			
		||||
                _ => Err(Error::new(pat.span(), format_args!("unsupported pattern: {:?}", pat))),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        FnArg::Receiver(_) => {
 | 
			
		||||
            Err(Error::new(arg.span(), format_args!("`self` arguments are prohibited: {:?}", arg)))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct CommandFun {
 | 
			
		||||
    /// `#[...]`-style attributes.
 | 
			
		||||
    pub attributes: Vec<Attribute>,
 | 
			
		||||
    /// Populated cooked attributes. These are attributes outside of the realm of this crate's procedural macros
 | 
			
		||||
    /// and will appear in generated output.
 | 
			
		||||
    pub visibility: Visibility,
 | 
			
		||||
    pub name: Ident,
 | 
			
		||||
    pub args: Vec<Argument>,
 | 
			
		||||
    pub ret: Type,
 | 
			
		||||
    pub body: Vec<Stmt>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Parse for CommandFun {
 | 
			
		||||
    fn parse(input: ParseStream<'_>) -> Result<Self> {
 | 
			
		||||
        let attributes = input.call(Attribute::parse_outer)?;
 | 
			
		||||
 | 
			
		||||
        let visibility = input.parse::<Visibility>()?;
 | 
			
		||||
 | 
			
		||||
        input.parse::<Token![async]>()?;
 | 
			
		||||
 | 
			
		||||
        input.parse::<Token![fn]>()?;
 | 
			
		||||
        let name = input.parse()?;
 | 
			
		||||
 | 
			
		||||
        // (...)
 | 
			
		||||
        let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?;
 | 
			
		||||
 | 
			
		||||
        let ret = match input.parse::<ReturnType>()? {
 | 
			
		||||
            ReturnType::Type(_, t) => (*t).clone(),
 | 
			
		||||
            ReturnType::Default => Type::Verbatim(quote!(())),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        // { ... }
 | 
			
		||||
        let bcont;
 | 
			
		||||
        braced!(bcont in input);
 | 
			
		||||
        let body = bcont.call(Block::parse_within)?;
 | 
			
		||||
 | 
			
		||||
        let args = args.into_iter().map(parse_argument).collect::<Result<Vec<_>>>()?;
 | 
			
		||||
 | 
			
		||||
        Ok(Self { attributes, visibility, name, args, ret, body })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToTokens for CommandFun {
 | 
			
		||||
    fn to_tokens(&self, stream: &mut TokenStream2) {
 | 
			
		||||
        let Self { attributes: _, visibility, name, args, ret, body } = self;
 | 
			
		||||
 | 
			
		||||
        stream.extend(quote! {
 | 
			
		||||
            #visibility async fn #name (#(#args),*) -> #ret {
 | 
			
		||||
                #(#body)*
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub(crate) enum ApplicationCommandOptionType {
 | 
			
		||||
    SubCommand,
 | 
			
		||||
    SubCommandGroup,
 | 
			
		||||
    String,
 | 
			
		||||
    Integer,
 | 
			
		||||
    Boolean,
 | 
			
		||||
    User,
 | 
			
		||||
    Channel,
 | 
			
		||||
    Role,
 | 
			
		||||
    Mentionable,
 | 
			
		||||
    Number,
 | 
			
		||||
    Unknown,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ApplicationCommandOptionType {
 | 
			
		||||
    pub fn from_str(s: String) -> Self {
 | 
			
		||||
        match s.as_str() {
 | 
			
		||||
            "SubCommand" => Self::SubCommand,
 | 
			
		||||
            "SubCommandGroup" => Self::SubCommandGroup,
 | 
			
		||||
            "String" => Self::String,
 | 
			
		||||
            "Integer" => Self::Integer,
 | 
			
		||||
            "Boolean" => Self::Boolean,
 | 
			
		||||
            "User" => Self::User,
 | 
			
		||||
            "Channel" => Self::Channel,
 | 
			
		||||
            "Role" => Self::Role,
 | 
			
		||||
            "Mentionable" => Self::Mentionable,
 | 
			
		||||
            "Number" => Self::Number,
 | 
			
		||||
            _ => Self::Unknown,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToTokens for ApplicationCommandOptionType {
 | 
			
		||||
    fn to_tokens(&self, stream: &mut TokenStream2) {
 | 
			
		||||
        let path = quote!(
 | 
			
		||||
            serenity::model::interactions::application_command::ApplicationCommandOptionType
 | 
			
		||||
        );
 | 
			
		||||
        let variant = match self {
 | 
			
		||||
            ApplicationCommandOptionType::SubCommand => quote!(SubCommand),
 | 
			
		||||
            ApplicationCommandOptionType::SubCommandGroup => quote!(SubCommandGroup),
 | 
			
		||||
            ApplicationCommandOptionType::String => quote!(String),
 | 
			
		||||
            ApplicationCommandOptionType::Integer => quote!(Integer),
 | 
			
		||||
            ApplicationCommandOptionType::Boolean => quote!(Boolean),
 | 
			
		||||
            ApplicationCommandOptionType::User => quote!(User),
 | 
			
		||||
            ApplicationCommandOptionType::Channel => quote!(Channel),
 | 
			
		||||
            ApplicationCommandOptionType::Role => quote!(Role),
 | 
			
		||||
            ApplicationCommandOptionType::Mentionable => quote!(Mentionable),
 | 
			
		||||
            ApplicationCommandOptionType::Number => quote!(Number),
 | 
			
		||||
            ApplicationCommandOptionType::Unknown => quote!(Unknown),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        stream.extend(quote! {
 | 
			
		||||
            #path::#variant
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub(crate) struct Arg {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub description: String,
 | 
			
		||||
    pub kind: ApplicationCommandOptionType,
 | 
			
		||||
    pub required: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Arg {
 | 
			
		||||
    pub fn as_tokens(&self, ident: &Ident) -> TokenStream2 {
 | 
			
		||||
        let arg_path = quote!(crate::framework::Arg);
 | 
			
		||||
        let Arg { name, description, kind, required } = self;
 | 
			
		||||
 | 
			
		||||
        quote! {
 | 
			
		||||
            #[allow(missing_docs)]
 | 
			
		||||
            pub static #ident: #arg_path = #arg_path {
 | 
			
		||||
                name: #name,
 | 
			
		||||
                description: #description,
 | 
			
		||||
                kind: #kind,
 | 
			
		||||
                required: #required,
 | 
			
		||||
                options: &[]
 | 
			
		||||
            };
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for Arg {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            name: String::new(),
 | 
			
		||||
            description: String::new(),
 | 
			
		||||
            kind: ApplicationCommandOptionType::String,
 | 
			
		||||
            required: false,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub(crate) struct Subcommand {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub description: String,
 | 
			
		||||
    pub cmd_args: Vec<Arg>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Subcommand {
 | 
			
		||||
    pub fn as_tokens(&mut self, ident: &Ident) -> TokenStream2 {
 | 
			
		||||
        let arg_path = quote!(crate::framework::Arg);
 | 
			
		||||
        let subcommand_path = ApplicationCommandOptionType::SubCommand;
 | 
			
		||||
 | 
			
		||||
        let arg_idents = self
 | 
			
		||||
            .cmd_args
 | 
			
		||||
            .iter()
 | 
			
		||||
            .map(|arg| ident.with_suffix(arg.name.as_str()).with_suffix(ARG))
 | 
			
		||||
            .collect::<Vec<Ident>>();
 | 
			
		||||
 | 
			
		||||
        let mut tokens = self
 | 
			
		||||
            .cmd_args
 | 
			
		||||
            .iter_mut()
 | 
			
		||||
            .zip(arg_idents.iter())
 | 
			
		||||
            .map(|(arg, ident)| arg.as_tokens(ident))
 | 
			
		||||
            .fold(quote! {}, |mut a, b| {
 | 
			
		||||
                a.extend(b);
 | 
			
		||||
                a
 | 
			
		||||
            });
 | 
			
		||||
 | 
			
		||||
        let Subcommand { name, description, .. } = self;
 | 
			
		||||
 | 
			
		||||
        tokens.extend(quote! {
 | 
			
		||||
            #[allow(missing_docs)]
 | 
			
		||||
            pub static #ident: #arg_path = #arg_path {
 | 
			
		||||
                name: #name,
 | 
			
		||||
                description: #description,
 | 
			
		||||
                kind: #subcommand_path,
 | 
			
		||||
                required: false,
 | 
			
		||||
                options: &[#(&#arg_idents),*],
 | 
			
		||||
            };
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        tokens
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for Subcommand {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self { name: String::new(), description: String::new(), cmd_args: vec![] }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Subcommand {
 | 
			
		||||
    pub(crate) fn new(name: String) -> Self {
 | 
			
		||||
        Self { name, ..Default::default() }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub(crate) struct SubcommandGroup {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub description: String,
 | 
			
		||||
    pub subcommands: Vec<Subcommand>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl SubcommandGroup {
 | 
			
		||||
    pub fn as_tokens(&mut self, ident: &Ident) -> TokenStream2 {
 | 
			
		||||
        let arg_path = quote!(crate::framework::Arg);
 | 
			
		||||
        let subcommand_group_path = ApplicationCommandOptionType::SubCommandGroup;
 | 
			
		||||
 | 
			
		||||
        let arg_idents = self
 | 
			
		||||
            .subcommands
 | 
			
		||||
            .iter()
 | 
			
		||||
            .map(|arg| {
 | 
			
		||||
                ident
 | 
			
		||||
                    .with_suffix(self.name.as_str())
 | 
			
		||||
                    .with_suffix(arg.name.as_str())
 | 
			
		||||
                    .with_suffix(SUBCOMMAND)
 | 
			
		||||
            })
 | 
			
		||||
            .collect::<Vec<Ident>>();
 | 
			
		||||
 | 
			
		||||
        let mut tokens = self
 | 
			
		||||
            .subcommands
 | 
			
		||||
            .iter_mut()
 | 
			
		||||
            .zip(arg_idents.iter())
 | 
			
		||||
            .map(|(subcommand, ident)| subcommand.as_tokens(ident))
 | 
			
		||||
            .fold(quote! {}, |mut a, b| {
 | 
			
		||||
                a.extend(b);
 | 
			
		||||
                a
 | 
			
		||||
            });
 | 
			
		||||
 | 
			
		||||
        let SubcommandGroup { name, description, .. } = self;
 | 
			
		||||
 | 
			
		||||
        tokens.extend(quote! {
 | 
			
		||||
            #[allow(missing_docs)]
 | 
			
		||||
            pub static #ident: #arg_path = #arg_path {
 | 
			
		||||
                name: #name,
 | 
			
		||||
                description: #description,
 | 
			
		||||
                kind: #subcommand_group_path,
 | 
			
		||||
                required: false,
 | 
			
		||||
                options: &[#(&#arg_idents),*],
 | 
			
		||||
            };
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        tokens
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for SubcommandGroup {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self { name: String::new(), description: String::new(), subcommands: vec![] }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl SubcommandGroup {
 | 
			
		||||
    pub(crate) fn new(name: String) -> Self {
 | 
			
		||||
        Self { name, ..Default::default() }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default)]
 | 
			
		||||
pub(crate) struct Options {
 | 
			
		||||
    pub aliases: Vec<String>,
 | 
			
		||||
    pub description: String,
 | 
			
		||||
    pub group: String,
 | 
			
		||||
    pub examples: Vec<String>,
 | 
			
		||||
    pub can_blacklist: bool,
 | 
			
		||||
    pub supports_dm: bool,
 | 
			
		||||
    pub cmd_args: Vec<Arg>,
 | 
			
		||||
    pub subcommands: Vec<Subcommand>,
 | 
			
		||||
    pub subcommand_groups: Vec<SubcommandGroup>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Options {
 | 
			
		||||
    #[inline]
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self { group: "None".to_string(), ..Default::default() }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,6 +1,5 @@
 | 
			
		||||
use proc_macro::TokenStream;
 | 
			
		||||
use proc_macro2::Span;
 | 
			
		||||
use proc_macro2::TokenStream as TokenStream2;
 | 
			
		||||
use proc_macro2::{Span, TokenStream as TokenStream2};
 | 
			
		||||
use quote::{format_ident, quote, ToTokens};
 | 
			
		||||
use syn::{
 | 
			
		||||
    braced, bracketed, parenthesized,
 | 
			
		||||
@@ -158,3 +157,20 @@ pub fn populate_fut_lifetimes_on_refs(args: &mut Vec<Argument>) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn append_line(desc: &mut String, mut line: String) {
 | 
			
		||||
    if line.starts_with(' ') {
 | 
			
		||||
        line.remove(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    match line.rfind("\\$") {
 | 
			
		||||
        Some(i) => {
 | 
			
		||||
            desc.push_str(line[..i].trim_end());
 | 
			
		||||
            desc.push(' ');
 | 
			
		||||
        }
 | 
			
		||||
        None => {
 | 
			
		||||
            desc.push_str(&line);
 | 
			
		||||
            desc.push('\n');
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,3 +1,5 @@
 | 
			
		||||
CREATE DATABASE IF NOT EXISTS reminders;
 | 
			
		||||
 | 
			
		||||
SET FOREIGN_KEY_CHECKS=0;
 | 
			
		||||
 | 
			
		||||
USE reminders;
 | 
			
		||||
							
								
								
									
										160
									
								
								migration/01-reminder_message_embed.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								migration/01-reminder_message_embed.sql
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										13
									
								
								migration/02-macro.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								migration/02-macro.sql
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
			
		||||
USE reminders;
 | 
			
		||||
 | 
			
		||||
CREATE TABLE macro (
 | 
			
		||||
    id INT UNSIGNED AUTO_INCREMENT,
 | 
			
		||||
    guild_id INT UNSIGNED NOT NULL,
 | 
			
		||||
 | 
			
		||||
    name VARCHAR(100) NOT NULL,
 | 
			
		||||
    description VARCHAR(100),
 | 
			
		||||
    commands TEXT NOT NULL,
 | 
			
		||||
 | 
			
		||||
    FOREIGN KEY (guild_id) REFERENCES guilds(id) ON DELETE CASCADE,
 | 
			
		||||
    PRIMARY KEY (id)
 | 
			
		||||
);
 | 
			
		||||
@@ -1,5 +0,0 @@
 | 
			
		||||
pub mod suffixes {
 | 
			
		||||
    pub const COMMAND: &str = "COMMAND";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub use self::suffixes::*;
 | 
			
		||||
@@ -1,102 +0,0 @@
 | 
			
		||||
#![deny(rust_2018_idioms)]
 | 
			
		||||
// FIXME: Remove this in a foreseeable future.
 | 
			
		||||
// Currently exists for backwards compatibility to previous Rust versions.
 | 
			
		||||
#![recursion_limit = "128"]
 | 
			
		||||
 | 
			
		||||
#[allow(unused_extern_crates)]
 | 
			
		||||
extern crate proc_macro;
 | 
			
		||||
 | 
			
		||||
use proc_macro::TokenStream;
 | 
			
		||||
use quote::quote;
 | 
			
		||||
use syn::{parse::Error, parse_macro_input, spanned::Spanned, Lit};
 | 
			
		||||
 | 
			
		||||
pub(crate) mod attributes;
 | 
			
		||||
pub(crate) mod consts;
 | 
			
		||||
pub(crate) mod structures;
 | 
			
		||||
 | 
			
		||||
#[macro_use]
 | 
			
		||||
pub(crate) mod util;
 | 
			
		||||
 | 
			
		||||
use attributes::*;
 | 
			
		||||
use consts::*;
 | 
			
		||||
use structures::*;
 | 
			
		||||
use util::*;
 | 
			
		||||
 | 
			
		||||
macro_rules! match_options {
 | 
			
		||||
    ($v:expr, $values:ident, $options:ident, $span:expr => [$($name:ident);*]) => {
 | 
			
		||||
        match $v {
 | 
			
		||||
            $(
 | 
			
		||||
                stringify!($name) => $options.$name = propagate_err!($crate::attributes::parse($values)),
 | 
			
		||||
            )*
 | 
			
		||||
            _ => {
 | 
			
		||||
                return Error::new($span, format_args!("invalid attribute: {:?}", $v))
 | 
			
		||||
                    .to_compile_error()
 | 
			
		||||
                    .into();
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[proc_macro_attribute]
 | 
			
		||||
pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
 | 
			
		||||
    let mut fun = parse_macro_input!(input as CommandFun);
 | 
			
		||||
 | 
			
		||||
    let lit_name = if !attr.is_empty() {
 | 
			
		||||
        parse_macro_input!(attr as Lit).to_str()
 | 
			
		||||
    } else {
 | 
			
		||||
        fun.name.to_string()
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let mut options = Options::new();
 | 
			
		||||
 | 
			
		||||
    for attribute in &fun.attributes {
 | 
			
		||||
        let span = attribute.span();
 | 
			
		||||
        let values = propagate_err!(parse_values(attribute));
 | 
			
		||||
 | 
			
		||||
        let name = values.name.to_string();
 | 
			
		||||
        let name = &name[..];
 | 
			
		||||
 | 
			
		||||
        match_options!(name, values, options, span => [
 | 
			
		||||
            permission_level;
 | 
			
		||||
            supports_dm;
 | 
			
		||||
            can_blacklist
 | 
			
		||||
        ]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let Options {
 | 
			
		||||
        permission_level,
 | 
			
		||||
        supports_dm,
 | 
			
		||||
        can_blacklist,
 | 
			
		||||
    } = options;
 | 
			
		||||
 | 
			
		||||
    let visibility = fun.visibility;
 | 
			
		||||
    let name = fun.name.clone();
 | 
			
		||||
    let body = fun.body;
 | 
			
		||||
 | 
			
		||||
    let n = name.with_suffix(COMMAND);
 | 
			
		||||
 | 
			
		||||
    let cooked = fun.cooked.clone();
 | 
			
		||||
 | 
			
		||||
    let command_path = quote!(crate::framework::Command);
 | 
			
		||||
 | 
			
		||||
    populate_fut_lifetimes_on_refs(&mut fun.args);
 | 
			
		||||
    let args = fun.args;
 | 
			
		||||
 | 
			
		||||
    (quote! {
 | 
			
		||||
        #(#cooked)*
 | 
			
		||||
        pub static #n: #command_path = #command_path {
 | 
			
		||||
            func: #name,
 | 
			
		||||
            name: #lit_name,
 | 
			
		||||
            required_perms: #permission_level,
 | 
			
		||||
            supports_dm: #supports_dm,
 | 
			
		||||
            can_blacklist: #can_blacklist,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        #visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, ()> {
 | 
			
		||||
            use ::serenity::futures::future::FutureExt;
 | 
			
		||||
 | 
			
		||||
            async move { #(#body)* }.boxed()
 | 
			
		||||
        }
 | 
			
		||||
    })
 | 
			
		||||
    .into()
 | 
			
		||||
}
 | 
			
		||||
@@ -1,231 +0,0 @@
 | 
			
		||||
use crate::util::{Argument, Parenthesised};
 | 
			
		||||
use proc_macro2::Span;
 | 
			
		||||
use proc_macro2::TokenStream as TokenStream2;
 | 
			
		||||
use quote::{quote, ToTokens};
 | 
			
		||||
use syn::{
 | 
			
		||||
    braced,
 | 
			
		||||
    parse::{Error, Parse, ParseStream, Result},
 | 
			
		||||
    spanned::Spanned,
 | 
			
		||||
    Attribute, Block, FnArg, Ident, Pat, Path, PathSegment, Stmt, Token, Visibility,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
fn parse_argument(arg: FnArg) -> Result<Argument> {
 | 
			
		||||
    match arg {
 | 
			
		||||
        FnArg::Typed(typed) => {
 | 
			
		||||
            let pat = typed.pat;
 | 
			
		||||
            let kind = typed.ty;
 | 
			
		||||
 | 
			
		||||
            match *pat {
 | 
			
		||||
                Pat::Ident(id) => {
 | 
			
		||||
                    let name = id.ident;
 | 
			
		||||
                    let mutable = id.mutability;
 | 
			
		||||
 | 
			
		||||
                    Ok(Argument {
 | 
			
		||||
                        mutable,
 | 
			
		||||
                        name,
 | 
			
		||||
                        kind: *kind,
 | 
			
		||||
                    })
 | 
			
		||||
                }
 | 
			
		||||
                Pat::Wild(wild) => {
 | 
			
		||||
                    let token = wild.underscore_token;
 | 
			
		||||
 | 
			
		||||
                    let name = Ident::new("_", token.spans[0]);
 | 
			
		||||
 | 
			
		||||
                    Ok(Argument {
 | 
			
		||||
                        mutable: None,
 | 
			
		||||
                        name,
 | 
			
		||||
                        kind: *kind,
 | 
			
		||||
                    })
 | 
			
		||||
                }
 | 
			
		||||
                _ => Err(Error::new(
 | 
			
		||||
                    pat.span(),
 | 
			
		||||
                    format_args!("unsupported pattern: {:?}", pat),
 | 
			
		||||
                )),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        FnArg::Receiver(_) => Err(Error::new(
 | 
			
		||||
            arg.span(),
 | 
			
		||||
            format_args!("`self` arguments are prohibited: {:?}", arg),
 | 
			
		||||
        )),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Test if the attribute is cooked.
 | 
			
		||||
fn is_cooked(attr: &Attribute) -> bool {
 | 
			
		||||
    const COOKED_ATTRIBUTE_NAMES: &[&str] = &[
 | 
			
		||||
        "cfg", "cfg_attr", "doc", "derive", "inline", "allow", "warn", "deny", "forbid",
 | 
			
		||||
    ];
 | 
			
		||||
 | 
			
		||||
    COOKED_ATTRIBUTE_NAMES.iter().any(|n| attr.path.is_ident(n))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Removes cooked attributes from a vector of attributes. Uncooked attributes are left in the vector.
 | 
			
		||||
///
 | 
			
		||||
/// # Return
 | 
			
		||||
///
 | 
			
		||||
/// Returns a vector of cooked attributes that have been removed from the input vector.
 | 
			
		||||
fn remove_cooked(attrs: &mut Vec<Attribute>) -> Vec<Attribute> {
 | 
			
		||||
    let mut cooked = Vec::new();
 | 
			
		||||
 | 
			
		||||
    // FIXME: Replace with `Vec::drain_filter` once it is stable.
 | 
			
		||||
    let mut i = 0;
 | 
			
		||||
    while i < attrs.len() {
 | 
			
		||||
        if !is_cooked(&attrs[i]) {
 | 
			
		||||
            i += 1;
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        cooked.push(attrs.remove(i));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    cooked
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct CommandFun {
 | 
			
		||||
    /// `#[...]`-style attributes.
 | 
			
		||||
    pub attributes: Vec<Attribute>,
 | 
			
		||||
    /// Populated cooked attributes. These are attributes outside of the realm of this crate's procedural macros
 | 
			
		||||
    /// and will appear in generated output.
 | 
			
		||||
    pub cooked: Vec<Attribute>,
 | 
			
		||||
    pub visibility: Visibility,
 | 
			
		||||
    pub name: Ident,
 | 
			
		||||
    pub args: Vec<Argument>,
 | 
			
		||||
    pub body: Vec<Stmt>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Parse for CommandFun {
 | 
			
		||||
    fn parse(input: ParseStream<'_>) -> Result<Self> {
 | 
			
		||||
        let mut attributes = input.call(Attribute::parse_outer)?;
 | 
			
		||||
 | 
			
		||||
        // `#[doc = "..."]` is a cooked attribute but it is special-cased for commands.
 | 
			
		||||
        for attr in &mut attributes {
 | 
			
		||||
            // Rename documentation comment attributes (`#[doc = "..."]`) to `#[description = "..."]`.
 | 
			
		||||
            if attr.path.is_ident("doc") {
 | 
			
		||||
                attr.path = Path::from(PathSegment::from(Ident::new(
 | 
			
		||||
                    "description",
 | 
			
		||||
                    Span::call_site(),
 | 
			
		||||
                )));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let cooked = remove_cooked(&mut attributes);
 | 
			
		||||
 | 
			
		||||
        let visibility = input.parse::<Visibility>()?;
 | 
			
		||||
 | 
			
		||||
        input.parse::<Token![async]>()?;
 | 
			
		||||
 | 
			
		||||
        input.parse::<Token![fn]>()?;
 | 
			
		||||
        let name = input.parse()?;
 | 
			
		||||
 | 
			
		||||
        // (...)
 | 
			
		||||
        let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?;
 | 
			
		||||
 | 
			
		||||
        // { ... }
 | 
			
		||||
        let bcont;
 | 
			
		||||
        braced!(bcont in input);
 | 
			
		||||
        let body = bcont.call(Block::parse_within)?;
 | 
			
		||||
 | 
			
		||||
        let args = args
 | 
			
		||||
            .into_iter()
 | 
			
		||||
            .map(parse_argument)
 | 
			
		||||
            .collect::<Result<Vec<_>>>()?;
 | 
			
		||||
 | 
			
		||||
        Ok(Self {
 | 
			
		||||
            attributes,
 | 
			
		||||
            cooked,
 | 
			
		||||
            visibility,
 | 
			
		||||
            name,
 | 
			
		||||
            args,
 | 
			
		||||
            body,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToTokens for CommandFun {
 | 
			
		||||
    fn to_tokens(&self, stream: &mut TokenStream2) {
 | 
			
		||||
        let Self {
 | 
			
		||||
            attributes: _,
 | 
			
		||||
            cooked,
 | 
			
		||||
            visibility,
 | 
			
		||||
            name,
 | 
			
		||||
            args,
 | 
			
		||||
            body,
 | 
			
		||||
        } = self;
 | 
			
		||||
 | 
			
		||||
        stream.extend(quote! {
 | 
			
		||||
            #(#cooked)*
 | 
			
		||||
            #visibility async fn #name (#(#args),*) -> () {
 | 
			
		||||
                #(#body)*
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum PermissionLevel {
 | 
			
		||||
    Unrestricted,
 | 
			
		||||
    Managed,
 | 
			
		||||
    Restricted,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for PermissionLevel {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self::Unrestricted
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl PermissionLevel {
 | 
			
		||||
    pub fn from_str(s: &str) -> Option<Self> {
 | 
			
		||||
        Some(match s.to_uppercase().as_str() {
 | 
			
		||||
            "UNRESTRICTED" => Self::Unrestricted,
 | 
			
		||||
            "MANAGED" => Self::Managed,
 | 
			
		||||
            "RESTRICTED" => Self::Restricted,
 | 
			
		||||
            _ => return None,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToTokens for PermissionLevel {
 | 
			
		||||
    fn to_tokens(&self, stream: &mut TokenStream2) {
 | 
			
		||||
        let path = quote!(crate::framework::PermissionLevel);
 | 
			
		||||
        let variant;
 | 
			
		||||
 | 
			
		||||
        match self {
 | 
			
		||||
            Self::Unrestricted => {
 | 
			
		||||
                variant = quote!(Unrestricted);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            Self::Managed => {
 | 
			
		||||
                variant = quote!(Managed);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            Self::Restricted => {
 | 
			
		||||
                variant = quote!(Restricted);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        stream.extend(quote! {
 | 
			
		||||
            #path::#variant
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default)]
 | 
			
		||||
pub struct Options {
 | 
			
		||||
    pub permission_level: PermissionLevel,
 | 
			
		||||
    pub supports_dm: bool,
 | 
			
		||||
    pub can_blacklist: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Options {
 | 
			
		||||
    #[inline]
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        let mut options = Self::default();
 | 
			
		||||
 | 
			
		||||
        options.can_blacklist = true;
 | 
			
		||||
        options.supports_dm = true;
 | 
			
		||||
 | 
			
		||||
        options
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										3
									
								
								rustfmt.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								rustfmt.toml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
imports_granularity = "Crate"
 | 
			
		||||
group_imports = "StdExternalCrate"
 | 
			
		||||
use_small_heuristics = "Max"
 | 
			
		||||
@@ -1,37 +1,15 @@
 | 
			
		||||
use regex_command_attr::command;
 | 
			
		||||
 | 
			
		||||
use serenity::{client::Context, model::channel::Message};
 | 
			
		||||
 | 
			
		||||
use chrono::offset::Utc;
 | 
			
		||||
use regex_command_attr::command;
 | 
			
		||||
use serenity::{builder::CreateEmbedFooter, client::Context};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    command_help, consts::DEFAULT_PREFIX, get_ctx_data, language_manager::LanguageManager,
 | 
			
		||||
    models::UserData, FrameworkCtx, THEME_COLOR,
 | 
			
		||||
    framework::{CommandInvoke, CreateGenericResponse},
 | 
			
		||||
    models::CtxData,
 | 
			
		||||
    THEME_COLOR,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::models::CtxGuildData;
 | 
			
		||||
use serenity::builder::CreateEmbedFooter;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use std::time::{SystemTime, UNIX_EPOCH};
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
#[can_blacklist(false)]
 | 
			
		||||
async fn ping(ctx: &Context, msg: &Message, _args: String) {
 | 
			
		||||
    let now = SystemTime::now();
 | 
			
		||||
    let since_epoch = now
 | 
			
		||||
        .duration_since(UNIX_EPOCH)
 | 
			
		||||
        .expect("Time calculated as going backwards. Very bad");
 | 
			
		||||
 | 
			
		||||
    let delta = since_epoch.as_millis() as i64 - msg.timestamp.timestamp_millis();
 | 
			
		||||
 | 
			
		||||
    let _ = msg
 | 
			
		||||
        .channel_id
 | 
			
		||||
        .say(&ctx, format!("Time taken to receive message: {}ms", delta))
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter {
 | 
			
		||||
    let shard_count = ctx.cache.shard_count().await;
 | 
			
		||||
fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter {
 | 
			
		||||
    let shard_count = ctx.cache.shard_count();
 | 
			
		||||
    let shard = ctx.shard_id;
 | 
			
		||||
 | 
			
		||||
    move |f| {
 | 
			
		||||
@@ -45,174 +23,140 @@ async fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut Cr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
#[can_blacklist(false)]
 | 
			
		||||
async fn help(ctx: &Context, msg: &Message, args: String) {
 | 
			
		||||
    async fn default_help(
 | 
			
		||||
        ctx: &Context,
 | 
			
		||||
        msg: &Message,
 | 
			
		||||
        lm: Arc<LanguageManager>,
 | 
			
		||||
        prefix: &str,
 | 
			
		||||
        language: &str,
 | 
			
		||||
    ) {
 | 
			
		||||
        let desc = lm.get(language, "help/desc").replace("{prefix}", prefix);
 | 
			
		||||
        let footer = footer(ctx).await;
 | 
			
		||||
#[description("Get an overview of the bot commands")]
 | 
			
		||||
async fn help(ctx: &Context, invoke: &mut CommandInvoke) {
 | 
			
		||||
    let footer = footer(ctx);
 | 
			
		||||
 | 
			
		||||
        let _ = msg
 | 
			
		||||
            .channel_id
 | 
			
		||||
            .send_message(ctx, |m| {
 | 
			
		||||
                m.embed(move |e| {
 | 
			
		||||
                    e.title("Help Menu")
 | 
			
		||||
                        .description(desc)
 | 
			
		||||
                        .field(
 | 
			
		||||
                            lm.get(language, "help/setup_title"),
 | 
			
		||||
                            "`lang` `timezone` `meridian`",
 | 
			
		||||
                            true,
 | 
			
		||||
                        )
 | 
			
		||||
                        .field(
 | 
			
		||||
                            lm.get(language, "help/mod_title"),
 | 
			
		||||
                            "`prefix` `blacklist` `restrict` `alias`",
 | 
			
		||||
                            true,
 | 
			
		||||
                        )
 | 
			
		||||
                        .field(
 | 
			
		||||
                            lm.get(language, "help/reminder_title"),
 | 
			
		||||
                            "`remind` `interval` `natural` `look` `countdown`",
 | 
			
		||||
                            true,
 | 
			
		||||
                        )
 | 
			
		||||
                        .field(
 | 
			
		||||
                            lm.get(language, "help/reminder_mod_title"),
 | 
			
		||||
                            "`del` `offset` `pause` `nudge`",
 | 
			
		||||
                            true,
 | 
			
		||||
                        )
 | 
			
		||||
                        .field(
 | 
			
		||||
                            lm.get(language, "help/info_title"),
 | 
			
		||||
                            "`help` `info` `donate` `clock`",
 | 
			
		||||
                            true,
 | 
			
		||||
                        )
 | 
			
		||||
                        .field(
 | 
			
		||||
                            lm.get(language, "help/todo_title"),
 | 
			
		||||
                            "`todo` `todos` `todoc`",
 | 
			
		||||
                            true,
 | 
			
		||||
                        )
 | 
			
		||||
                        .field(lm.get(language, "help/other_title"), "`timer`", true)
 | 
			
		||||
                        .footer(footer)
 | 
			
		||||
                        .color(*THEME_COLOR)
 | 
			
		||||
                })
 | 
			
		||||
            })
 | 
			
		||||
            .await;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let (pool, lm) = get_ctx_data(&ctx).await;
 | 
			
		||||
 | 
			
		||||
    let language = UserData::language_of(&msg.author, &pool);
 | 
			
		||||
    let prefix = ctx.prefix(msg.guild_id);
 | 
			
		||||
 | 
			
		||||
    if !args.is_empty() {
 | 
			
		||||
        let framework = ctx
 | 
			
		||||
            .data
 | 
			
		||||
            .read()
 | 
			
		||||
            .await
 | 
			
		||||
            .get::<FrameworkCtx>()
 | 
			
		||||
            .cloned()
 | 
			
		||||
            .expect("Could not get FrameworkCtx from data");
 | 
			
		||||
 | 
			
		||||
        let matched = framework
 | 
			
		||||
            .commands
 | 
			
		||||
            .get(args.as_str())
 | 
			
		||||
            .map(|inner| inner.name);
 | 
			
		||||
 | 
			
		||||
        if let Some(command_name) = matched {
 | 
			
		||||
            command_help(ctx, msg, lm, &prefix.await, &language.await, command_name).await
 | 
			
		||||
        } else {
 | 
			
		||||
            default_help(ctx, msg, lm, &prefix.await, &language.await).await;
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        default_help(ctx, msg, lm, &prefix.await, &language.await).await;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
async fn info(ctx: &Context, msg: &Message, _args: String) {
 | 
			
		||||
    let (pool, lm) = get_ctx_data(&ctx).await;
 | 
			
		||||
 | 
			
		||||
    let language = UserData::language_of(&msg.author, &pool);
 | 
			
		||||
    let prefix = ctx.prefix(msg.guild_id);
 | 
			
		||||
    let current_user = ctx.cache.current_user();
 | 
			
		||||
    let footer = footer(ctx).await;
 | 
			
		||||
 | 
			
		||||
    let desc = lm
 | 
			
		||||
        .get(&language.await, "info")
 | 
			
		||||
        .replacen("{user}", ¤t_user.await.name, 1)
 | 
			
		||||
        .replace("{default_prefix}", &*DEFAULT_PREFIX)
 | 
			
		||||
        .replace("{prefix}", &prefix.await);
 | 
			
		||||
 | 
			
		||||
    let _ = msg
 | 
			
		||||
        .channel_id
 | 
			
		||||
        .send_message(ctx, |m| {
 | 
			
		||||
            m.embed(move |e| {
 | 
			
		||||
                e.title("Info")
 | 
			
		||||
                    .description(desc)
 | 
			
		||||
                    .footer(footer)
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
            })
 | 
			
		||||
        })
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
async fn donate(ctx: &Context, msg: &Message, _args: String) {
 | 
			
		||||
    let (pool, lm) = get_ctx_data(&ctx).await;
 | 
			
		||||
 | 
			
		||||
    let language = UserData::language_of(&msg.author, &pool).await;
 | 
			
		||||
    let desc = lm.get(&language, "donate");
 | 
			
		||||
    let footer = footer(ctx).await;
 | 
			
		||||
 | 
			
		||||
    let _ = msg
 | 
			
		||||
        .channel_id
 | 
			
		||||
        .send_message(ctx, |m| {
 | 
			
		||||
            m.embed(move |e| {
 | 
			
		||||
                e.title("Donate")
 | 
			
		||||
                    .description(desc)
 | 
			
		||||
                    .footer(footer)
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
            })
 | 
			
		||||
        })
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
async fn dashboard(ctx: &Context, msg: &Message, _args: String) {
 | 
			
		||||
    let footer = footer(ctx).await;
 | 
			
		||||
 | 
			
		||||
    let _ = msg
 | 
			
		||||
        .channel_id
 | 
			
		||||
        .send_message(ctx, |m| {
 | 
			
		||||
            m.embed(move |e| {
 | 
			
		||||
                e.title("Dashboard")
 | 
			
		||||
                    .description("https://reminder-bot.com/dashboard")
 | 
			
		||||
                    .footer(footer)
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
            })
 | 
			
		||||
        })
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
async fn clock(ctx: &Context, msg: &Message, _args: String) {
 | 
			
		||||
    let (pool, lm) = get_ctx_data(&ctx).await;
 | 
			
		||||
 | 
			
		||||
    let language = UserData::language_of(&msg.author, &pool).await;
 | 
			
		||||
    let timezone = UserData::timezone_of(&msg.author, &pool).await;
 | 
			
		||||
    let meridian = UserData::meridian_of(&msg.author, &pool).await;
 | 
			
		||||
 | 
			
		||||
    let now = Utc::now().with_timezone(&timezone);
 | 
			
		||||
 | 
			
		||||
    let clock_display = lm.get(&language, "clock/time");
 | 
			
		||||
 | 
			
		||||
    let _ = msg
 | 
			
		||||
        .channel_id
 | 
			
		||||
        .say(
 | 
			
		||||
    let _ = invoke
 | 
			
		||||
        .respond(
 | 
			
		||||
            &ctx,
 | 
			
		||||
            clock_display.replacen("{}", &now.format(meridian.fmt_str()).to_string(), 1),
 | 
			
		||||
            CreateGenericResponse::new().embed(|e| {
 | 
			
		||||
                e.title("Help")
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
                    .description(
 | 
			
		||||
                        "__Info Commands__
 | 
			
		||||
`/help` `/info` `/donate` `/dashboard` `/clock`
 | 
			
		||||
*run these commands with no options*
 | 
			
		||||
 | 
			
		||||
__Reminder Commands__
 | 
			
		||||
`/remind` - Create a new reminder that will send a message at a certain time
 | 
			
		||||
`/timer` - Start a timer from now, that will count time passed. Also used to view and remove timers
 | 
			
		||||
 | 
			
		||||
__Reminder Management__
 | 
			
		||||
`/del` - Delete reminders
 | 
			
		||||
`/look` - View reminders
 | 
			
		||||
`/pause` - Pause all reminders on the channel
 | 
			
		||||
`/offset` - Move all reminders by a certain time
 | 
			
		||||
`/nudge` - Move all new reminders on this channel by a certain time
 | 
			
		||||
 | 
			
		||||
__Todo Commands__
 | 
			
		||||
`/todo` - Add, view and manage the server, channel or user todo lists
 | 
			
		||||
 | 
			
		||||
__Setup Commands__
 | 
			
		||||
`/timezone` - Set your timezone (necessary for `/remind` to work properly)
 | 
			
		||||
 | 
			
		||||
__Advanced Commands__
 | 
			
		||||
`/macro` - Record and replay command sequences
 | 
			
		||||
                    ",
 | 
			
		||||
                    )
 | 
			
		||||
                    .footer(footer)
 | 
			
		||||
            }),
 | 
			
		||||
        )
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
#[aliases("invite")]
 | 
			
		||||
#[description("Get information about the bot")]
 | 
			
		||||
async fn info(ctx: &Context, invoke: &mut CommandInvoke) {
 | 
			
		||||
    let footer = footer(ctx);
 | 
			
		||||
 | 
			
		||||
    let _ = invoke
 | 
			
		||||
        .respond(
 | 
			
		||||
            ctx.http.clone(),
 | 
			
		||||
            CreateGenericResponse::new().embed(|e| {
 | 
			
		||||
                e.title("Info")
 | 
			
		||||
                    .description(format!(
 | 
			
		||||
                        "Help: `/help`
 | 
			
		||||
 | 
			
		||||
**Welcome to Reminder Bot!**
 | 
			
		||||
Developer: <@203532103185465344>
 | 
			
		||||
Icon: <@253202252821430272>
 | 
			
		||||
Find me on https://discord.jellywx.com and on https://github.com/JellyWX :)
 | 
			
		||||
 | 
			
		||||
Invite the bot: https://invite.reminder-bot.com/
 | 
			
		||||
Use our dashboard: https://reminder-bot.com/",
 | 
			
		||||
                    ))
 | 
			
		||||
                    .footer(footer)
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
            }),
 | 
			
		||||
        )
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
#[description("Details on supporting the bot and Patreon benefits")]
 | 
			
		||||
#[group("Info")]
 | 
			
		||||
async fn donate(ctx: &Context, invoke: &mut CommandInvoke) {
 | 
			
		||||
    let footer = footer(ctx);
 | 
			
		||||
 | 
			
		||||
    let _ = invoke
 | 
			
		||||
        .respond(
 | 
			
		||||
            ctx.http.clone(),
 | 
			
		||||
            CreateGenericResponse::new().embed(|e| {
 | 
			
		||||
                e.title("Donate")
 | 
			
		||||
                    .description("Thinking of adding a monthly contribution? Click below for my Patreon and official bot server :)
 | 
			
		||||
 | 
			
		||||
**https://www.patreon.com/jellywx/**
 | 
			
		||||
**https://discord.jellywx.com/**
 | 
			
		||||
 | 
			
		||||
When you subscribe, Patreon will automatically rank you up on our Discord server (make sure you link your Patreon and Discord accounts!)
 | 
			
		||||
With your new rank, you'll be able to:
 | 
			
		||||
• Set repeating reminders with `interval`, `natural` or the dashboard
 | 
			
		||||
• Use unlimited uploads on SoundFX
 | 
			
		||||
 | 
			
		||||
(Also, members of servers you __own__ will be able to set repeating reminders via commands)
 | 
			
		||||
 | 
			
		||||
Just $2 USD/month!
 | 
			
		||||
 | 
			
		||||
*Please note, you must be in the JellyWX Discord server to receive Patreon features*")
 | 
			
		||||
                    .footer(footer)
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
            }),
 | 
			
		||||
        )
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
#[description("Get the link to the online dashboard")]
 | 
			
		||||
#[group("Info")]
 | 
			
		||||
async fn dashboard(ctx: &Context, invoke: &mut CommandInvoke) {
 | 
			
		||||
    let footer = footer(ctx);
 | 
			
		||||
 | 
			
		||||
    let _ = invoke
 | 
			
		||||
        .respond(
 | 
			
		||||
            ctx.http.clone(),
 | 
			
		||||
            CreateGenericResponse::new().embed(|e| {
 | 
			
		||||
                e.title("Dashboard")
 | 
			
		||||
                    .description("**https://reminder-bot.com/dashboard**")
 | 
			
		||||
                    .footer(footer)
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
            }),
 | 
			
		||||
        )
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command]
 | 
			
		||||
#[description("View the current time in your selected timezone")]
 | 
			
		||||
#[group("Info")]
 | 
			
		||||
async fn clock(ctx: &Context, invoke: &mut CommandInvoke) {
 | 
			
		||||
    let ud = ctx.user_data(&invoke.author_id()).await.unwrap();
 | 
			
		||||
    let now = Utc::now().with_timezone(&ud.timezone());
 | 
			
		||||
 | 
			
		||||
    let _ = invoke
 | 
			
		||||
        .respond(
 | 
			
		||||
            ctx.http.clone(),
 | 
			
		||||
            CreateGenericResponse::new().content(format!("Current time: {}", now.format("%H:%M"))),
 | 
			
		||||
        )
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@@ -1,443 +1,260 @@
 | 
			
		||||
use regex_command_attr::command;
 | 
			
		||||
use serenity::client::Context;
 | 
			
		||||
 | 
			
		||||
use serenity::{
 | 
			
		||||
    async_trait,
 | 
			
		||||
    client::Context,
 | 
			
		||||
    constants::MESSAGE_CODE_LIMIT,
 | 
			
		||||
    model::{
 | 
			
		||||
        channel::Message,
 | 
			
		||||
        id::{ChannelId, GuildId, UserId},
 | 
			
		||||
use crate::{
 | 
			
		||||
    component_models::{
 | 
			
		||||
        pager::{Pager, TodoPager},
 | 
			
		||||
        ComponentDataModel, TodoSelector,
 | 
			
		||||
    },
 | 
			
		||||
    consts::{EMBED_DESCRIPTION_MAX_LENGTH, SELECT_MAX_ENTRIES, THEME_COLOR},
 | 
			
		||||
    framework::{CommandInvoke, CommandOptions, CreateGenericResponse},
 | 
			
		||||
    hooks::CHECK_GUILD_PERMISSIONS_HOOK,
 | 
			
		||||
    SQLPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use std::fmt;
 | 
			
		||||
#[command]
 | 
			
		||||
#[description("Manage todo lists")]
 | 
			
		||||
#[subcommandgroup("server")]
 | 
			
		||||
#[description("Manage the server todo list")]
 | 
			
		||||
#[subcommand("add")]
 | 
			
		||||
#[description("Add an item to the server todo list")]
 | 
			
		||||
#[arg(
 | 
			
		||||
    name = "task",
 | 
			
		||||
    description = "The task to add to the todo list",
 | 
			
		||||
    kind = "String",
 | 
			
		||||
    required = true
 | 
			
		||||
)]
 | 
			
		||||
#[subcommand("view")]
 | 
			
		||||
#[description("View and remove from the server todo list")]
 | 
			
		||||
#[subcommandgroup("channel")]
 | 
			
		||||
#[description("Manage the channel todo list")]
 | 
			
		||||
#[subcommand("add")]
 | 
			
		||||
#[description("Add to the channel todo list")]
 | 
			
		||||
#[arg(
 | 
			
		||||
    name = "task",
 | 
			
		||||
    description = "The task to add to the todo list",
 | 
			
		||||
    kind = "String",
 | 
			
		||||
    required = true
 | 
			
		||||
)]
 | 
			
		||||
#[subcommand("view")]
 | 
			
		||||
#[description("View and remove from the channel todo list")]
 | 
			
		||||
#[subcommandgroup("user")]
 | 
			
		||||
#[description("Manage your personal todo list")]
 | 
			
		||||
#[subcommand("add")]
 | 
			
		||||
#[description("Add to your personal todo list")]
 | 
			
		||||
#[arg(
 | 
			
		||||
    name = "task",
 | 
			
		||||
    description = "The task to add to the todo list",
 | 
			
		||||
    kind = "String",
 | 
			
		||||
    required = true
 | 
			
		||||
)]
 | 
			
		||||
#[subcommand("view")]
 | 
			
		||||
#[description("View and remove from your personal todo list")]
 | 
			
		||||
#[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
 | 
			
		||||
async fn todo(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) {
 | 
			
		||||
    if invoke.guild_id().is_none() && args.subcommand_group != Some("user".to_string()) {
 | 
			
		||||
        let _ = invoke
 | 
			
		||||
            .respond(
 | 
			
		||||
                &ctx,
 | 
			
		||||
                CreateGenericResponse::new().content("Please use `/todo user` in direct messages"),
 | 
			
		||||
            )
 | 
			
		||||
            .await;
 | 
			
		||||
    } else {
 | 
			
		||||
        let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
use crate::models::CtxGuildData;
 | 
			
		||||
use crate::{command_help, get_ctx_data, models::UserData};
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
use std::convert::TryFrom;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
struct TodoNotFound;
 | 
			
		||||
 | 
			
		||||
impl std::error::Error for TodoNotFound {}
 | 
			
		||||
impl fmt::Display for TodoNotFound {
 | 
			
		||||
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 | 
			
		||||
        write!(f, "Todo not found")
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct Todo {
 | 
			
		||||
    id: u32,
 | 
			
		||||
    value: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct TodoTarget {
 | 
			
		||||
    user: UserId,
 | 
			
		||||
    guild: Option<GuildId>,
 | 
			
		||||
    channel: Option<ChannelId>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl TodoTarget {
 | 
			
		||||
    pub fn command(&self, subcommand_opt: Option<SubCommand>) -> String {
 | 
			
		||||
        let context = if self.channel.is_some() {
 | 
			
		||||
            "channel"
 | 
			
		||||
        } else if self.guild.is_some() {
 | 
			
		||||
            "guild"
 | 
			
		||||
        } else {
 | 
			
		||||
            "user"
 | 
			
		||||
        let keys = match args.subcommand_group.as_ref().unwrap().as_str() {
 | 
			
		||||
            "server" => (None, None, invoke.guild_id().map(|g| g.0)),
 | 
			
		||||
            "channel" => (None, Some(invoke.channel_id().0), invoke.guild_id().map(|g| g.0)),
 | 
			
		||||
            _ => (Some(invoke.author_id().0), None, None),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        if let Some(subcommand) = subcommand_opt {
 | 
			
		||||
            format!("todo {} {}", context, subcommand.to_string())
 | 
			
		||||
        } else {
 | 
			
		||||
            format!("todo {}", context)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
        match args.get("task") {
 | 
			
		||||
            Some(task) => {
 | 
			
		||||
                let task = task.to_string();
 | 
			
		||||
 | 
			
		||||
    pub fn name(&self) -> String {
 | 
			
		||||
        if self.channel.is_some() {
 | 
			
		||||
            "Channel"
 | 
			
		||||
        } else if self.guild.is_some() {
 | 
			
		||||
            "Guild"
 | 
			
		||||
        } else {
 | 
			
		||||
            "User"
 | 
			
		||||
        }
 | 
			
		||||
        .to_string()
 | 
			
		||||
    }
 | 
			
		||||
                sqlx::query!(
 | 
			
		||||
                    "INSERT INTO todos (user_id, channel_id, guild_id, value) VALUES ((SELECT id FROM users WHERE user = ?), (SELECT id FROM channels WHERE channel = ?), (SELECT id FROM guilds WHERE guild = ?), ?)",
 | 
			
		||||
                    keys.0,
 | 
			
		||||
                    keys.1,
 | 
			
		||||
                    keys.2,
 | 
			
		||||
                    task
 | 
			
		||||
                )
 | 
			
		||||
                .execute(&pool)
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap();
 | 
			
		||||
 | 
			
		||||
    pub async fn view(
 | 
			
		||||
        &self,
 | 
			
		||||
        pool: MySqlPool,
 | 
			
		||||
    ) -> Result<Vec<Todo>, Box<dyn std::error::Error + Send + Sync>> {
 | 
			
		||||
        Ok(if let Some(cid) = self.channel {
 | 
			
		||||
            sqlx::query_as!(
 | 
			
		||||
                Todo,
 | 
			
		||||
                "
 | 
			
		||||
SELECT id, value FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
 | 
			
		||||
                ",
 | 
			
		||||
                cid.as_u64()
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_all(&pool)
 | 
			
		||||
            .await?
 | 
			
		||||
        } else if let Some(gid) = self.guild {
 | 
			
		||||
            sqlx::query_as!(
 | 
			
		||||
                Todo,
 | 
			
		||||
                "
 | 
			
		||||
SELECT id, value FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
 | 
			
		||||
                ",
 | 
			
		||||
                gid.as_u64()
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_all(&pool)
 | 
			
		||||
            .await?
 | 
			
		||||
        } else {
 | 
			
		||||
            sqlx::query_as!(
 | 
			
		||||
                Todo,
 | 
			
		||||
                "
 | 
			
		||||
SELECT id, value FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
 | 
			
		||||
                ",
 | 
			
		||||
                self.user.as_u64()
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_all(&pool)
 | 
			
		||||
            .await?
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn add(
 | 
			
		||||
        &self,
 | 
			
		||||
        value: String,
 | 
			
		||||
        pool: MySqlPool,
 | 
			
		||||
    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
 | 
			
		||||
        if let (Some(cid), Some(gid)) = (self.channel, self.guild) {
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
INSERT INTO todos (user_id, guild_id, channel_id, value) VALUES (
 | 
			
		||||
    (SELECT id FROM users WHERE user = ?),
 | 
			
		||||
    (SELECT id FROM guilds WHERE guild = ?),
 | 
			
		||||
    (SELECT id FROM channels WHERE channel = ?),
 | 
			
		||||
    ?
 | 
			
		||||
)
 | 
			
		||||
                ",
 | 
			
		||||
                self.user.as_u64(),
 | 
			
		||||
                gid.as_u64(),
 | 
			
		||||
                cid.as_u64(),
 | 
			
		||||
                value
 | 
			
		||||
            )
 | 
			
		||||
            .execute(&pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        } else if let Some(gid) = self.guild {
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
INSERT INTO todos (user_id, guild_id, value) VALUES (
 | 
			
		||||
    (SELECT id FROM users WHERE user = ?),
 | 
			
		||||
    (SELECT id FROM guilds WHERE guild = ?),
 | 
			
		||||
    ?
 | 
			
		||||
)
 | 
			
		||||
                ",
 | 
			
		||||
                self.user.as_u64(),
 | 
			
		||||
                gid.as_u64(),
 | 
			
		||||
                value
 | 
			
		||||
            )
 | 
			
		||||
            .execute(&pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        } else {
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
INSERT INTO todos (user_id, value) VALUES (
 | 
			
		||||
    (SELECT id FROM users WHERE user = ?),
 | 
			
		||||
    ?
 | 
			
		||||
)
 | 
			
		||||
                ",
 | 
			
		||||
                self.user.as_u64(),
 | 
			
		||||
                value
 | 
			
		||||
            )
 | 
			
		||||
            .execute(&pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn remove(
 | 
			
		||||
        &self,
 | 
			
		||||
        num: usize,
 | 
			
		||||
        pool: &MySqlPool,
 | 
			
		||||
    ) -> Result<Todo, Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        let todos = self.view(pool.clone()).await?;
 | 
			
		||||
 | 
			
		||||
        if let Some(removal_todo) = todos.get(num) {
 | 
			
		||||
            let deleting = sqlx::query_as!(
 | 
			
		||||
                Todo,
 | 
			
		||||
                "
 | 
			
		||||
SELECT id, value FROM todos WHERE id = ?
 | 
			
		||||
                ",
 | 
			
		||||
                removal_todo.id
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_one(&pool.clone())
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
DELETE FROM todos WHERE id = ?
 | 
			
		||||
                ",
 | 
			
		||||
                removal_todo.id
 | 
			
		||||
            )
 | 
			
		||||
            .execute(pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
            Ok(deleting)
 | 
			
		||||
        } else {
 | 
			
		||||
            Err(Box::new(TodoNotFound))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn clear(
 | 
			
		||||
        &self,
 | 
			
		||||
        pool: &MySqlPool,
 | 
			
		||||
    ) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        if let Some(cid) = self.channel {
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
DELETE FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
 | 
			
		||||
                ",
 | 
			
		||||
                cid.as_u64()
 | 
			
		||||
            )
 | 
			
		||||
            .execute(pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        } else if let Some(gid) = self.guild {
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
DELETE FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
 | 
			
		||||
                ",
 | 
			
		||||
                gid.as_u64()
 | 
			
		||||
            )
 | 
			
		||||
            .execute(pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        } else {
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
DELETE FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
 | 
			
		||||
                ",
 | 
			
		||||
                self.user.as_u64()
 | 
			
		||||
            )
 | 
			
		||||
            .execute(pool)
 | 
			
		||||
            .await?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn execute(&self, ctx: &Context, msg: &Message, subcommand: SubCommand, extra: String) {
 | 
			
		||||
        let (pool, lm) = get_ctx_data(&ctx).await;
 | 
			
		||||
 | 
			
		||||
        let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
 | 
			
		||||
        let prefix = ctx.prefix(msg.guild_id).await;
 | 
			
		||||
 | 
			
		||||
        match subcommand {
 | 
			
		||||
            SubCommand::View => {
 | 
			
		||||
                let todo_items = self.view(pool).await.unwrap();
 | 
			
		||||
                let mut todo_groups = vec!["".to_string()];
 | 
			
		||||
                let mut char_count = 0;
 | 
			
		||||
 | 
			
		||||
                todo_items.iter().enumerate().for_each(|(count, todo)| {
 | 
			
		||||
                    let display = format!("{}: {}\n", count + 1, todo.value);
 | 
			
		||||
 | 
			
		||||
                    if char_count + display.len() > MESSAGE_CODE_LIMIT as usize {
 | 
			
		||||
                        char_count = display.len();
 | 
			
		||||
 | 
			
		||||
                        todo_groups.push(display);
 | 
			
		||||
                    } else {
 | 
			
		||||
                        char_count += display.len();
 | 
			
		||||
 | 
			
		||||
                        let last_group = todo_groups.pop().unwrap();
 | 
			
		||||
 | 
			
		||||
                        todo_groups.push(format!("{}{}", last_group, display));
 | 
			
		||||
                    }
 | 
			
		||||
                });
 | 
			
		||||
 | 
			
		||||
                for group in todo_groups {
 | 
			
		||||
                    let _ = msg
 | 
			
		||||
                        .channel_id
 | 
			
		||||
                        .send_message(&ctx, |m| {
 | 
			
		||||
                            m.embed(|e| e.title(format!("{} Todo", self.name())).description(group))
 | 
			
		||||
                        })
 | 
			
		||||
                        .await;
 | 
			
		||||
                }
 | 
			
		||||
                let _ = invoke
 | 
			
		||||
                    .respond(&ctx, CreateGenericResponse::new().content("Item added to todo list"))
 | 
			
		||||
                    .await;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            SubCommand::Add => {
 | 
			
		||||
                let content = lm
 | 
			
		||||
                    .get(&user_data.language, "todo/added")
 | 
			
		||||
                    .replacen("{name}", &extra, 1);
 | 
			
		||||
 | 
			
		||||
                self.add(extra, pool).await.unwrap();
 | 
			
		||||
 | 
			
		||||
                let _ = msg.channel_id.say(&ctx, content).await;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            SubCommand::Remove => {
 | 
			
		||||
                if let Ok(num) = extra.parse::<usize>() {
 | 
			
		||||
                    if let Ok(todo) = self.remove(num - 1, &pool).await {
 | 
			
		||||
                        let content = lm.get(&user_data.language, "todo/removed").replacen(
 | 
			
		||||
                            "{}",
 | 
			
		||||
                            &todo.value,
 | 
			
		||||
                            1,
 | 
			
		||||
                        );
 | 
			
		||||
 | 
			
		||||
                        let _ = msg.channel_id.say(&ctx, content).await;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        let _ = msg
 | 
			
		||||
                            .channel_id
 | 
			
		||||
                            .say(&ctx, lm.get(&user_data.language, "todo/error_index"))
 | 
			
		||||
                            .await;
 | 
			
		||||
                    }
 | 
			
		||||
            None => {
 | 
			
		||||
                let values = if let Some(uid) = keys.0 {
 | 
			
		||||
                    sqlx::query!(
 | 
			
		||||
                        "SELECT todos.id, value FROM todos
 | 
			
		||||
INNER JOIN users ON todos.user_id = users.id
 | 
			
		||||
WHERE users.user = ?",
 | 
			
		||||
                        uid,
 | 
			
		||||
                    )
 | 
			
		||||
                    .fetch_all(&pool)
 | 
			
		||||
                    .await
 | 
			
		||||
                    .unwrap()
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .map(|row| (row.id as usize, row.value.clone()))
 | 
			
		||||
                    .collect::<Vec<(usize, String)>>()
 | 
			
		||||
                } else if let Some(cid) = keys.1 {
 | 
			
		||||
                    sqlx::query!(
 | 
			
		||||
                        "SELECT todos.id, value FROM todos
 | 
			
		||||
INNER JOIN channels ON todos.channel_id = channels.id
 | 
			
		||||
WHERE channels.channel = ?",
 | 
			
		||||
                        cid,
 | 
			
		||||
                    )
 | 
			
		||||
                    .fetch_all(&pool)
 | 
			
		||||
                    .await
 | 
			
		||||
                    .unwrap()
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .map(|row| (row.id as usize, row.value.clone()))
 | 
			
		||||
                    .collect::<Vec<(usize, String)>>()
 | 
			
		||||
                } else {
 | 
			
		||||
                    let content = lm
 | 
			
		||||
                        .get(&user_data.language, "todo/error_value")
 | 
			
		||||
                        .replacen("{prefix}", &prefix, 1)
 | 
			
		||||
                        .replacen("{command}", &self.command(Some(subcommand)), 1);
 | 
			
		||||
                    sqlx::query!(
 | 
			
		||||
                        "SELECT todos.id, value FROM todos
 | 
			
		||||
INNER JOIN guilds ON todos.guild_id = guilds.id
 | 
			
		||||
WHERE guilds.guild = ?",
 | 
			
		||||
                        keys.2,
 | 
			
		||||
                    )
 | 
			
		||||
                    .fetch_all(&pool)
 | 
			
		||||
                    .await
 | 
			
		||||
                    .unwrap()
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .map(|row| (row.id as usize, row.value.clone()))
 | 
			
		||||
                    .collect::<Vec<(usize, String)>>()
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                    let _ = msg.channel_id.say(&ctx, content).await;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
                let resp = show_todo_page(&values, 0, keys.0, keys.1, keys.2);
 | 
			
		||||
 | 
			
		||||
            SubCommand::Clear => {
 | 
			
		||||
                self.clear(&pool).await.unwrap();
 | 
			
		||||
 | 
			
		||||
                let content = lm.get(&user_data.language, "todo/cleared");
 | 
			
		||||
 | 
			
		||||
                let _ = msg.channel_id.say(&ctx, content).await;
 | 
			
		||||
                invoke.respond(&ctx, resp).await.unwrap();
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum SubCommand {
 | 
			
		||||
    View,
 | 
			
		||||
    Add,
 | 
			
		||||
    Remove,
 | 
			
		||||
    Clear,
 | 
			
		||||
}
 | 
			
		||||
pub fn max_todo_page(todo_values: &[(usize, String)]) -> usize {
 | 
			
		||||
    let mut rows = 0;
 | 
			
		||||
    let mut char_count = 0;
 | 
			
		||||
 | 
			
		||||
impl TryFrom<Option<&str>> for SubCommand {
 | 
			
		||||
    type Error = ();
 | 
			
		||||
    todo_values.iter().enumerate().map(|(c, (_, v))| format!("{}: {}", c, v)).fold(
 | 
			
		||||
        1,
 | 
			
		||||
        |mut pages, text| {
 | 
			
		||||
            rows += 1;
 | 
			
		||||
            char_count += text.len();
 | 
			
		||||
 | 
			
		||||
    fn try_from(value: Option<&str>) -> Result<Self, Self::Error> {
 | 
			
		||||
        match value {
 | 
			
		||||
            Some("add") => Ok(SubCommand::Add),
 | 
			
		||||
 | 
			
		||||
            Some("remove") => Ok(SubCommand::Remove),
 | 
			
		||||
 | 
			
		||||
            Some("clear") => Ok(SubCommand::Clear),
 | 
			
		||||
 | 
			
		||||
            None | Some("") => Ok(SubCommand::View),
 | 
			
		||||
 | 
			
		||||
            Some(_unrecognised) => Err(()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToString for SubCommand {
 | 
			
		||||
    fn to_string(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            SubCommand::View => "",
 | 
			
		||||
            SubCommand::Add => "add",
 | 
			
		||||
            SubCommand::Remove => "remove",
 | 
			
		||||
            SubCommand::Clear => "clear",
 | 
			
		||||
        }
 | 
			
		||||
        .to_string()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
trait Execute {
 | 
			
		||||
    async fn execute(self, ctx: &Context, msg: &Message, extra: String, target: TodoTarget);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl Execute for Result<SubCommand, ()> {
 | 
			
		||||
    async fn execute(self, ctx: &Context, msg: &Message, extra: String, target: TodoTarget) {
 | 
			
		||||
        if let Ok(subcommand) = self {
 | 
			
		||||
            target.execute(ctx, msg, subcommand, extra).await;
 | 
			
		||||
        } else {
 | 
			
		||||
            show_help(&ctx, msg, Some(target)).await;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command("todo")]
 | 
			
		||||
async fn todo_user(ctx: &Context, msg: &Message, args: String) {
 | 
			
		||||
    let mut split = args.split(' ');
 | 
			
		||||
 | 
			
		||||
    let target = TodoTarget {
 | 
			
		||||
        user: msg.author.id,
 | 
			
		||||
        guild: None,
 | 
			
		||||
        channel: None,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let subcommand_opt = SubCommand::try_from(split.next());
 | 
			
		||||
 | 
			
		||||
    subcommand_opt
 | 
			
		||||
        .execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command("todoc")]
 | 
			
		||||
#[supports_dm(false)]
 | 
			
		||||
#[permission_level(Managed)]
 | 
			
		||||
async fn todo_channel(ctx: &Context, msg: &Message, args: String) {
 | 
			
		||||
    let mut split = args.split(' ');
 | 
			
		||||
 | 
			
		||||
    let target = TodoTarget {
 | 
			
		||||
        user: msg.author.id,
 | 
			
		||||
        guild: msg.guild_id,
 | 
			
		||||
        channel: Some(msg.channel_id),
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let subcommand_opt = SubCommand::try_from(split.next());
 | 
			
		||||
 | 
			
		||||
    subcommand_opt
 | 
			
		||||
        .execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[command("todos")]
 | 
			
		||||
#[supports_dm(false)]
 | 
			
		||||
#[permission_level(Managed)]
 | 
			
		||||
async fn todo_guild(ctx: &Context, msg: &Message, args: String) {
 | 
			
		||||
    let mut split = args.split(' ');
 | 
			
		||||
 | 
			
		||||
    let target = TodoTarget {
 | 
			
		||||
        user: msg.author.id,
 | 
			
		||||
        guild: msg.guild_id,
 | 
			
		||||
        channel: None,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let subcommand_opt = SubCommand::try_from(split.next());
 | 
			
		||||
 | 
			
		||||
    subcommand_opt
 | 
			
		||||
        .execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn show_help(ctx: &Context, msg: &Message, target: Option<TodoTarget>) {
 | 
			
		||||
    let (pool, lm) = get_ctx_data(&ctx).await;
 | 
			
		||||
 | 
			
		||||
    let language = UserData::language_of(&msg.author, &pool);
 | 
			
		||||
    let prefix = ctx.prefix(msg.guild_id);
 | 
			
		||||
 | 
			
		||||
    let command = match target {
 | 
			
		||||
        None => "todo",
 | 
			
		||||
        Some(t) => {
 | 
			
		||||
            if t.channel.is_some() {
 | 
			
		||||
                "todoc"
 | 
			
		||||
            } else if t.guild.is_some() {
 | 
			
		||||
                "todos"
 | 
			
		||||
            } else {
 | 
			
		||||
                "todo"
 | 
			
		||||
            if char_count > EMBED_DESCRIPTION_MAX_LENGTH || rows > SELECT_MAX_ENTRIES {
 | 
			
		||||
                rows = 1;
 | 
			
		||||
                char_count = text.len();
 | 
			
		||||
                pages += 1;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
            pages
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn show_todo_page(
 | 
			
		||||
    todo_values: &[(usize, String)],
 | 
			
		||||
    page: usize,
 | 
			
		||||
    user_id: Option<u64>,
 | 
			
		||||
    channel_id: Option<u64>,
 | 
			
		||||
    guild_id: Option<u64>,
 | 
			
		||||
) -> CreateGenericResponse {
 | 
			
		||||
    let pager = TodoPager::new(page, user_id, channel_id, guild_id);
 | 
			
		||||
 | 
			
		||||
    let pages = max_todo_page(todo_values);
 | 
			
		||||
    let mut page = page;
 | 
			
		||||
    if page >= pages {
 | 
			
		||||
        page = pages - 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let mut char_count = 0;
 | 
			
		||||
    let mut rows = 0;
 | 
			
		||||
    let mut skipped_rows = 0;
 | 
			
		||||
    let mut skipped_char_count = 0;
 | 
			
		||||
    let mut first_num = 0;
 | 
			
		||||
 | 
			
		||||
    let mut skipped_pages = 0;
 | 
			
		||||
 | 
			
		||||
    let (todo_ids, display_vec): (Vec<usize>, Vec<String>) = todo_values
 | 
			
		||||
        .iter()
 | 
			
		||||
        .enumerate()
 | 
			
		||||
        .map(|(c, (i, v))| (i, format!("`{}`: {}", c + 1, v)))
 | 
			
		||||
        .skip_while(|(_, p)| {
 | 
			
		||||
            first_num += 1;
 | 
			
		||||
            skipped_rows += 1;
 | 
			
		||||
            skipped_char_count += p.len();
 | 
			
		||||
 | 
			
		||||
            if skipped_char_count > EMBED_DESCRIPTION_MAX_LENGTH
 | 
			
		||||
                || skipped_rows > SELECT_MAX_ENTRIES
 | 
			
		||||
            {
 | 
			
		||||
                skipped_rows = 1;
 | 
			
		||||
                skipped_char_count = p.len();
 | 
			
		||||
                skipped_pages += 1;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            skipped_pages < page
 | 
			
		||||
        })
 | 
			
		||||
        .take_while(|(_, p)| {
 | 
			
		||||
            rows += 1;
 | 
			
		||||
            char_count += p.len();
 | 
			
		||||
 | 
			
		||||
            char_count < EMBED_DESCRIPTION_MAX_LENGTH && rows <= SELECT_MAX_ENTRIES
 | 
			
		||||
        })
 | 
			
		||||
        .unzip();
 | 
			
		||||
 | 
			
		||||
    let display = display_vec.join("\n");
 | 
			
		||||
 | 
			
		||||
    let title = if user_id.is_some() {
 | 
			
		||||
        "Your"
 | 
			
		||||
    } else if channel_id.is_some() {
 | 
			
		||||
        "Channel"
 | 
			
		||||
    } else {
 | 
			
		||||
        "Server"
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    command_help(ctx, msg, lm, &prefix.await, &language.await, command).await;
 | 
			
		||||
    if todo_ids.is_empty() {
 | 
			
		||||
        CreateGenericResponse::new().embed(|e| {
 | 
			
		||||
            e.title(format!("{} Todo List", title))
 | 
			
		||||
                .description("Todo List Empty!")
 | 
			
		||||
                .footer(|f| f.text(format!("Page {} of {}", page + 1, pages)))
 | 
			
		||||
                .color(*THEME_COLOR)
 | 
			
		||||
        })
 | 
			
		||||
    } else {
 | 
			
		||||
        let todo_selector =
 | 
			
		||||
            ComponentDataModel::TodoSelector(TodoSelector { page, user_id, channel_id, guild_id });
 | 
			
		||||
 | 
			
		||||
        CreateGenericResponse::new()
 | 
			
		||||
            .embed(|e| {
 | 
			
		||||
                e.title(format!("{} Todo List", title))
 | 
			
		||||
                    .description(display)
 | 
			
		||||
                    .footer(|f| f.text(format!("Page {} of {}", page + 1, pages)))
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
            })
 | 
			
		||||
            .components(|comp| {
 | 
			
		||||
                pager.create_button_row(pages, comp);
 | 
			
		||||
 | 
			
		||||
                comp.create_action_row(|row| {
 | 
			
		||||
                    row.create_select_menu(|menu| {
 | 
			
		||||
                        menu.custom_id(todo_selector.to_custom_id()).options(|opt| {
 | 
			
		||||
                            for (count, (id, disp)) in todo_ids.iter().zip(&display_vec).enumerate()
 | 
			
		||||
                            {
 | 
			
		||||
                                opt.create_option(|o| {
 | 
			
		||||
                                    o.label(format!("Mark {} complete", count + first_num))
 | 
			
		||||
                                        .value(id)
 | 
			
		||||
                                        .description(disp.split_once(" ").unwrap_or(("", "")).1)
 | 
			
		||||
                                });
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            opt
 | 
			
		||||
                        })
 | 
			
		||||
                    })
 | 
			
		||||
                })
 | 
			
		||||
            })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										310
									
								
								src/component_models/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										310
									
								
								src/component_models/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,310 @@
 | 
			
		||||
pub(crate) mod pager;
 | 
			
		||||
 | 
			
		||||
use std::io::Cursor;
 | 
			
		||||
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use num_integer::Integer;
 | 
			
		||||
use rmp_serde::Serializer;
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
use serenity::{
 | 
			
		||||
    builder::CreateEmbed,
 | 
			
		||||
    client::Context,
 | 
			
		||||
    model::{
 | 
			
		||||
        channel::Channel,
 | 
			
		||||
        interactions::{message_component::MessageComponentInteraction, InteractionResponseType},
 | 
			
		||||
        prelude::InteractionApplicationCommandCallbackDataFlags,
 | 
			
		||||
    },
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    commands::{
 | 
			
		||||
        moderation_cmds::{max_macro_page, show_macro_page},
 | 
			
		||||
        reminder_cmds::{max_delete_page, show_delete_page},
 | 
			
		||||
        todo_cmds::{max_todo_page, show_todo_page},
 | 
			
		||||
    },
 | 
			
		||||
    component_models::pager::{DelPager, LookPager, MacroPager, Pager, TodoPager},
 | 
			
		||||
    consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR},
 | 
			
		||||
    framework::CommandInvoke,
 | 
			
		||||
    models::{command_macro::CommandMacro, reminder::Reminder},
 | 
			
		||||
    SQLPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize, Serialize)]
 | 
			
		||||
#[serde(tag = "type")]
 | 
			
		||||
#[repr(u8)]
 | 
			
		||||
pub enum ComponentDataModel {
 | 
			
		||||
    LookPager(LookPager),
 | 
			
		||||
    DelPager(DelPager),
 | 
			
		||||
    TodoPager(TodoPager),
 | 
			
		||||
    DelSelector(DelSelector),
 | 
			
		||||
    TodoSelector(TodoSelector),
 | 
			
		||||
    MacroPager(MacroPager),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ComponentDataModel {
 | 
			
		||||
    pub fn to_custom_id(&self) -> String {
 | 
			
		||||
        let mut buf = Vec::new();
 | 
			
		||||
        self.serialize(&mut Serializer::new(&mut buf)).unwrap();
 | 
			
		||||
        base64::encode(buf)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn from_custom_id(data: &String) -> Self {
 | 
			
		||||
        let buf = base64::decode(data)
 | 
			
		||||
            .map_err(|e| format!("Could not decode `custom_id' {}: {:?}", data, e))
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        let cur = Cursor::new(buf);
 | 
			
		||||
        rmp_serde::from_read(cur).unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn act(&self, ctx: &Context, component: MessageComponentInteraction) {
 | 
			
		||||
        match self {
 | 
			
		||||
            ComponentDataModel::LookPager(pager) => {
 | 
			
		||||
                let flags = pager.flags;
 | 
			
		||||
 | 
			
		||||
                let channel_opt = component.channel_id.to_channel_cached(&ctx);
 | 
			
		||||
 | 
			
		||||
                let channel_id = if let Some(Channel::Guild(channel)) = channel_opt {
 | 
			
		||||
                    if Some(channel.guild_id) == component.guild_id {
 | 
			
		||||
                        flags.channel_id.unwrap_or(component.channel_id)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        component.channel_id
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    component.channel_id
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                let reminders = Reminder::from_channel(ctx, channel_id, &flags).await;
 | 
			
		||||
 | 
			
		||||
                let pages = reminders
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .map(|reminder| reminder.display(&flags, &pager.timezone))
 | 
			
		||||
                    .fold(0, |t, r| t + r.len())
 | 
			
		||||
                    .div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH);
 | 
			
		||||
 | 
			
		||||
                let channel_name =
 | 
			
		||||
                    if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx) {
 | 
			
		||||
                        Some(channel.name)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        None
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                let next_page = pager.next_page(pages);
 | 
			
		||||
 | 
			
		||||
                let mut char_count = 0;
 | 
			
		||||
                let mut skip_char_count = 0;
 | 
			
		||||
 | 
			
		||||
                let display = reminders
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .map(|reminder| reminder.display(&flags, &pager.timezone))
 | 
			
		||||
                    .skip_while(|p| {
 | 
			
		||||
                        skip_char_count += p.len();
 | 
			
		||||
 | 
			
		||||
                        skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * next_page as usize
 | 
			
		||||
                    })
 | 
			
		||||
                    .take_while(|p| {
 | 
			
		||||
                        char_count += p.len();
 | 
			
		||||
 | 
			
		||||
                        char_count < EMBED_DESCRIPTION_MAX_LENGTH
 | 
			
		||||
                    })
 | 
			
		||||
                    .collect::<Vec<String>>()
 | 
			
		||||
                    .join("\n");
 | 
			
		||||
 | 
			
		||||
                let mut embed = CreateEmbed::default();
 | 
			
		||||
                embed
 | 
			
		||||
                    .title(format!(
 | 
			
		||||
                        "Reminders{}",
 | 
			
		||||
                        channel_name.map_or(String::new(), |n| format!(" on #{}", n))
 | 
			
		||||
                    ))
 | 
			
		||||
                    .description(display)
 | 
			
		||||
                    .footer(|f| f.text(format!("Page {} of {}", next_page + 1, pages)))
 | 
			
		||||
                    .color(*THEME_COLOR);
 | 
			
		||||
 | 
			
		||||
                let _ = component
 | 
			
		||||
                    .create_interaction_response(&ctx, |r| {
 | 
			
		||||
                        r.kind(InteractionResponseType::UpdateMessage).interaction_response_data(
 | 
			
		||||
                            |response| {
 | 
			
		||||
                                response.embeds(vec![embed]).components(|comp| {
 | 
			
		||||
                                    pager.create_button_row(pages, comp);
 | 
			
		||||
 | 
			
		||||
                                    comp
 | 
			
		||||
                                })
 | 
			
		||||
                            },
 | 
			
		||||
                        )
 | 
			
		||||
                    })
 | 
			
		||||
                    .await;
 | 
			
		||||
            }
 | 
			
		||||
            ComponentDataModel::DelPager(pager) => {
 | 
			
		||||
                let reminders =
 | 
			
		||||
                    Reminder::from_guild(ctx, component.guild_id, component.user.id).await;
 | 
			
		||||
 | 
			
		||||
                let max_pages = max_delete_page(&reminders, &pager.timezone);
 | 
			
		||||
 | 
			
		||||
                let resp = show_delete_page(&reminders, pager.next_page(max_pages), pager.timezone);
 | 
			
		||||
 | 
			
		||||
                let mut invoke = CommandInvoke::component(component);
 | 
			
		||||
                let _ = invoke.respond(&ctx, resp).await;
 | 
			
		||||
            }
 | 
			
		||||
            ComponentDataModel::DelSelector(selector) => {
 | 
			
		||||
                let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
                let selected_id = component.data.values.join(",");
 | 
			
		||||
 | 
			
		||||
                sqlx::query!("DELETE FROM reminders WHERE FIND_IN_SET(id, ?)", selected_id)
 | 
			
		||||
                    .execute(&pool)
 | 
			
		||||
                    .await
 | 
			
		||||
                    .unwrap();
 | 
			
		||||
 | 
			
		||||
                let reminders =
 | 
			
		||||
                    Reminder::from_guild(ctx, component.guild_id, component.user.id).await;
 | 
			
		||||
 | 
			
		||||
                let resp = show_delete_page(&reminders, selector.page, selector.timezone);
 | 
			
		||||
 | 
			
		||||
                let mut invoke = CommandInvoke::component(component);
 | 
			
		||||
                let _ = invoke.respond(&ctx, resp).await;
 | 
			
		||||
            }
 | 
			
		||||
            ComponentDataModel::TodoPager(pager) => {
 | 
			
		||||
                if Some(component.user.id.0) == pager.user_id || pager.user_id.is_none() {
 | 
			
		||||
                    let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
                    let values = if let Some(uid) = pager.user_id {
 | 
			
		||||
                        sqlx::query!(
 | 
			
		||||
                            "SELECT todos.id, value FROM todos
 | 
			
		||||
    INNER JOIN users ON todos.user_id = users.id
 | 
			
		||||
    WHERE users.user = ?",
 | 
			
		||||
                            uid,
 | 
			
		||||
                        )
 | 
			
		||||
                        .fetch_all(&pool)
 | 
			
		||||
                        .await
 | 
			
		||||
                        .unwrap()
 | 
			
		||||
                        .iter()
 | 
			
		||||
                        .map(|row| (row.id as usize, row.value.clone()))
 | 
			
		||||
                        .collect::<Vec<(usize, String)>>()
 | 
			
		||||
                    } else if let Some(cid) = pager.channel_id {
 | 
			
		||||
                        sqlx::query!(
 | 
			
		||||
                            "SELECT todos.id, value FROM todos
 | 
			
		||||
    INNER JOIN channels ON todos.channel_id = channels.id
 | 
			
		||||
    WHERE channels.channel = ?",
 | 
			
		||||
                            cid,
 | 
			
		||||
                        )
 | 
			
		||||
                        .fetch_all(&pool)
 | 
			
		||||
                        .await
 | 
			
		||||
                        .unwrap()
 | 
			
		||||
                        .iter()
 | 
			
		||||
                        .map(|row| (row.id as usize, row.value.clone()))
 | 
			
		||||
                        .collect::<Vec<(usize, String)>>()
 | 
			
		||||
                    } else {
 | 
			
		||||
                        sqlx::query!(
 | 
			
		||||
                            "SELECT todos.id, value FROM todos
 | 
			
		||||
    INNER JOIN guilds ON todos.guild_id = guilds.id
 | 
			
		||||
    WHERE guilds.guild = ?",
 | 
			
		||||
                            pager.guild_id,
 | 
			
		||||
                        )
 | 
			
		||||
                        .fetch_all(&pool)
 | 
			
		||||
                        .await
 | 
			
		||||
                        .unwrap()
 | 
			
		||||
                        .iter()
 | 
			
		||||
                        .map(|row| (row.id as usize, row.value.clone()))
 | 
			
		||||
                        .collect::<Vec<(usize, String)>>()
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let max_pages = max_todo_page(&values);
 | 
			
		||||
 | 
			
		||||
                    let resp = show_todo_page(
 | 
			
		||||
                        &values,
 | 
			
		||||
                        pager.next_page(max_pages),
 | 
			
		||||
                        pager.user_id,
 | 
			
		||||
                        pager.channel_id,
 | 
			
		||||
                        pager.guild_id,
 | 
			
		||||
                    );
 | 
			
		||||
 | 
			
		||||
                    let mut invoke = CommandInvoke::component(component);
 | 
			
		||||
                    let _ = invoke.respond(&ctx, resp).await;
 | 
			
		||||
                } else {
 | 
			
		||||
                    let _ = component
 | 
			
		||||
                        .create_interaction_response(&ctx, |r| {
 | 
			
		||||
                            r.kind(InteractionResponseType::ChannelMessageWithSource)
 | 
			
		||||
                                .interaction_response_data(|d| {
 | 
			
		||||
                                    d.flags(
 | 
			
		||||
                                        InteractionApplicationCommandCallbackDataFlags::EPHEMERAL,
 | 
			
		||||
                                    )
 | 
			
		||||
                                    .content("Only the user who performed the command can use these components")
 | 
			
		||||
                                })
 | 
			
		||||
                        })
 | 
			
		||||
                        .await;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            ComponentDataModel::TodoSelector(selector) => {
 | 
			
		||||
                if Some(component.user.id.0) == selector.user_id || selector.user_id.is_none() {
 | 
			
		||||
                    let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
                    let selected_id = component.data.values.join(",");
 | 
			
		||||
 | 
			
		||||
                    sqlx::query!("DELETE FROM todos WHERE FIND_IN_SET(id, ?)", selected_id)
 | 
			
		||||
                        .execute(&pool)
 | 
			
		||||
                        .await
 | 
			
		||||
                        .unwrap();
 | 
			
		||||
 | 
			
		||||
                    let values = sqlx::query!(
 | 
			
		||||
                    // fucking braindead mysql use <=> instead of = for null comparison
 | 
			
		||||
                    "SELECT id, value FROM todos WHERE user_id <=> ? AND channel_id <=> ? AND guild_id <=> ?",
 | 
			
		||||
                    selector.user_id,
 | 
			
		||||
                    selector.channel_id,
 | 
			
		||||
                    selector.guild_id,
 | 
			
		||||
                )
 | 
			
		||||
                .fetch_all(&pool)
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap()
 | 
			
		||||
                .iter()
 | 
			
		||||
                .map(|row| (row.id as usize, row.value.clone()))
 | 
			
		||||
                .collect::<Vec<(usize, String)>>();
 | 
			
		||||
 | 
			
		||||
                    let resp = show_todo_page(
 | 
			
		||||
                        &values,
 | 
			
		||||
                        selector.page,
 | 
			
		||||
                        selector.user_id,
 | 
			
		||||
                        selector.channel_id,
 | 
			
		||||
                        selector.guild_id,
 | 
			
		||||
                    );
 | 
			
		||||
 | 
			
		||||
                    let mut invoke = CommandInvoke::component(component);
 | 
			
		||||
                    let _ = invoke.respond(&ctx, resp).await;
 | 
			
		||||
                } else {
 | 
			
		||||
                    let _ = component
 | 
			
		||||
                        .create_interaction_response(&ctx, |r| {
 | 
			
		||||
                            r.kind(InteractionResponseType::ChannelMessageWithSource)
 | 
			
		||||
                                .interaction_response_data(|d| {
 | 
			
		||||
                                    d.flags(
 | 
			
		||||
                                        InteractionApplicationCommandCallbackDataFlags::EPHEMERAL,
 | 
			
		||||
                                    )
 | 
			
		||||
                                    .content("Only the user who performed the command can use these components")
 | 
			
		||||
                                })
 | 
			
		||||
                        })
 | 
			
		||||
                        .await;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            ComponentDataModel::MacroPager(pager) => {
 | 
			
		||||
                let mut invoke = CommandInvoke::component(component);
 | 
			
		||||
 | 
			
		||||
                let macros = CommandMacro::from_guild(ctx, invoke.guild_id().unwrap()).await;
 | 
			
		||||
 | 
			
		||||
                let max_page = max_macro_page(¯os);
 | 
			
		||||
                let page = pager.next_page(max_page);
 | 
			
		||||
 | 
			
		||||
                let resp = show_macro_page(¯os, page);
 | 
			
		||||
                let _ = invoke.respond(&ctx, resp).await;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize)]
 | 
			
		||||
pub struct DelSelector {
 | 
			
		||||
    pub page: usize,
 | 
			
		||||
    pub timezone: Tz,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize)]
 | 
			
		||||
pub struct TodoSelector {
 | 
			
		||||
    pub page: usize,
 | 
			
		||||
    pub user_id: Option<u64>,
 | 
			
		||||
    pub channel_id: Option<u64>,
 | 
			
		||||
    pub guild_id: Option<u64>,
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										411
									
								
								src/component_models/pager.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										411
									
								
								src/component_models/pager.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,411 @@
 | 
			
		||||
// todo split pager out into a single struct
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
use serde_repr::*;
 | 
			
		||||
use serenity::{builder::CreateComponents, model::interactions::message_component::ButtonStyle};
 | 
			
		||||
 | 
			
		||||
use crate::{component_models::ComponentDataModel, models::reminder::look_flags::LookFlags};
 | 
			
		||||
 | 
			
		||||
pub trait Pager {
 | 
			
		||||
    fn next_page(&self, max_pages: usize) -> usize;
 | 
			
		||||
 | 
			
		||||
    fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize_repr, Deserialize_repr)]
 | 
			
		||||
#[repr(u8)]
 | 
			
		||||
enum PageAction {
 | 
			
		||||
    First = 0,
 | 
			
		||||
    Previous = 1,
 | 
			
		||||
    Refresh = 2,
 | 
			
		||||
    Next = 3,
 | 
			
		||||
    Last = 4,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize)]
 | 
			
		||||
pub struct LookPager {
 | 
			
		||||
    pub flags: LookFlags,
 | 
			
		||||
    pub page: usize,
 | 
			
		||||
    action: PageAction,
 | 
			
		||||
    pub timezone: Tz,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Pager for LookPager {
 | 
			
		||||
    fn next_page(&self, max_pages: usize) -> usize {
 | 
			
		||||
        match self.action {
 | 
			
		||||
            PageAction::First => 0,
 | 
			
		||||
            PageAction::Previous => 0.max(self.page - 1),
 | 
			
		||||
            PageAction::Refresh => self.page,
 | 
			
		||||
            PageAction::Next => (max_pages - 1).min(self.page + 1),
 | 
			
		||||
            PageAction::Last => max_pages - 1,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
 | 
			
		||||
        let next_page = self.next_page(max_pages);
 | 
			
		||||
 | 
			
		||||
        let (page_first, page_prev, page_refresh, page_next, page_last) =
 | 
			
		||||
            LookPager::buttons(self.flags, next_page, self.timezone);
 | 
			
		||||
 | 
			
		||||
        comp.create_action_row(|row| {
 | 
			
		||||
            row.create_button(|b| {
 | 
			
		||||
                b.label("⏮️")
 | 
			
		||||
                    .style(ButtonStyle::Primary)
 | 
			
		||||
                    .custom_id(page_first.to_custom_id())
 | 
			
		||||
                    .disabled(next_page == 0)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("◀️")
 | 
			
		||||
                    .style(ButtonStyle::Secondary)
 | 
			
		||||
                    .custom_id(page_prev.to_custom_id())
 | 
			
		||||
                    .disabled(next_page == 0)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("▶️")
 | 
			
		||||
                    .style(ButtonStyle::Secondary)
 | 
			
		||||
                    .custom_id(page_next.to_custom_id())
 | 
			
		||||
                    .disabled(next_page + 1 == max_pages)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("⏭️")
 | 
			
		||||
                    .style(ButtonStyle::Primary)
 | 
			
		||||
                    .custom_id(page_last.to_custom_id())
 | 
			
		||||
                    .disabled(next_page + 1 == max_pages)
 | 
			
		||||
            })
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl LookPager {
 | 
			
		||||
    pub fn new(flags: LookFlags, timezone: Tz) -> Self {
 | 
			
		||||
        Self { flags, page: 0, action: PageAction::First, timezone }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn buttons(
 | 
			
		||||
        flags: LookFlags,
 | 
			
		||||
        page: usize,
 | 
			
		||||
        timezone: Tz,
 | 
			
		||||
    ) -> (
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
    ) {
 | 
			
		||||
        (
 | 
			
		||||
            ComponentDataModel::LookPager(LookPager {
 | 
			
		||||
                flags,
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::First,
 | 
			
		||||
                timezone,
 | 
			
		||||
            }),
 | 
			
		||||
            ComponentDataModel::LookPager(LookPager {
 | 
			
		||||
                flags,
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::Previous,
 | 
			
		||||
                timezone,
 | 
			
		||||
            }),
 | 
			
		||||
            ComponentDataModel::LookPager(LookPager {
 | 
			
		||||
                flags,
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::Refresh,
 | 
			
		||||
                timezone,
 | 
			
		||||
            }),
 | 
			
		||||
            ComponentDataModel::LookPager(LookPager {
 | 
			
		||||
                flags,
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::Next,
 | 
			
		||||
                timezone,
 | 
			
		||||
            }),
 | 
			
		||||
            ComponentDataModel::LookPager(LookPager {
 | 
			
		||||
                flags,
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::Last,
 | 
			
		||||
                timezone,
 | 
			
		||||
            }),
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize)]
 | 
			
		||||
pub struct DelPager {
 | 
			
		||||
    pub page: usize,
 | 
			
		||||
    action: PageAction,
 | 
			
		||||
    pub timezone: Tz,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Pager for DelPager {
 | 
			
		||||
    fn next_page(&self, max_pages: usize) -> usize {
 | 
			
		||||
        match self.action {
 | 
			
		||||
            PageAction::First => 0,
 | 
			
		||||
            PageAction::Previous => 0.max(self.page - 1),
 | 
			
		||||
            PageAction::Refresh => self.page,
 | 
			
		||||
            PageAction::Next => (max_pages - 1).min(self.page + 1),
 | 
			
		||||
            PageAction::Last => max_pages - 1,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
 | 
			
		||||
        let next_page = self.next_page(max_pages);
 | 
			
		||||
 | 
			
		||||
        let (page_first, page_prev, page_refresh, page_next, page_last) =
 | 
			
		||||
            DelPager::buttons(next_page, self.timezone);
 | 
			
		||||
 | 
			
		||||
        comp.create_action_row(|row| {
 | 
			
		||||
            row.create_button(|b| {
 | 
			
		||||
                b.label("⏮️")
 | 
			
		||||
                    .style(ButtonStyle::Primary)
 | 
			
		||||
                    .custom_id(page_first.to_custom_id())
 | 
			
		||||
                    .disabled(next_page == 0)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("◀️")
 | 
			
		||||
                    .style(ButtonStyle::Secondary)
 | 
			
		||||
                    .custom_id(page_prev.to_custom_id())
 | 
			
		||||
                    .disabled(next_page == 0)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("▶️")
 | 
			
		||||
                    .style(ButtonStyle::Secondary)
 | 
			
		||||
                    .custom_id(page_next.to_custom_id())
 | 
			
		||||
                    .disabled(next_page + 1 == max_pages)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("⏭️")
 | 
			
		||||
                    .style(ButtonStyle::Primary)
 | 
			
		||||
                    .custom_id(page_last.to_custom_id())
 | 
			
		||||
                    .disabled(next_page + 1 == max_pages)
 | 
			
		||||
            })
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl DelPager {
 | 
			
		||||
    pub fn new(page: usize, timezone: Tz) -> Self {
 | 
			
		||||
        Self { page, action: PageAction::Refresh, timezone }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn buttons(
 | 
			
		||||
        page: usize,
 | 
			
		||||
        timezone: Tz,
 | 
			
		||||
    ) -> (
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
    ) {
 | 
			
		||||
        (
 | 
			
		||||
            ComponentDataModel::DelPager(DelPager { page, action: PageAction::First, timezone }),
 | 
			
		||||
            ComponentDataModel::DelPager(DelPager { page, action: PageAction::Previous, timezone }),
 | 
			
		||||
            ComponentDataModel::DelPager(DelPager { page, action: PageAction::Refresh, timezone }),
 | 
			
		||||
            ComponentDataModel::DelPager(DelPager { page, action: PageAction::Next, timezone }),
 | 
			
		||||
            ComponentDataModel::DelPager(DelPager { page, action: PageAction::Last, timezone }),
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize, Serialize)]
 | 
			
		||||
pub struct TodoPager {
 | 
			
		||||
    pub page: usize,
 | 
			
		||||
    action: PageAction,
 | 
			
		||||
    pub user_id: Option<u64>,
 | 
			
		||||
    pub channel_id: Option<u64>,
 | 
			
		||||
    pub guild_id: Option<u64>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Pager for TodoPager {
 | 
			
		||||
    fn next_page(&self, max_pages: usize) -> usize {
 | 
			
		||||
        match self.action {
 | 
			
		||||
            PageAction::First => 0,
 | 
			
		||||
            PageAction::Previous => 0.max(self.page - 1),
 | 
			
		||||
            PageAction::Refresh => self.page,
 | 
			
		||||
            PageAction::Next => (max_pages - 1).min(self.page + 1),
 | 
			
		||||
            PageAction::Last => max_pages - 1,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
 | 
			
		||||
        let next_page = self.next_page(max_pages);
 | 
			
		||||
 | 
			
		||||
        let (page_first, page_prev, page_refresh, page_next, page_last) =
 | 
			
		||||
            TodoPager::buttons(next_page, self.user_id, self.channel_id, self.guild_id);
 | 
			
		||||
 | 
			
		||||
        comp.create_action_row(|row| {
 | 
			
		||||
            row.create_button(|b| {
 | 
			
		||||
                b.label("⏮️")
 | 
			
		||||
                    .style(ButtonStyle::Primary)
 | 
			
		||||
                    .custom_id(page_first.to_custom_id())
 | 
			
		||||
                    .disabled(next_page == 0)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("◀️")
 | 
			
		||||
                    .style(ButtonStyle::Secondary)
 | 
			
		||||
                    .custom_id(page_prev.to_custom_id())
 | 
			
		||||
                    .disabled(next_page == 0)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("▶️")
 | 
			
		||||
                    .style(ButtonStyle::Secondary)
 | 
			
		||||
                    .custom_id(page_next.to_custom_id())
 | 
			
		||||
                    .disabled(next_page + 1 == max_pages)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("⏭️")
 | 
			
		||||
                    .style(ButtonStyle::Primary)
 | 
			
		||||
                    .custom_id(page_last.to_custom_id())
 | 
			
		||||
                    .disabled(next_page + 1 == max_pages)
 | 
			
		||||
            })
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl TodoPager {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        page: usize,
 | 
			
		||||
        user_id: Option<u64>,
 | 
			
		||||
        channel_id: Option<u64>,
 | 
			
		||||
        guild_id: Option<u64>,
 | 
			
		||||
    ) -> Self {
 | 
			
		||||
        Self { page, action: PageAction::Refresh, user_id, channel_id, guild_id }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn buttons(
 | 
			
		||||
        page: usize,
 | 
			
		||||
        user_id: Option<u64>,
 | 
			
		||||
        channel_id: Option<u64>,
 | 
			
		||||
        guild_id: Option<u64>,
 | 
			
		||||
    ) -> (
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
    ) {
 | 
			
		||||
        (
 | 
			
		||||
            ComponentDataModel::TodoPager(TodoPager {
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::First,
 | 
			
		||||
                user_id,
 | 
			
		||||
                channel_id,
 | 
			
		||||
                guild_id,
 | 
			
		||||
            }),
 | 
			
		||||
            ComponentDataModel::TodoPager(TodoPager {
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::Previous,
 | 
			
		||||
                user_id,
 | 
			
		||||
                channel_id,
 | 
			
		||||
                guild_id,
 | 
			
		||||
            }),
 | 
			
		||||
            ComponentDataModel::TodoPager(TodoPager {
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::Refresh,
 | 
			
		||||
                user_id,
 | 
			
		||||
                channel_id,
 | 
			
		||||
                guild_id,
 | 
			
		||||
            }),
 | 
			
		||||
            ComponentDataModel::TodoPager(TodoPager {
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::Next,
 | 
			
		||||
                user_id,
 | 
			
		||||
                channel_id,
 | 
			
		||||
                guild_id,
 | 
			
		||||
            }),
 | 
			
		||||
            ComponentDataModel::TodoPager(TodoPager {
 | 
			
		||||
                page,
 | 
			
		||||
                action: PageAction::Last,
 | 
			
		||||
                user_id,
 | 
			
		||||
                channel_id,
 | 
			
		||||
                guild_id,
 | 
			
		||||
            }),
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize)]
 | 
			
		||||
pub struct MacroPager {
 | 
			
		||||
    pub page: usize,
 | 
			
		||||
    action: PageAction,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Pager for MacroPager {
 | 
			
		||||
    fn next_page(&self, max_pages: usize) -> usize {
 | 
			
		||||
        match self.action {
 | 
			
		||||
            PageAction::First => 0,
 | 
			
		||||
            PageAction::Previous => 0.max(self.page - 1),
 | 
			
		||||
            PageAction::Refresh => self.page,
 | 
			
		||||
            PageAction::Next => (max_pages - 1).min(self.page + 1),
 | 
			
		||||
            PageAction::Last => max_pages - 1,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
 | 
			
		||||
        let next_page = self.next_page(max_pages);
 | 
			
		||||
 | 
			
		||||
        let (page_first, page_prev, page_refresh, page_next, page_last) =
 | 
			
		||||
            MacroPager::buttons(next_page);
 | 
			
		||||
 | 
			
		||||
        comp.create_action_row(|row| {
 | 
			
		||||
            row.create_button(|b| {
 | 
			
		||||
                b.label("⏮️")
 | 
			
		||||
                    .style(ButtonStyle::Primary)
 | 
			
		||||
                    .custom_id(page_first.to_custom_id())
 | 
			
		||||
                    .disabled(next_page == 0)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("◀️")
 | 
			
		||||
                    .style(ButtonStyle::Secondary)
 | 
			
		||||
                    .custom_id(page_prev.to_custom_id())
 | 
			
		||||
                    .disabled(next_page == 0)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("▶️")
 | 
			
		||||
                    .style(ButtonStyle::Secondary)
 | 
			
		||||
                    .custom_id(page_next.to_custom_id())
 | 
			
		||||
                    .disabled(next_page + 1 == max_pages)
 | 
			
		||||
            })
 | 
			
		||||
            .create_button(|b| {
 | 
			
		||||
                b.label("⏭️")
 | 
			
		||||
                    .style(ButtonStyle::Primary)
 | 
			
		||||
                    .custom_id(page_last.to_custom_id())
 | 
			
		||||
                    .disabled(next_page + 1 == max_pages)
 | 
			
		||||
            })
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl MacroPager {
 | 
			
		||||
    pub fn new(page: usize) -> Self {
 | 
			
		||||
        Self { page, action: PageAction::Refresh }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn buttons(
 | 
			
		||||
        page: usize,
 | 
			
		||||
    ) -> (
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
        ComponentDataModel,
 | 
			
		||||
    ) {
 | 
			
		||||
        (
 | 
			
		||||
            ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::First }),
 | 
			
		||||
            ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Previous }),
 | 
			
		||||
            ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Refresh }),
 | 
			
		||||
            ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Next }),
 | 
			
		||||
            ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Last }),
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,6 +1,8 @@
 | 
			
		||||
pub const DAY: u64 = 86_400;
 | 
			
		||||
pub const HOUR: u64 = 3_600;
 | 
			
		||||
pub const MINUTE: u64 = 60;
 | 
			
		||||
pub const EMBED_DESCRIPTION_MAX_LENGTH: usize = 4000;
 | 
			
		||||
pub const SELECT_MAX_ENTRIES: usize = 25;
 | 
			
		||||
 | 
			
		||||
pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
 | 
			
		||||
 | 
			
		||||
@@ -8,83 +10,50 @@ const THEME_COLOR_FALLBACK: u32 = 0x8fb677;
 | 
			
		||||
 | 
			
		||||
use std::{collections::HashSet, env, iter::FromIterator};
 | 
			
		||||
 | 
			
		||||
use regex::{Regex, RegexBuilder};
 | 
			
		||||
use regex::Regex;
 | 
			
		||||
use serenity::http::AttachmentType;
 | 
			
		||||
 | 
			
		||||
lazy_static! {
 | 
			
		||||
    pub static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref REGEX_ROLE: Regex = Regex::new(r#"<@&(\d+)>"#).unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref REGEX_COMMANDS: Regex = Regex::new(r#"([a-z]+)"#).unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref REGEX_ALIAS: Regex =
 | 
			
		||||
        Regex::new(r#"(?P<name>[\S]{1,12})(?:(?: (?P<cmd>.*)$)|$)"#).unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref REGEX_CONTENT_SUBSTITUTION: Regex = Regex::new(r#"<<((?P<user>\d+)|(?P<role>.{1,100}))>>"#).unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref REMIND_INTERVAL: u64 = env::var("REMIND_INTERVAL")
 | 
			
		||||
        .map(|inner| inner.parse::<u64>().ok())
 | 
			
		||||
        .ok()
 | 
			
		||||
        .flatten()
 | 
			
		||||
        .unwrap_or(10);
 | 
			
		||||
    pub static ref DEFAULT_AVATAR: AttachmentType<'static> = (
 | 
			
		||||
        include_bytes!(concat!(
 | 
			
		||||
            env!("CARGO_MANIFEST_DIR"),
 | 
			
		||||
            "/assets/",
 | 
			
		||||
            env!("WEBHOOK_AVATAR", "WEBHOOK_AVATAR not provided for compilation")
 | 
			
		||||
        )) as &[u8],
 | 
			
		||||
        env!("WEBHOOK_AVATAR"),
 | 
			
		||||
    )
 | 
			
		||||
        .into();
 | 
			
		||||
    pub static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"\s*<(#|@)(?:!)?(\d+)>\s*"#).unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref REGEX_REMIND_COMMAND: Regex = RegexBuilder::new(
 | 
			
		||||
    r#"(?P<mentions>(?:<@\d+>\s+|<@!\d+>\s+|<#\d+>\s+)*)(?P<time>(?:(?:\d+)(?:s|m|h|d|:|/|-|))+)(?:\s+(?P<interval>(?:(?:\d+)(?:s|m|h|d|))+))?(?:\s+(?P<expires>(?:(?:\d+)(?:s|m|h|d|:|/|-|))+))?\s+(?P<content>.*)"#
 | 
			
		||||
    )
 | 
			
		||||
        .dot_matches_new_line(true)
 | 
			
		||||
        .build()
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref REGEX_NATURAL_COMMAND_1: Regex = RegexBuilder::new(
 | 
			
		||||
    r#"(?P<time>.*?)(?:\s+)(?:send|say)(?:\s+)(?P<msg>.*?)(?:(?:\s+)to(?:\s+)(?P<mentions>((?:<@\d+>)|(?:<@!\d+>)|(?:<#\d+>)|(?:\s+))+))?$"#
 | 
			
		||||
    )
 | 
			
		||||
        .dot_matches_new_line(true)
 | 
			
		||||
        .build()
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref REGEX_NATURAL_COMMAND_2: Regex = RegexBuilder::new(
 | 
			
		||||
    r#"(?P<msg>.*)(?:\s+)every(?:\s+)(?P<interval>.*?)(?:(?:\s+)(?:until|for)(?:\s+)(?P<expires>.*?))?$"#
 | 
			
		||||
    )
 | 
			
		||||
        .dot_matches_new_line(true)
 | 
			
		||||
        .build()
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
    pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(
 | 
			
		||||
        env::var("SUBSCRIPTION_ROLES")
 | 
			
		||||
            .map(|var| var
 | 
			
		||||
                .split(',')
 | 
			
		||||
                .filter_map(|item| { item.parse::<u64>().ok() })
 | 
			
		||||
                .collect::<Vec<u64>>())
 | 
			
		||||
            .unwrap_or_else(|_| vec![])
 | 
			
		||||
            .unwrap_or_else(|_| Vec::new())
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD")
 | 
			
		||||
        .map(|var| var.parse::<u64>().ok())
 | 
			
		||||
        .ok()
 | 
			
		||||
        .flatten();
 | 
			
		||||
 | 
			
		||||
    pub static ref CNC_GUILD: Option<u64> =
 | 
			
		||||
        env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten();
 | 
			
		||||
    pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL")
 | 
			
		||||
        .ok()
 | 
			
		||||
        .map(|inner| inner.parse::<i64>().ok())
 | 
			
		||||
        .flatten()
 | 
			
		||||
        .unwrap_or(600);
 | 
			
		||||
 | 
			
		||||
    pub static ref MAX_TIME: i64 = env::var("MAX_TIME")
 | 
			
		||||
        .ok()
 | 
			
		||||
        .map(|inner| inner.parse::<i64>().ok())
 | 
			
		||||
        .flatten()
 | 
			
		||||
        .unwrap_or(60 * 60 * 24 * 365 * 50);
 | 
			
		||||
 | 
			
		||||
    pub static ref LOCAL_TIMEZONE: String =
 | 
			
		||||
        env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string());
 | 
			
		||||
 | 
			
		||||
    pub static ref LOCAL_LANGUAGE: String =
 | 
			
		||||
        env::var("LOCAL_LANGUAGE").unwrap_or_else(|_| "EN".to_string());
 | 
			
		||||
 | 
			
		||||
    pub static ref DEFAULT_PREFIX: String =
 | 
			
		||||
        env::var("DEFAULT_PREFIX").unwrap_or_else(|_| "$".to_string());
 | 
			
		||||
 | 
			
		||||
    pub static ref THEME_COLOR: u32 = env::var("THEME_COLOR").map_or(
 | 
			
		||||
        THEME_COLOR_FALLBACK,
 | 
			
		||||
        |inner| u32::from_str_radix(&inner, 16).unwrap_or(THEME_COLOR_FALLBACK)
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    pub static ref THEME_COLOR: u32 = env::var("THEME_COLOR")
 | 
			
		||||
        .map_or(THEME_COLOR_FALLBACK, |inner| u32::from_str_radix(&inner, 16)
 | 
			
		||||
            .unwrap_or(THEME_COLOR_FALLBACK));
 | 
			
		||||
    pub static ref PYTHON_LOCATION: String =
 | 
			
		||||
        env::var("PYTHON_LOCATION").unwrap_or_else(|_| "venv/bin/python3".to_string());
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1009
									
								
								src/framework.rs
									
									
									
									
									
								
							
							
						
						
									
										1009
									
								
								src/framework.rs
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										152
									
								
								src/hooks.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								src/hooks.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,152 @@
 | 
			
		||||
use regex_command_attr::check;
 | 
			
		||||
use serenity::{client::Context, model::channel::Channel};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    framework::{CommandInvoke, CommandOptions, CreateGenericResponse, HookResult},
 | 
			
		||||
    moderation_cmds, RecordingMacros,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#[check]
 | 
			
		||||
pub async fn guild_only(
 | 
			
		||||
    ctx: &Context,
 | 
			
		||||
    invoke: &mut CommandInvoke,
 | 
			
		||||
    _args: &CommandOptions,
 | 
			
		||||
) -> HookResult {
 | 
			
		||||
    if invoke.guild_id().is_some() {
 | 
			
		||||
        HookResult::Continue
 | 
			
		||||
    } else {
 | 
			
		||||
        let _ = invoke
 | 
			
		||||
            .respond(
 | 
			
		||||
                &ctx,
 | 
			
		||||
                CreateGenericResponse::new().content("This command can only be used in servers"),
 | 
			
		||||
            )
 | 
			
		||||
            .await;
 | 
			
		||||
 | 
			
		||||
        HookResult::Halt
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[check]
 | 
			
		||||
pub async fn macro_check(
 | 
			
		||||
    ctx: &Context,
 | 
			
		||||
    invoke: &mut CommandInvoke,
 | 
			
		||||
    args: &CommandOptions,
 | 
			
		||||
) -> HookResult {
 | 
			
		||||
    if let Some(guild_id) = invoke.guild_id() {
 | 
			
		||||
        if args.command != moderation_cmds::MACRO_CMD_COMMAND.names[0] {
 | 
			
		||||
            let active_recordings =
 | 
			
		||||
                ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
 | 
			
		||||
            let mut lock = active_recordings.write().await;
 | 
			
		||||
 | 
			
		||||
            if let Some(command_macro) = lock.get_mut(&(guild_id, invoke.author_id())) {
 | 
			
		||||
                if command_macro.commands.len() >= 5 {
 | 
			
		||||
                    let _ = invoke
 | 
			
		||||
                        .respond(
 | 
			
		||||
                            &ctx,
 | 
			
		||||
                            CreateGenericResponse::new().content("5 commands already recorded. Please use `/macro finish` to end recording."),
 | 
			
		||||
                        )
 | 
			
		||||
                        .await;
 | 
			
		||||
                } else {
 | 
			
		||||
                    command_macro.commands.push(args.clone());
 | 
			
		||||
 | 
			
		||||
                    let _ = invoke
 | 
			
		||||
                        .respond(
 | 
			
		||||
                            &ctx,
 | 
			
		||||
                            CreateGenericResponse::new().content("Command recorded to macro"),
 | 
			
		||||
                        )
 | 
			
		||||
                        .await;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                HookResult::Halt
 | 
			
		||||
            } else {
 | 
			
		||||
                HookResult::Continue
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            HookResult::Continue
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        HookResult::Continue
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[check]
 | 
			
		||||
pub async fn check_self_permissions(
 | 
			
		||||
    ctx: &Context,
 | 
			
		||||
    invoke: &mut CommandInvoke,
 | 
			
		||||
    _args: &CommandOptions,
 | 
			
		||||
) -> HookResult {
 | 
			
		||||
    if let Some(guild) = invoke.guild(&ctx) {
 | 
			
		||||
        let user_id = ctx.cache.current_user_id();
 | 
			
		||||
 | 
			
		||||
        let manage_webhooks =
 | 
			
		||||
            guild.member_permissions(&ctx, user_id).await.map_or(false, |p| p.manage_webhooks());
 | 
			
		||||
        let (view_channel, send_messages, embed_links) = invoke
 | 
			
		||||
            .channel_id()
 | 
			
		||||
            .to_channel_cached(&ctx)
 | 
			
		||||
            .map(|c| {
 | 
			
		||||
                if let Channel::Guild(channel) = c {
 | 
			
		||||
                    channel.permissions_for_user(ctx, user_id).ok()
 | 
			
		||||
                } else {
 | 
			
		||||
                    None
 | 
			
		||||
                }
 | 
			
		||||
            })
 | 
			
		||||
            .flatten()
 | 
			
		||||
            .map_or((false, false, false), |p| {
 | 
			
		||||
                (p.read_messages(), p.send_messages(), p.embed_links())
 | 
			
		||||
            });
 | 
			
		||||
 | 
			
		||||
        if manage_webhooks && send_messages && embed_links {
 | 
			
		||||
            HookResult::Continue
 | 
			
		||||
        } else {
 | 
			
		||||
            let _ = invoke
 | 
			
		||||
                .respond(
 | 
			
		||||
                    &ctx,
 | 
			
		||||
                    CreateGenericResponse::new().content(format!(
 | 
			
		||||
                        "Please ensure the bot has the correct permissions:
 | 
			
		||||
 | 
			
		||||
{}     **View Channel**
 | 
			
		||||
{}     **Send Message**
 | 
			
		||||
{}     **Embed Links**
 | 
			
		||||
{}     **Manage Webhooks**",
 | 
			
		||||
                        if view_channel { "✅" } else { "❌" },
 | 
			
		||||
                        if send_messages { "✅" } else { "❌" },
 | 
			
		||||
                        if manage_webhooks { "✅" } else { "❌" },
 | 
			
		||||
                        if embed_links { "✅" } else { "❌" },
 | 
			
		||||
                    )),
 | 
			
		||||
                )
 | 
			
		||||
                .await;
 | 
			
		||||
 | 
			
		||||
            HookResult::Halt
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        HookResult::Continue
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[check]
 | 
			
		||||
pub async fn check_guild_permissions(
 | 
			
		||||
    ctx: &Context,
 | 
			
		||||
    invoke: &mut CommandInvoke,
 | 
			
		||||
    _args: &CommandOptions,
 | 
			
		||||
) -> HookResult {
 | 
			
		||||
    if let Some(guild) = invoke.guild(&ctx) {
 | 
			
		||||
        let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap();
 | 
			
		||||
 | 
			
		||||
        if !permissions.manage_guild() {
 | 
			
		||||
            let _ = invoke
 | 
			
		||||
                .respond(
 | 
			
		||||
                    &ctx,
 | 
			
		||||
                    CreateGenericResponse::new().content(
 | 
			
		||||
                        "You must have the \"Manage Server\" permission to use this command",
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
                .await;
 | 
			
		||||
 | 
			
		||||
            HookResult::Halt
 | 
			
		||||
        } else {
 | 
			
		||||
            HookResult::Continue
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        HookResult::Continue
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,65 +0,0 @@
 | 
			
		||||
use serde::Deserialize;
 | 
			
		||||
use serde_json::from_str;
 | 
			
		||||
use serenity::prelude::TypeMapKey;
 | 
			
		||||
 | 
			
		||||
use std::{collections::HashMap, error::Error, sync::Arc};
 | 
			
		||||
 | 
			
		||||
use crate::consts::LOCAL_LANGUAGE;
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize)]
 | 
			
		||||
pub struct LanguageManager {
 | 
			
		||||
    languages: HashMap<String, String>,
 | 
			
		||||
    strings: HashMap<String, HashMap<String, String>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl LanguageManager {
 | 
			
		||||
    pub fn from_compiled(content: &'static str) -> Result<Self, Box<dyn Error + Send + Sync>> {
 | 
			
		||||
        let new: Self = from_str(content.as_ref())?;
 | 
			
		||||
 | 
			
		||||
        Ok(new)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn get(&self, language: &str, name: &str) -> &str {
 | 
			
		||||
        self.strings
 | 
			
		||||
            .get(language)
 | 
			
		||||
            .map(|sm| sm.get(name))
 | 
			
		||||
            .expect(&format!(r#"Language does not exist: "{}""#, language))
 | 
			
		||||
            .unwrap_or_else(|| {
 | 
			
		||||
                self.strings
 | 
			
		||||
                    .get(&*LOCAL_LANGUAGE)
 | 
			
		||||
                    .map(|sm| {
 | 
			
		||||
                        sm.get(name)
 | 
			
		||||
                            .expect(&format!(r#"String does not exist: "{}""#, name))
 | 
			
		||||
                    })
 | 
			
		||||
                    .expect("LOCAL_LANGUAGE is not available")
 | 
			
		||||
            })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn get_language(&self, language: &str) -> Option<&str> {
 | 
			
		||||
        let language_normal = language.to_lowercase();
 | 
			
		||||
 | 
			
		||||
        self.languages
 | 
			
		||||
            .iter()
 | 
			
		||||
            .filter(|(k, v)| {
 | 
			
		||||
                k.to_lowercase() == language_normal || v.to_lowercase() == language_normal
 | 
			
		||||
            })
 | 
			
		||||
            .map(|(k, _)| k.as_str())
 | 
			
		||||
            .next()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn get_language_by_flag(&self, flag: &str) -> Option<&str> {
 | 
			
		||||
        self.languages
 | 
			
		||||
            .iter()
 | 
			
		||||
            .filter(|(k, _)| self.get(k, "flag") == flag)
 | 
			
		||||
            .map(|(k, _)| k.as_str())
 | 
			
		||||
            .next()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn all_languages(&self) -> impl Iterator<Item = (&str, &str)> {
 | 
			
		||||
        self.languages.iter().map(|(k, v)| (k.as_str(), v.as_str()))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl TypeMapKey for LanguageManager {
 | 
			
		||||
    type Value = Arc<Self>;
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										372
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										372
									
								
								src/main.rs
									
									
									
									
									
								
							@@ -1,59 +1,56 @@
 | 
			
		||||
#![feature(int_roundings)]
 | 
			
		||||
#[macro_use]
 | 
			
		||||
extern crate lazy_static;
 | 
			
		||||
 | 
			
		||||
mod commands;
 | 
			
		||||
mod component_models;
 | 
			
		||||
mod consts;
 | 
			
		||||
mod framework;
 | 
			
		||||
mod language_manager;
 | 
			
		||||
mod hooks;
 | 
			
		||||
mod models;
 | 
			
		||||
mod sender;
 | 
			
		||||
mod time_parser;
 | 
			
		||||
 | 
			
		||||
use std::{
 | 
			
		||||
    collections::HashMap,
 | 
			
		||||
    env,
 | 
			
		||||
    sync::{
 | 
			
		||||
        atomic::{AtomicBool, Ordering},
 | 
			
		||||
        Arc,
 | 
			
		||||
    },
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use dotenv::dotenv;
 | 
			
		||||
use log::info;
 | 
			
		||||
use serenity::{
 | 
			
		||||
    async_trait,
 | 
			
		||||
    cache::Cache,
 | 
			
		||||
    client::{bridge::gateway::GatewayIntents, Client},
 | 
			
		||||
    http::{client::Http, CacheHttp},
 | 
			
		||||
    model::{
 | 
			
		||||
        channel::GuildChannel,
 | 
			
		||||
        channel::Message,
 | 
			
		||||
        gateway::{Activity, Ready},
 | 
			
		||||
        guild::{Guild, GuildUnavailable},
 | 
			
		||||
        id::{GuildId, UserId},
 | 
			
		||||
        interactions::Interaction,
 | 
			
		||||
    },
 | 
			
		||||
    prelude::{Context, EventHandler, TypeMapKey},
 | 
			
		||||
    utils::shard_id,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use sqlx::mysql::MySqlPool;
 | 
			
		||||
 | 
			
		||||
use dotenv::dotenv;
 | 
			
		||||
 | 
			
		||||
use std::{collections::HashMap, env, sync::Arc};
 | 
			
		||||
use tokio::{
 | 
			
		||||
    sync::RwLock,
 | 
			
		||||
    time::{Duration, Instant},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds},
 | 
			
		||||
    consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR},
 | 
			
		||||
    component_models::ComponentDataModel,
 | 
			
		||||
    consts::{CNC_GUILD, REMIND_INTERVAL, SUBSCRIPTION_ROLES, THEME_COLOR},
 | 
			
		||||
    framework::RegexFramework,
 | 
			
		||||
    language_manager::LanguageManager,
 | 
			
		||||
    models::GuildData,
 | 
			
		||||
    models::command_macro::CommandMacro,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use serenity::futures::TryFutureExt;
 | 
			
		||||
 | 
			
		||||
use inflector::Inflector;
 | 
			
		||||
use log::info;
 | 
			
		||||
 | 
			
		||||
use dashmap::DashMap;
 | 
			
		||||
 | 
			
		||||
use tokio::sync::RwLock;
 | 
			
		||||
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
 | 
			
		||||
struct GuildDataCache;
 | 
			
		||||
 | 
			
		||||
impl TypeMapKey for GuildDataCache {
 | 
			
		||||
    type Value = Arc<DashMap<GuildId, Arc<RwLock<GuildData>>>>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct SQLPool;
 | 
			
		||||
 | 
			
		||||
impl TypeMapKey for SQLPool {
 | 
			
		||||
@@ -66,22 +63,54 @@ impl TypeMapKey for ReqwestClient {
 | 
			
		||||
    type Value = Arc<reqwest::Client>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct FrameworkCtx;
 | 
			
		||||
 | 
			
		||||
impl TypeMapKey for FrameworkCtx {
 | 
			
		||||
    type Value = Arc<RegexFramework>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct PopularTimezones;
 | 
			
		||||
 | 
			
		||||
impl TypeMapKey for PopularTimezones {
 | 
			
		||||
    type Value = Arc<Vec<Tz>>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct Handler;
 | 
			
		||||
struct RecordingMacros;
 | 
			
		||||
 | 
			
		||||
impl TypeMapKey for RecordingMacros {
 | 
			
		||||
    type Value = Arc<RwLock<HashMap<(GuildId, UserId), CommandMacro>>>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct Handler {
 | 
			
		||||
    is_loop_running: AtomicBool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl EventHandler for Handler {
 | 
			
		||||
    async fn cache_ready(&self, ctx_base: Context, _guilds: Vec<GuildId>) {
 | 
			
		||||
        info!("Cache Ready!");
 | 
			
		||||
        info!("Preparing to send reminders");
 | 
			
		||||
 | 
			
		||||
        if !self.is_loop_running.load(Ordering::Relaxed) {
 | 
			
		||||
            let ctx = ctx_base.clone();
 | 
			
		||||
 | 
			
		||||
            tokio::spawn(async move {
 | 
			
		||||
                let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
                loop {
 | 
			
		||||
                    let sleep_until = Instant::now() + Duration::from_secs(*REMIND_INTERVAL);
 | 
			
		||||
                    let reminders = sender::Reminder::fetch_reminders(&pool).await;
 | 
			
		||||
 | 
			
		||||
                    if reminders.len() > 0 {
 | 
			
		||||
                        info!("Preparing to send {} reminders.", reminders.len());
 | 
			
		||||
 | 
			
		||||
                        for reminder in reminders {
 | 
			
		||||
                            reminder.send(pool.clone(), ctx.clone()).await;
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    tokio::time::sleep_until(sleep_until).await;
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
 | 
			
		||||
            self.is_loop_running.swap(true, Ordering::Relaxed);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) {
 | 
			
		||||
        let pool = ctx
 | 
			
		||||
            .data
 | 
			
		||||
@@ -107,28 +136,20 @@ DELETE FROM channels WHERE channel = ?
 | 
			
		||||
            let guild_id = guild.id.as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
            {
 | 
			
		||||
                let pool = ctx
 | 
			
		||||
                    .data
 | 
			
		||||
                    .read()
 | 
			
		||||
                    .await
 | 
			
		||||
                    .get::<SQLPool>()
 | 
			
		||||
                    .cloned()
 | 
			
		||||
                    .expect("Could not get SQLPool from data");
 | 
			
		||||
                let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
                GuildData::from_guild(guild, &pool).await.expect(&format!(
 | 
			
		||||
                    "Failed to create new guild object for {}",
 | 
			
		||||
                    guild_id
 | 
			
		||||
                ));
 | 
			
		||||
                let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id)
 | 
			
		||||
                    .execute(&pool)
 | 
			
		||||
                    .await;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
 | 
			
		||||
                let shard_count = ctx.cache.shard_count().await;
 | 
			
		||||
                let shard_count = ctx.cache.shard_count();
 | 
			
		||||
                let current_shard_id = shard_id(guild_id, shard_count);
 | 
			
		||||
 | 
			
		||||
                let guild_count = ctx
 | 
			
		||||
                    .cache
 | 
			
		||||
                    .guilds()
 | 
			
		||||
                    .await
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id)
 | 
			
		||||
                    .count() as u64;
 | 
			
		||||
@@ -150,7 +171,7 @@ DELETE FROM channels WHERE channel = ?
 | 
			
		||||
                    .post(
 | 
			
		||||
                        format!(
 | 
			
		||||
                            "https://top.gg/api/bots/{}/stats",
 | 
			
		||||
                            ctx.cache.current_user_id().await.as_u64()
 | 
			
		||||
                            ctx.cache.current_user_id().as_u64()
 | 
			
		||||
                        )
 | 
			
		||||
                        .as_str(),
 | 
			
		||||
                    )
 | 
			
		||||
@@ -166,33 +187,36 @@ DELETE FROM channels WHERE channel = ?
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn guild_delete(&self, ctx: Context, guild: GuildUnavailable, _guild: Option<Guild>) {
 | 
			
		||||
        let pool = ctx
 | 
			
		||||
            .data
 | 
			
		||||
            .read()
 | 
			
		||||
            .await
 | 
			
		||||
            .get::<SQLPool>()
 | 
			
		||||
            .cloned()
 | 
			
		||||
            .expect("Could not get SQLPool from data");
 | 
			
		||||
    async fn guild_delete(&self, ctx: Context, incomplete: GuildUnavailable, _full: Option<Guild>) {
 | 
			
		||||
        let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
        let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0)
 | 
			
		||||
            .execute(&pool)
 | 
			
		||||
            .await;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
        let guild_data_cache = ctx
 | 
			
		||||
            .data
 | 
			
		||||
            .read()
 | 
			
		||||
            .await
 | 
			
		||||
            .get::<GuildDataCache>()
 | 
			
		||||
            .cloned()
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        guild_data_cache.remove(&guild.id);
 | 
			
		||||
    async fn ready(&self, ctx: Context, _: Ready) {
 | 
			
		||||
        ctx.set_activity(Activity::watching("for /remind")).await;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
DELETE FROM guilds WHERE guild = ?
 | 
			
		||||
            ",
 | 
			
		||||
            guild.id.as_u64()
 | 
			
		||||
        )
 | 
			
		||||
        .execute(&pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
    async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
 | 
			
		||||
        match interaction {
 | 
			
		||||
            Interaction::ApplicationCommand(application_command) => {
 | 
			
		||||
                let framework = ctx
 | 
			
		||||
                    .data
 | 
			
		||||
                    .read()
 | 
			
		||||
                    .await
 | 
			
		||||
                    .get::<RegexFramework>()
 | 
			
		||||
                    .cloned()
 | 
			
		||||
                    .expect("RegexFramework not found in context");
 | 
			
		||||
 | 
			
		||||
                framework.execute(ctx, application_command).await;
 | 
			
		||||
            }
 | 
			
		||||
            Interaction::MessageComponent(component) => {
 | 
			
		||||
                let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id);
 | 
			
		||||
                component_model.act(&ctx, component).await;
 | 
			
		||||
            }
 | 
			
		||||
            _ => {}
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -204,98 +228,59 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
 | 
			
		||||
 | 
			
		||||
    let token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment");
 | 
			
		||||
 | 
			
		||||
    let http = Http::new_with_token(&token);
 | 
			
		||||
    let application_id = {
 | 
			
		||||
        let http = Http::new_with_token(&token);
 | 
			
		||||
 | 
			
		||||
    let logged_in_id = http
 | 
			
		||||
        .get_current_user()
 | 
			
		||||
        .map_ok(|user| user.id.as_u64().to_owned())
 | 
			
		||||
        .await?;
 | 
			
		||||
        http.get_current_application_info().await?.id
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1");
 | 
			
		||||
 | 
			
		||||
    let framework = RegexFramework::new(logged_in_id)
 | 
			
		||||
        .default_prefix(DEFAULT_PREFIX.clone())
 | 
			
		||||
        .case_insensitive(env::var("CASE_INSENSITIVE").map_or(true, |var| var == "1"))
 | 
			
		||||
    let framework = RegexFramework::new()
 | 
			
		||||
        .ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
 | 
			
		||||
        .debug_guild(env::var("DEBUG_GUILD").map_or(None, |g| {
 | 
			
		||||
            Some(GuildId(g.parse::<u64>().expect("DEBUG_GUILD must be a guild ID")))
 | 
			
		||||
        }))
 | 
			
		||||
        .dm_enabled(dm_enabled)
 | 
			
		||||
        // info commands
 | 
			
		||||
        .add_command("ping", &info_cmds::PING_COMMAND)
 | 
			
		||||
        .add_command("help", &info_cmds::HELP_COMMAND)
 | 
			
		||||
        .add_command("info", &info_cmds::INFO_COMMAND)
 | 
			
		||||
        .add_command("invite", &info_cmds::INFO_COMMAND)
 | 
			
		||||
        .add_command("donate", &info_cmds::DONATE_COMMAND)
 | 
			
		||||
        .add_command("dashboard", &info_cmds::DASHBOARD_COMMAND)
 | 
			
		||||
        .add_command("clock", &info_cmds::CLOCK_COMMAND)
 | 
			
		||||
        .add_command(&info_cmds::HELP_COMMAND)
 | 
			
		||||
        .add_command(&info_cmds::INFO_COMMAND)
 | 
			
		||||
        .add_command(&info_cmds::DONATE_COMMAND)
 | 
			
		||||
        .add_command(&info_cmds::DASHBOARD_COMMAND)
 | 
			
		||||
        .add_command(&info_cmds::CLOCK_COMMAND)
 | 
			
		||||
        // reminder commands
 | 
			
		||||
        .add_command("timer", &reminder_cmds::TIMER_COMMAND)
 | 
			
		||||
        .add_command("remind", &reminder_cmds::REMIND_COMMAND)
 | 
			
		||||
        .add_command("r", &reminder_cmds::REMIND_COMMAND)
 | 
			
		||||
        .add_command("interval", &reminder_cmds::INTERVAL_COMMAND)
 | 
			
		||||
        .add_command("i", &reminder_cmds::INTERVAL_COMMAND)
 | 
			
		||||
        .add_command("natural", &reminder_cmds::NATURAL_COMMAND)
 | 
			
		||||
        .add_command("n", &reminder_cmds::NATURAL_COMMAND)
 | 
			
		||||
        .add_command("", &reminder_cmds::NATURAL_COMMAND)
 | 
			
		||||
        .add_command("countdown", &reminder_cmds::COUNTDOWN_COMMAND)
 | 
			
		||||
        .add_command(&reminder_cmds::TIMER_COMMAND)
 | 
			
		||||
        .add_command(&reminder_cmds::REMIND_COMMAND)
 | 
			
		||||
        // management commands
 | 
			
		||||
        .add_command("look", &reminder_cmds::LOOK_COMMAND)
 | 
			
		||||
        .add_command("del", &reminder_cmds::DELETE_COMMAND)
 | 
			
		||||
        .add_command(&reminder_cmds::DELETE_COMMAND)
 | 
			
		||||
        .add_command(&reminder_cmds::LOOK_COMMAND)
 | 
			
		||||
        .add_command(&reminder_cmds::PAUSE_COMMAND)
 | 
			
		||||
        .add_command(&reminder_cmds::OFFSET_COMMAND)
 | 
			
		||||
        .add_command(&reminder_cmds::NUDGE_COMMAND)
 | 
			
		||||
        // to-do commands
 | 
			
		||||
        .add_command("todo", &todo_cmds::TODO_USER_COMMAND)
 | 
			
		||||
        .add_command("todo user", &todo_cmds::TODO_USER_COMMAND)
 | 
			
		||||
        .add_command("todoc", &todo_cmds::TODO_CHANNEL_COMMAND)
 | 
			
		||||
        .add_command("todo channel", &todo_cmds::TODO_CHANNEL_COMMAND)
 | 
			
		||||
        .add_command("todos", &todo_cmds::TODO_GUILD_COMMAND)
 | 
			
		||||
        .add_command("todo server", &todo_cmds::TODO_GUILD_COMMAND)
 | 
			
		||||
        .add_command("todo guild", &todo_cmds::TODO_GUILD_COMMAND)
 | 
			
		||||
        .add_command(&todo_cmds::TODO_COMMAND)
 | 
			
		||||
        // moderation commands
 | 
			
		||||
        .add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND)
 | 
			
		||||
        .add_command("restrict", &moderation_cmds::RESTRICT_COMMAND)
 | 
			
		||||
        .add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND)
 | 
			
		||||
        .add_command("meridian", &moderation_cmds::CHANGE_MERIDIAN_COMMAND)
 | 
			
		||||
        .add_command("prefix", &moderation_cmds::PREFIX_COMMAND)
 | 
			
		||||
        .add_command("lang", &moderation_cmds::LANGUAGE_COMMAND)
 | 
			
		||||
        .add_command("pause", &reminder_cmds::PAUSE_COMMAND)
 | 
			
		||||
        .add_command("offset", &reminder_cmds::OFFSET_COMMAND)
 | 
			
		||||
        .add_command("nudge", &reminder_cmds::NUDGE_COMMAND)
 | 
			
		||||
        .add_command("alias", &moderation_cmds::ALIAS_COMMAND)
 | 
			
		||||
        .add_command("a", &moderation_cmds::ALIAS_COMMAND)
 | 
			
		||||
        .build();
 | 
			
		||||
        .add_command(&moderation_cmds::TIMEZONE_COMMAND)
 | 
			
		||||
        .add_command(&moderation_cmds::MACRO_CMD_COMMAND)
 | 
			
		||||
        .add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK)
 | 
			
		||||
        .add_hook(&hooks::MACRO_CHECK_HOOK);
 | 
			
		||||
 | 
			
		||||
    let framework_arc = Arc::new(framework);
 | 
			
		||||
 | 
			
		||||
    let mut client = Client::builder(&token)
 | 
			
		||||
        .intents(if dm_enabled {
 | 
			
		||||
            GatewayIntents::GUILD_MESSAGES
 | 
			
		||||
                | GatewayIntents::GUILDS
 | 
			
		||||
                | GatewayIntents::GUILD_MESSAGE_REACTIONS
 | 
			
		||||
                | GatewayIntents::DIRECT_MESSAGES
 | 
			
		||||
                | GatewayIntents::DIRECT_MESSAGE_REACTIONS
 | 
			
		||||
        } else {
 | 
			
		||||
            GatewayIntents::GUILD_MESSAGES
 | 
			
		||||
                | GatewayIntents::GUILDS
 | 
			
		||||
                | GatewayIntents::GUILD_MESSAGE_REACTIONS
 | 
			
		||||
        })
 | 
			
		||||
        .event_handler(Handler)
 | 
			
		||||
        .framework_arc(framework_arc.clone())
 | 
			
		||||
        .intents(GatewayIntents::GUILDS)
 | 
			
		||||
        .application_id(application_id.0)
 | 
			
		||||
        .event_handler(Handler { is_loop_running: AtomicBool::from(false) })
 | 
			
		||||
        .await
 | 
			
		||||
        .expect("Error occurred creating client");
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        let guild_data_cache = dashmap::DashMap::new();
 | 
			
		||||
 | 
			
		||||
        let pool = MySqlPool::connect(
 | 
			
		||||
            &env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment"),
 | 
			
		||||
        )
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        let language_manager = LanguageManager::from_compiled(include_str!(concat!(
 | 
			
		||||
            env!("CARGO_MANIFEST_DIR"),
 | 
			
		||||
            "/assets/",
 | 
			
		||||
            env!("STRINGS_FILE")
 | 
			
		||||
        )))
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        let popular_timezones = sqlx::query!(
 | 
			
		||||
            "SELECT timezone FROM users GROUP BY timezone ORDER BY COUNT(timezone) DESC LIMIT 21"
 | 
			
		||||
        )
 | 
			
		||||
@@ -308,19 +293,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
 | 
			
		||||
 | 
			
		||||
        let mut data = client.data.write().await;
 | 
			
		||||
 | 
			
		||||
        data.insert::<GuildDataCache>(Arc::new(guild_data_cache));
 | 
			
		||||
 | 
			
		||||
        data.insert::<SQLPool>(pool);
 | 
			
		||||
        data.insert::<PopularTimezones>(Arc::new(popular_timezones));
 | 
			
		||||
        data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));
 | 
			
		||||
        data.insert::<FrameworkCtx>(framework_arc.clone());
 | 
			
		||||
        data.insert::<LanguageManager>(Arc::new(language_manager))
 | 
			
		||||
        data.insert::<RegexFramework>(framework_arc.clone());
 | 
			
		||||
        data.insert::<RecordingMacros>(Arc::new(RwLock::new(HashMap::new())));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    framework_arc.build_slash(&client.cache_and_http.http).await;
 | 
			
		||||
 | 
			
		||||
    if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| {
 | 
			
		||||
        let mut split = sr
 | 
			
		||||
            .split(',')
 | 
			
		||||
            .map(|val| val.parse::<u64>().expect("SHARD_RANGE not an integer"));
 | 
			
		||||
        let mut split =
 | 
			
		||||
            sr.split(',').map(|val| val.parse::<u64>().expect("SHARD_RANGE not an integer"));
 | 
			
		||||
 | 
			
		||||
        (split.next(), split.next())
 | 
			
		||||
    }) {
 | 
			
		||||
@@ -330,24 +314,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
 | 
			
		||||
            .flatten()
 | 
			
		||||
            .expect("No SHARD_COUNT provided, but SHARD_RANGE was provided");
 | 
			
		||||
 | 
			
		||||
        assert!(
 | 
			
		||||
            lower < upper,
 | 
			
		||||
            "SHARD_RANGE lower limit is not less than the upper limit"
 | 
			
		||||
        );
 | 
			
		||||
        assert!(lower < upper, "SHARD_RANGE lower limit is not less than the upper limit");
 | 
			
		||||
 | 
			
		||||
        info!(
 | 
			
		||||
            "Starting client fragment with shards {}-{}/{}",
 | 
			
		||||
            lower, upper, total_shards
 | 
			
		||||
        );
 | 
			
		||||
        info!("Starting client fragment with shards {}-{}/{}", lower, upper, total_shards);
 | 
			
		||||
 | 
			
		||||
        client
 | 
			
		||||
            .start_shard_range([lower, upper], total_shards)
 | 
			
		||||
            .await?;
 | 
			
		||||
    } else if let Ok(total_shards) = env::var("SHARD_COUNT").map(|shard_count| {
 | 
			
		||||
        shard_count
 | 
			
		||||
            .parse::<u64>()
 | 
			
		||||
            .expect("SHARD_COUNT not an integer")
 | 
			
		||||
    }) {
 | 
			
		||||
        client.start_shard_range([lower, upper], total_shards).await?;
 | 
			
		||||
    } else if let Ok(total_shards) = env::var("SHARD_COUNT")
 | 
			
		||||
        .map(|shard_count| shard_count.parse::<u64>().expect("SHARD_COUNT not an integer"))
 | 
			
		||||
    {
 | 
			
		||||
        info!("Starting client with {} shards", total_shards);
 | 
			
		||||
 | 
			
		||||
        client.start_shards(total_shards).await?;
 | 
			
		||||
@@ -362,9 +336,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
 | 
			
		||||
 | 
			
		||||
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
 | 
			
		||||
    if let Some(subscription_guild) = *CNC_GUILD {
 | 
			
		||||
        let guild_member = GuildId(subscription_guild)
 | 
			
		||||
            .member(cache_http, user_id)
 | 
			
		||||
            .await;
 | 
			
		||||
        let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await;
 | 
			
		||||
 | 
			
		||||
        if let Ok(member) = guild_member {
 | 
			
		||||
            for role in member.roles {
 | 
			
		||||
@@ -380,65 +352,15 @@ pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<U
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn check_subscription_on_message(
 | 
			
		||||
    cache_http: impl CacheHttp + AsRef<Cache>,
 | 
			
		||||
    msg: &Message,
 | 
			
		||||
pub async fn check_guild_subscription(
 | 
			
		||||
    cache_http: impl CacheHttp,
 | 
			
		||||
    guild_id: impl Into<GuildId>,
 | 
			
		||||
) -> bool {
 | 
			
		||||
    check_subscription(&cache_http, &msg.author).await
 | 
			
		||||
        || if let Some(guild) = msg.guild(&cache_http).await {
 | 
			
		||||
            check_subscription(&cache_http, guild.owner_id).await
 | 
			
		||||
        } else {
 | 
			
		||||
            false
 | 
			
		||||
        }
 | 
			
		||||
}
 | 
			
		||||
    if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) {
 | 
			
		||||
        let owner = guild.owner_id;
 | 
			
		||||
 | 
			
		||||
pub async fn get_ctx_data(ctx: &&Context) -> (MySqlPool, Arc<LanguageManager>) {
 | 
			
		||||
    let pool;
 | 
			
		||||
    let lm;
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        let data = ctx.data.read().await;
 | 
			
		||||
 | 
			
		||||
        pool = data
 | 
			
		||||
            .get::<SQLPool>()
 | 
			
		||||
            .cloned()
 | 
			
		||||
            .expect("Could not get SQLPool");
 | 
			
		||||
 | 
			
		||||
        lm = data
 | 
			
		||||
            .get::<LanguageManager>()
 | 
			
		||||
            .cloned()
 | 
			
		||||
            .expect("Could not get LanguageManager");
 | 
			
		||||
        check_subscription(&cache_http, owner).await
 | 
			
		||||
    } else {
 | 
			
		||||
        false
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    (pool, lm)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn command_help(
 | 
			
		||||
    ctx: &Context,
 | 
			
		||||
    msg: &Message,
 | 
			
		||||
    lm: Arc<LanguageManager>,
 | 
			
		||||
    prefix: &str,
 | 
			
		||||
    language: &str,
 | 
			
		||||
    command_name: &str,
 | 
			
		||||
) {
 | 
			
		||||
    let _ = msg
 | 
			
		||||
        .channel_id
 | 
			
		||||
        .send_message(ctx, |m| {
 | 
			
		||||
            m.embed(move |e| {
 | 
			
		||||
                e.title(format!("{} Help", command_name.to_title_case()))
 | 
			
		||||
                    .description(
 | 
			
		||||
                        lm.get(&language, &format!("help/{}", command_name))
 | 
			
		||||
                            .replace("{prefix}", &prefix),
 | 
			
		||||
                    )
 | 
			
		||||
                    .footer(|f| {
 | 
			
		||||
                        f.text(concat!(
 | 
			
		||||
                            env!("CARGO_PKG_NAME"),
 | 
			
		||||
                            " ver ",
 | 
			
		||||
                            env!("CARGO_PKG_VERSION")
 | 
			
		||||
                        ))
 | 
			
		||||
                    })
 | 
			
		||||
                    .color(*THEME_COLOR)
 | 
			
		||||
            })
 | 
			
		||||
        })
 | 
			
		||||
        .await;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										452
									
								
								src/models.rs
									
									
									
									
									
								
							
							
						
						
									
										452
									
								
								src/models.rs
									
									
									
									
									
								
							@@ -1,452 +0,0 @@
 | 
			
		||||
use serenity::{
 | 
			
		||||
    async_trait,
 | 
			
		||||
    http::CacheHttp,
 | 
			
		||||
    model::{
 | 
			
		||||
        channel::Channel,
 | 
			
		||||
        guild::Guild,
 | 
			
		||||
        id::{GuildId, UserId},
 | 
			
		||||
        user::User,
 | 
			
		||||
    },
 | 
			
		||||
    prelude::Context,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
 | 
			
		||||
use chrono::NaiveDateTime;
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
 | 
			
		||||
use log::error;
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    consts::{DEFAULT_PREFIX, LOCAL_LANGUAGE, LOCAL_TIMEZONE},
 | 
			
		||||
    GuildDataCache, SQLPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use tokio::sync::RwLock;
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
pub trait CtxGuildData {
 | 
			
		||||
    async fn guild_data<G: Into<GuildId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        guild_id: G,
 | 
			
		||||
    ) -> Result<Arc<RwLock<GuildData>>, sqlx::Error>;
 | 
			
		||||
 | 
			
		||||
    async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl CtxGuildData for Context {
 | 
			
		||||
    async fn guild_data<G: Into<GuildId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        guild_id: G,
 | 
			
		||||
    ) -> Result<Arc<RwLock<GuildData>>, sqlx::Error> {
 | 
			
		||||
        let guild_id = guild_id.into();
 | 
			
		||||
 | 
			
		||||
        let guild = guild_id.to_guild_cached(&self.cache).await.unwrap();
 | 
			
		||||
 | 
			
		||||
        let guild_cache = self
 | 
			
		||||
            .data
 | 
			
		||||
            .read()
 | 
			
		||||
            .await
 | 
			
		||||
            .get::<GuildDataCache>()
 | 
			
		||||
            .cloned()
 | 
			
		||||
            .unwrap();
 | 
			
		||||
 | 
			
		||||
        let x = if let Some(guild_data) = guild_cache.get(&guild_id) {
 | 
			
		||||
            Ok(guild_data.clone())
 | 
			
		||||
        } else {
 | 
			
		||||
            let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
            match GuildData::from_guild(guild, &pool).await {
 | 
			
		||||
                Ok(d) => {
 | 
			
		||||
                    let lock = Arc::new(RwLock::new(d));
 | 
			
		||||
 | 
			
		||||
                    guild_cache.insert(guild_id, lock.clone());
 | 
			
		||||
 | 
			
		||||
                    Ok(lock)
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                Err(e) => Err(e),
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        x
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String {
 | 
			
		||||
        if let Some(guild_id) = guild_id {
 | 
			
		||||
            self.guild_data(guild_id)
 | 
			
		||||
                .await
 | 
			
		||||
                .unwrap()
 | 
			
		||||
                .read()
 | 
			
		||||
                .await
 | 
			
		||||
                .prefix
 | 
			
		||||
                .clone()
 | 
			
		||||
        } else {
 | 
			
		||||
            DEFAULT_PREFIX.clone()
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct GuildData {
 | 
			
		||||
    pub id: u32,
 | 
			
		||||
    pub name: Option<String>,
 | 
			
		||||
    pub prefix: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl GuildData {
 | 
			
		||||
    pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result<Self, sqlx::Error> {
 | 
			
		||||
        let guild_id = guild.id.as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        match sqlx::query_as!(
 | 
			
		||||
            Self,
 | 
			
		||||
            "
 | 
			
		||||
SELECT id, name, prefix FROM guilds WHERE guild = ?
 | 
			
		||||
            ",
 | 
			
		||||
            guild_id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(mut g) => {
 | 
			
		||||
                g.name = Some(guild.name);
 | 
			
		||||
 | 
			
		||||
                Ok(g)
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            Err(sqlx::Error::RowNotFound) => {
 | 
			
		||||
                sqlx::query!(
 | 
			
		||||
                    "
 | 
			
		||||
INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?)
 | 
			
		||||
                    ",
 | 
			
		||||
                    guild_id,
 | 
			
		||||
                    guild.name,
 | 
			
		||||
                    *DEFAULT_PREFIX
 | 
			
		||||
                )
 | 
			
		||||
                .execute(&pool.clone())
 | 
			
		||||
                .await?;
 | 
			
		||||
 | 
			
		||||
                Ok(sqlx::query_as!(
 | 
			
		||||
                    Self,
 | 
			
		||||
                    "
 | 
			
		||||
SELECT id, name, prefix FROM guilds WHERE guild = ?
 | 
			
		||||
                    ",
 | 
			
		||||
                    guild_id
 | 
			
		||||
                )
 | 
			
		||||
                .fetch_one(pool)
 | 
			
		||||
                .await?)
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            Err(e) => {
 | 
			
		||||
                error!("Unexpected error in guild query: {:?}", e);
 | 
			
		||||
 | 
			
		||||
                Err(e)
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn commit_changes(&self, pool: &MySqlPool) {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
UPDATE guilds SET name = ?, prefix = ? WHERE id = ?
 | 
			
		||||
            ",
 | 
			
		||||
            self.name,
 | 
			
		||||
            self.prefix,
 | 
			
		||||
            self.id
 | 
			
		||||
        )
 | 
			
		||||
        .execute(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct ChannelData {
 | 
			
		||||
    pub id: u32,
 | 
			
		||||
    pub name: Option<String>,
 | 
			
		||||
    pub nudge: i16,
 | 
			
		||||
    pub blacklisted: bool,
 | 
			
		||||
    pub webhook_id: Option<u64>,
 | 
			
		||||
    pub webhook_token: Option<String>,
 | 
			
		||||
    pub paused: bool,
 | 
			
		||||
    pub paused_until: Option<NaiveDateTime>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ChannelData {
 | 
			
		||||
    pub async fn from_channel(
 | 
			
		||||
        channel: Channel,
 | 
			
		||||
        pool: &MySqlPool,
 | 
			
		||||
    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        let channel_id = channel.id().as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        if let Ok(c) = sqlx::query_as_unchecked!(Self,
 | 
			
		||||
            "
 | 
			
		||||
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
 | 
			
		||||
            ", channel_id)
 | 
			
		||||
            .fetch_one(pool)
 | 
			
		||||
            .await {
 | 
			
		||||
 | 
			
		||||
            Ok(c)
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
            let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
 | 
			
		||||
 | 
			
		||||
            let (guild_id, channel_name) = if let Some((a, b)) = props {
 | 
			
		||||
                (Some(a), Some(b))
 | 
			
		||||
            } else {
 | 
			
		||||
                (None, None)
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))
 | 
			
		||||
                ", channel_id, channel_name, guild_id)
 | 
			
		||||
                .execute(&pool.clone())
 | 
			
		||||
                .await?;
 | 
			
		||||
 | 
			
		||||
            Ok(sqlx::query_as_unchecked!(Self,
 | 
			
		||||
                "
 | 
			
		||||
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
 | 
			
		||||
                ", channel_id)
 | 
			
		||||
                .fetch_one(pool)
 | 
			
		||||
                .await?)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn commit_changes(&self, pool: &MySqlPool) {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ?
 | 
			
		||||
            ", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id)
 | 
			
		||||
            .execute(pool)
 | 
			
		||||
            .await.unwrap();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct UserData {
 | 
			
		||||
    pub id: u32,
 | 
			
		||||
    pub user: u64,
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub dm_channel: u32,
 | 
			
		||||
    pub language: String,
 | 
			
		||||
    pub timezone: String,
 | 
			
		||||
    pub meridian_time: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct MeridianType(bool);
 | 
			
		||||
 | 
			
		||||
impl MeridianType {
 | 
			
		||||
    pub fn fmt_str(&self) -> &str {
 | 
			
		||||
        if self.0 {
 | 
			
		||||
            "%Y-%m-%d %I:%M:%S %p"
 | 
			
		||||
        } else {
 | 
			
		||||
            "%Y-%m-%d %H:%M:%S"
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn fmt_str_short(&self) -> &str {
 | 
			
		||||
        if self.0 {
 | 
			
		||||
            "%I:%M %p"
 | 
			
		||||
        } else {
 | 
			
		||||
            "%H:%M"
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl UserData {
 | 
			
		||||
    pub async fn language_of<U>(user: U, pool: &MySqlPool) -> String
 | 
			
		||||
    where
 | 
			
		||||
        U: Into<UserId>,
 | 
			
		||||
    {
 | 
			
		||||
        let user_id = user.into().as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        match sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
SELECT language FROM users WHERE user = ?
 | 
			
		||||
            ",
 | 
			
		||||
            user_id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(r) => r.language,
 | 
			
		||||
 | 
			
		||||
            Err(_) => LOCAL_LANGUAGE.clone(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn timezone_of<U>(user: U, pool: &MySqlPool) -> Tz
 | 
			
		||||
    where
 | 
			
		||||
        U: Into<UserId>,
 | 
			
		||||
    {
 | 
			
		||||
        let user_id = user.into().as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        match sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
SELECT timezone FROM users WHERE user = ?
 | 
			
		||||
            ",
 | 
			
		||||
            user_id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(r) => r.timezone,
 | 
			
		||||
 | 
			
		||||
            Err(_) => LOCAL_TIMEZONE.clone(),
 | 
			
		||||
        }
 | 
			
		||||
        .parse()
 | 
			
		||||
        .unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn meridian_of<U>(user: U, pool: &MySqlPool) -> MeridianType
 | 
			
		||||
    where
 | 
			
		||||
        U: Into<UserId>,
 | 
			
		||||
    {
 | 
			
		||||
        let user_id = user.into().as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        match sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
SELECT meridian_time FROM users WHERE user = ?
 | 
			
		||||
            ",
 | 
			
		||||
            user_id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(r) => MeridianType(r.meridian_time != 0),
 | 
			
		||||
 | 
			
		||||
            Err(_) => MeridianType(false),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn from_user(
 | 
			
		||||
        user: &User,
 | 
			
		||||
        ctx: impl CacheHttp,
 | 
			
		||||
        pool: &MySqlPool,
 | 
			
		||||
    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        let user_id = user.id.as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        match sqlx::query_as_unchecked!(
 | 
			
		||||
            Self,
 | 
			
		||||
            "
 | 
			
		||||
SELECT id, user, name, dm_channel, IF(language IS NULL, ?, language) AS language, IF(timezone IS NULL, ?, timezone) AS timezone, meridian_time FROM users WHERE user = ?
 | 
			
		||||
            ",
 | 
			
		||||
            *LOCAL_LANGUAGE, *LOCAL_TIMEZONE, user_id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(c) => Ok(c),
 | 
			
		||||
 | 
			
		||||
            Err(sqlx::Error::RowNotFound) => {
 | 
			
		||||
                let dm_channel = user.create_dm_channel(ctx).await?;
 | 
			
		||||
                let dm_id = dm_channel.id.as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
                let pool_c = pool.clone();
 | 
			
		||||
 | 
			
		||||
                sqlx::query!(
 | 
			
		||||
                    "
 | 
			
		||||
INSERT IGNORE INTO channels (channel) VALUES (?)
 | 
			
		||||
                    ",
 | 
			
		||||
                    dm_id
 | 
			
		||||
                )
 | 
			
		||||
                .execute(&pool_c)
 | 
			
		||||
                .await?;
 | 
			
		||||
 | 
			
		||||
                sqlx::query!(
 | 
			
		||||
                    "
 | 
			
		||||
INSERT INTO users (user, name, dm_channel, language, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?, ?)
 | 
			
		||||
                    ", user_id, user.name, dm_id, *LOCAL_LANGUAGE, *LOCAL_TIMEZONE)
 | 
			
		||||
                    .execute(&pool_c)
 | 
			
		||||
                    .await?;
 | 
			
		||||
 | 
			
		||||
                Ok(sqlx::query_as_unchecked!(
 | 
			
		||||
                    Self,
 | 
			
		||||
                    "
 | 
			
		||||
SELECT id, user, name, dm_channel, language, timezone, meridian_time FROM users WHERE user = ?
 | 
			
		||||
                    ",
 | 
			
		||||
                    user_id
 | 
			
		||||
                )
 | 
			
		||||
                .fetch_one(pool)
 | 
			
		||||
                .await?)
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            Err(e) => {
 | 
			
		||||
                error!("Error querying for user: {:?}", e);
 | 
			
		||||
 | 
			
		||||
                Err(Box::new(e))
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn commit_changes(&self, pool: &MySqlPool) {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
UPDATE users SET name = ?, language = ?, timezone = ?, meridian_time = ? WHERE id = ?
 | 
			
		||||
            ",
 | 
			
		||||
            self.name,
 | 
			
		||||
            self.language,
 | 
			
		||||
            self.timezone,
 | 
			
		||||
            self.meridian_time,
 | 
			
		||||
            self.id
 | 
			
		||||
        )
 | 
			
		||||
        .execute(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn timezone(&self) -> Tz {
 | 
			
		||||
        self.timezone.parse().unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn meridian(&self) -> MeridianType {
 | 
			
		||||
        MeridianType(self.meridian_time)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct Timer {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub start_time: NaiveDateTime,
 | 
			
		||||
    pub owner: u64,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Timer {
 | 
			
		||||
    pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> {
 | 
			
		||||
        sqlx::query_as_unchecked!(
 | 
			
		||||
            Timer,
 | 
			
		||||
            "
 | 
			
		||||
SELECT name, start_time, owner FROM timers WHERE owner = ?
 | 
			
		||||
            ",
 | 
			
		||||
            owner
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_all(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
SELECT COUNT(1) as count FROM timers WHERE owner = ?
 | 
			
		||||
            ",
 | 
			
		||||
            owner
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap()
 | 
			
		||||
        .count as u32
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn create(name: &str, owner: u64, pool: &MySqlPool) {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
INSERT INTO timers (name, owner) VALUES (?, ?)
 | 
			
		||||
            ",
 | 
			
		||||
            name,
 | 
			
		||||
            owner
 | 
			
		||||
        )
 | 
			
		||||
        .execute(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										81
									
								
								src/models/channel_data.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								src/models/channel_data.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,81 @@
 | 
			
		||||
use chrono::NaiveDateTime;
 | 
			
		||||
use serenity::model::channel::Channel;
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
 | 
			
		||||
pub struct ChannelData {
 | 
			
		||||
    pub id: u32,
 | 
			
		||||
    pub name: Option<String>,
 | 
			
		||||
    pub nudge: i16,
 | 
			
		||||
    pub blacklisted: bool,
 | 
			
		||||
    pub webhook_id: Option<u64>,
 | 
			
		||||
    pub webhook_token: Option<String>,
 | 
			
		||||
    pub paused: bool,
 | 
			
		||||
    pub paused_until: Option<NaiveDateTime>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ChannelData {
 | 
			
		||||
    pub async fn from_channel(
 | 
			
		||||
        channel: &Channel,
 | 
			
		||||
        pool: &MySqlPool,
 | 
			
		||||
    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        let channel_id = channel.id().as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        if let Ok(c) = sqlx::query_as_unchecked!(
 | 
			
		||||
            Self,
 | 
			
		||||
            "
 | 
			
		||||
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
 | 
			
		||||
            ",
 | 
			
		||||
            channel_id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(c)
 | 
			
		||||
        } else {
 | 
			
		||||
            let props = channel.to_owned().guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
 | 
			
		||||
 | 
			
		||||
            let (guild_id, channel_name) = if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) };
 | 
			
		||||
 | 
			
		||||
            sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))
 | 
			
		||||
                ",
 | 
			
		||||
                channel_id,
 | 
			
		||||
                channel_name,
 | 
			
		||||
                guild_id
 | 
			
		||||
            )
 | 
			
		||||
            .execute(&pool.clone())
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
            Ok(sqlx::query_as_unchecked!(
 | 
			
		||||
                Self,
 | 
			
		||||
                "
 | 
			
		||||
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
 | 
			
		||||
                ",
 | 
			
		||||
                channel_id
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_one(pool)
 | 
			
		||||
            .await?)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn commit_changes(&self, pool: &MySqlPool) {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until \
 | 
			
		||||
             = ? WHERE id = ?
 | 
			
		||||
            ",
 | 
			
		||||
            self.name,
 | 
			
		||||
            self.nudge,
 | 
			
		||||
            self.blacklisted,
 | 
			
		||||
            self.webhook_id,
 | 
			
		||||
            self.webhook_token,
 | 
			
		||||
            self.paused,
 | 
			
		||||
            self.paused_until,
 | 
			
		||||
            self.id
 | 
			
		||||
        )
 | 
			
		||||
        .execute(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										33
									
								
								src/models/command_macro.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								src/models/command_macro.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,33 @@
 | 
			
		||||
use serenity::{client::Context, model::id::GuildId};
 | 
			
		||||
 | 
			
		||||
use crate::{framework::CommandOptions, SQLPool};
 | 
			
		||||
 | 
			
		||||
pub struct CommandMacro {
 | 
			
		||||
    pub guild_id: GuildId,
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub description: Option<String>,
 | 
			
		||||
    pub commands: Vec<CommandOptions>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl CommandMacro {
 | 
			
		||||
    pub async fn from_guild(ctx: &Context, guild_id: impl Into<GuildId>) -> Vec<Self> {
 | 
			
		||||
        let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
        let guild_id = guild_id.into();
 | 
			
		||||
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
 | 
			
		||||
            guild_id.0
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_all(&pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap()
 | 
			
		||||
        .iter()
 | 
			
		||||
        .map(|row| Self {
 | 
			
		||||
            guild_id,
 | 
			
		||||
            name: row.name.clone(),
 | 
			
		||||
            description: row.description.clone(),
 | 
			
		||||
            commands: serde_json::from_str(&row.commands).unwrap(),
 | 
			
		||||
        })
 | 
			
		||||
        .collect::<Vec<Self>>()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										66
									
								
								src/models/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								src/models/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,66 @@
 | 
			
		||||
pub mod channel_data;
 | 
			
		||||
pub mod command_macro;
 | 
			
		||||
pub mod reminder;
 | 
			
		||||
pub mod timer;
 | 
			
		||||
pub mod user_data;
 | 
			
		||||
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use serenity::{
 | 
			
		||||
    async_trait,
 | 
			
		||||
    model::id::{ChannelId, UserId},
 | 
			
		||||
    prelude::Context,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    models::{channel_data::ChannelData, user_data::UserData},
 | 
			
		||||
    SQLPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
pub trait CtxData {
 | 
			
		||||
    async fn user_data<U: Into<UserId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        user_id: U,
 | 
			
		||||
    ) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
 | 
			
		||||
 | 
			
		||||
    async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz;
 | 
			
		||||
 | 
			
		||||
    async fn channel_data<C: Into<ChannelId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        channel_id: C,
 | 
			
		||||
    ) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl CtxData for Context {
 | 
			
		||||
    async fn user_data<U: Into<UserId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        user_id: U,
 | 
			
		||||
    ) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        let user_id = user_id.into();
 | 
			
		||||
        let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
        let user = user_id.to_user(self).await.unwrap();
 | 
			
		||||
 | 
			
		||||
        UserData::from_user(&user, &self, &pool).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz {
 | 
			
		||||
        let user_id = user_id.into();
 | 
			
		||||
        let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
        UserData::timezone_of(user_id, &pool).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn channel_data<C: Into<ChannelId> + Send + Sync>(
 | 
			
		||||
        &self,
 | 
			
		||||
        channel_id: C,
 | 
			
		||||
    ) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        let channel_id = channel_id.into();
 | 
			
		||||
        let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
        let channel = channel_id.to_channel_cached(&self).unwrap();
 | 
			
		||||
 | 
			
		||||
        ChannelData::from_channel(&channel, &pool).await
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										305
									
								
								src/models/reminder/builder.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										305
									
								
								src/models/reminder/builder.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,305 @@
 | 
			
		||||
use std::{collections::HashSet, fmt::Display};
 | 
			
		||||
 | 
			
		||||
use chrono::{Duration, NaiveDateTime, Utc};
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use serenity::{
 | 
			
		||||
    client::Context,
 | 
			
		||||
    http::CacheHttp,
 | 
			
		||||
    model::{
 | 
			
		||||
        channel::GuildChannel,
 | 
			
		||||
        id::{ChannelId, GuildId, UserId},
 | 
			
		||||
        webhook::Webhook,
 | 
			
		||||
    },
 | 
			
		||||
    Result as SerenityResult,
 | 
			
		||||
};
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    consts,
 | 
			
		||||
    consts::{MAX_TIME, MIN_INTERVAL},
 | 
			
		||||
    models::{
 | 
			
		||||
        channel_data::ChannelData,
 | 
			
		||||
        reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
 | 
			
		||||
        user_data::UserData,
 | 
			
		||||
    },
 | 
			
		||||
    SQLPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
async fn create_webhook(
 | 
			
		||||
    ctx: impl CacheHttp,
 | 
			
		||||
    channel: GuildChannel,
 | 
			
		||||
    name: impl Display,
 | 
			
		||||
) -> SerenityResult<Webhook> {
 | 
			
		||||
    channel.create_webhook_with_avatar(ctx.http(), name, consts::DEFAULT_AVATAR.clone()).await
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Hash, PartialEq, Eq)]
 | 
			
		||||
pub enum ReminderScope {
 | 
			
		||||
    User(u64),
 | 
			
		||||
    Channel(u64),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ReminderScope {
 | 
			
		||||
    pub fn mention(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            Self::User(id) => format!("<@{}>", id),
 | 
			
		||||
            Self::Channel(id) => format!("<#{}>", id),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct ReminderBuilder {
 | 
			
		||||
    pool: MySqlPool,
 | 
			
		||||
    uid: String,
 | 
			
		||||
    channel: u32,
 | 
			
		||||
    utc_time: NaiveDateTime,
 | 
			
		||||
    timezone: String,
 | 
			
		||||
    interval: Option<i64>,
 | 
			
		||||
    expires: Option<NaiveDateTime>,
 | 
			
		||||
    content: String,
 | 
			
		||||
    tts: bool,
 | 
			
		||||
    attachment_name: Option<String>,
 | 
			
		||||
    attachment: Option<Vec<u8>>,
 | 
			
		||||
    set_by: Option<u32>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ReminderBuilder {
 | 
			
		||||
    pub async fn build(self) -> Result<Reminder, ReminderError> {
 | 
			
		||||
        let queried_time = sqlx::query!(
 | 
			
		||||
            "SELECT DATE_ADD(?, INTERVAL (SELECT nudge FROM channels WHERE id = ?) SECOND) AS `utc_time`",
 | 
			
		||||
            self.utc_time,
 | 
			
		||||
            self.channel,
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(&self.pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        match queried_time.utc_time {
 | 
			
		||||
            Some(utc_time) => {
 | 
			
		||||
                if utc_time < (Utc::now() - Duration::seconds(60)).naive_local() {
 | 
			
		||||
                    Err(ReminderError::PastTime)
 | 
			
		||||
                } else {
 | 
			
		||||
                    sqlx::query!(
 | 
			
		||||
                        "
 | 
			
		||||
INSERT INTO reminders (
 | 
			
		||||
    `uid`,
 | 
			
		||||
    `channel_id`,
 | 
			
		||||
    `utc_time`,
 | 
			
		||||
    `timezone`,
 | 
			
		||||
    `interval`,
 | 
			
		||||
    `expires`,
 | 
			
		||||
    `content`,
 | 
			
		||||
    `tts`,
 | 
			
		||||
    `attachment_name`,
 | 
			
		||||
    `attachment`,
 | 
			
		||||
    `set_by`
 | 
			
		||||
) VALUES (
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?,
 | 
			
		||||
    ?
 | 
			
		||||
)
 | 
			
		||||
            ",
 | 
			
		||||
                        self.uid,
 | 
			
		||||
                        self.channel,
 | 
			
		||||
                        utc_time,
 | 
			
		||||
                        self.timezone,
 | 
			
		||||
                        self.interval,
 | 
			
		||||
                        self.expires,
 | 
			
		||||
                        self.content,
 | 
			
		||||
                        self.tts,
 | 
			
		||||
                        self.attachment_name,
 | 
			
		||||
                        self.attachment,
 | 
			
		||||
                        self.set_by
 | 
			
		||||
                    )
 | 
			
		||||
                    .execute(&self.pool)
 | 
			
		||||
                    .await
 | 
			
		||||
                    .unwrap();
 | 
			
		||||
 | 
			
		||||
                    Ok(Reminder::from_uid(&self.pool, self.uid).await.unwrap())
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            None => Err(ReminderError::LongTime),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct MultiReminderBuilder<'a> {
 | 
			
		||||
    scopes: Vec<ReminderScope>,
 | 
			
		||||
    utc_time: NaiveDateTime,
 | 
			
		||||
    timezone: Tz,
 | 
			
		||||
    interval: Option<i64>,
 | 
			
		||||
    expires: Option<NaiveDateTime>,
 | 
			
		||||
    content: Content,
 | 
			
		||||
    set_by: Option<u32>,
 | 
			
		||||
    ctx: &'a Context,
 | 
			
		||||
    guild_id: Option<GuildId>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> MultiReminderBuilder<'a> {
 | 
			
		||||
    pub fn new(ctx: &'a Context, guild_id: Option<GuildId>) -> Self {
 | 
			
		||||
        MultiReminderBuilder {
 | 
			
		||||
            scopes: vec![],
 | 
			
		||||
            utc_time: Utc::now().naive_utc(),
 | 
			
		||||
            timezone: Tz::UTC,
 | 
			
		||||
            interval: None,
 | 
			
		||||
            expires: None,
 | 
			
		||||
            content: Content::new(),
 | 
			
		||||
            set_by: None,
 | 
			
		||||
            ctx,
 | 
			
		||||
            guild_id,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn content(mut self, content: Content) -> Self {
 | 
			
		||||
        self.content = content;
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn time<T: Into<i64>>(mut self, time: T) -> Self {
 | 
			
		||||
        self.utc_time = NaiveDateTime::from_timestamp(time.into(), 0);
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn expires<T: Into<i64>>(mut self, time: Option<T>) -> Self {
 | 
			
		||||
        if let Some(t) = time {
 | 
			
		||||
            self.expires = Some(NaiveDateTime::from_timestamp(t.into(), 0));
 | 
			
		||||
        } else {
 | 
			
		||||
            self.expires = None;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn author(mut self, user: UserData) -> Self {
 | 
			
		||||
        self.set_by = Some(user.id);
 | 
			
		||||
        self.timezone = user.timezone();
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn interval(mut self, interval: Option<i64>) -> Self {
 | 
			
		||||
        self.interval = interval;
 | 
			
		||||
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn set_scopes(&mut self, scopes: Vec<ReminderScope>) {
 | 
			
		||||
        self.scopes = scopes;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn build(self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) {
 | 
			
		||||
        let pool = self.ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
        let mut errors = HashSet::new();
 | 
			
		||||
 | 
			
		||||
        let mut ok_locs = HashSet::new();
 | 
			
		||||
 | 
			
		||||
        if self.interval.map_or(false, |i| (i as i64) < *MIN_INTERVAL) {
 | 
			
		||||
            errors.insert(ReminderError::ShortInterval);
 | 
			
		||||
        } else if self.interval.map_or(false, |i| (i as i64) > *MAX_TIME) {
 | 
			
		||||
            errors.insert(ReminderError::LongInterval);
 | 
			
		||||
        } else {
 | 
			
		||||
            for scope in self.scopes {
 | 
			
		||||
                let db_channel_id = match scope {
 | 
			
		||||
                    ReminderScope::User(user_id) => {
 | 
			
		||||
                        if let Ok(user) = UserId(user_id).to_user(&self.ctx).await {
 | 
			
		||||
                            let user_data =
 | 
			
		||||
                                UserData::from_user(&user, &self.ctx, &pool).await.unwrap();
 | 
			
		||||
 | 
			
		||||
                            if let Some(guild_id) = self.guild_id {
 | 
			
		||||
                                if guild_id.member(&self.ctx, user).await.is_err() {
 | 
			
		||||
                                    Err(ReminderError::InvalidTag)
 | 
			
		||||
                                } else {
 | 
			
		||||
                                    Ok(user_data.dm_channel)
 | 
			
		||||
                                }
 | 
			
		||||
                            } else {
 | 
			
		||||
                                Ok(user_data.dm_channel)
 | 
			
		||||
                            }
 | 
			
		||||
                        } else {
 | 
			
		||||
                            Err(ReminderError::InvalidTag)
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                    ReminderScope::Channel(channel_id) => {
 | 
			
		||||
                        let channel = ChannelId(channel_id).to_channel(&self.ctx).await.unwrap();
 | 
			
		||||
 | 
			
		||||
                        if let Some(guild_channel) = channel.clone().guild() {
 | 
			
		||||
                            if Some(guild_channel.guild_id) != self.guild_id {
 | 
			
		||||
                                Err(ReminderError::InvalidTag)
 | 
			
		||||
                            } else {
 | 
			
		||||
                                let mut channel_data =
 | 
			
		||||
                                    ChannelData::from_channel(&channel, &pool).await.unwrap();
 | 
			
		||||
 | 
			
		||||
                                if channel_data.webhook_id.is_none()
 | 
			
		||||
                                    || channel_data.webhook_token.is_none()
 | 
			
		||||
                                {
 | 
			
		||||
                                    match create_webhook(&self.ctx, guild_channel, "Reminder").await
 | 
			
		||||
                                    {
 | 
			
		||||
                                        Ok(webhook) => {
 | 
			
		||||
                                            channel_data.webhook_id =
 | 
			
		||||
                                                Some(webhook.id.as_u64().to_owned());
 | 
			
		||||
                                            channel_data.webhook_token = webhook.token;
 | 
			
		||||
 | 
			
		||||
                                            channel_data.commit_changes(&pool).await;
 | 
			
		||||
 | 
			
		||||
                                            Ok(channel_data.id)
 | 
			
		||||
                                        }
 | 
			
		||||
 | 
			
		||||
                                        Err(e) => Err(ReminderError::DiscordError(e.to_string())),
 | 
			
		||||
                                    }
 | 
			
		||||
                                } else {
 | 
			
		||||
                                    Ok(channel_data.id)
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                        } else {
 | 
			
		||||
                            Err(ReminderError::InvalidTag)
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                match db_channel_id {
 | 
			
		||||
                    Ok(c) => {
 | 
			
		||||
                        let builder = ReminderBuilder {
 | 
			
		||||
                            pool: pool.clone(),
 | 
			
		||||
                            uid: generate_uid(),
 | 
			
		||||
                            channel: c,
 | 
			
		||||
                            utc_time: self.utc_time,
 | 
			
		||||
                            timezone: self.timezone.to_string(),
 | 
			
		||||
                            interval: self.interval,
 | 
			
		||||
                            expires: self.expires,
 | 
			
		||||
                            content: self.content.content.clone(),
 | 
			
		||||
                            tts: self.content.tts,
 | 
			
		||||
                            attachment_name: self.content.attachment_name.clone(),
 | 
			
		||||
                            attachment: self.content.attachment.clone(),
 | 
			
		||||
                            set_by: self.set_by,
 | 
			
		||||
                        };
 | 
			
		||||
 | 
			
		||||
                        match builder.build().await {
 | 
			
		||||
                            Ok(_) => {
 | 
			
		||||
                                ok_locs.insert(scope);
 | 
			
		||||
                            }
 | 
			
		||||
                            Err(e) => {
 | 
			
		||||
                                errors.insert(e);
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                    Err(e) => {
 | 
			
		||||
                        errors.insert(e);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        (errors, ok_locs)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										12
									
								
								src/models/reminder/content.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								src/models/reminder/content.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
pub struct Content {
 | 
			
		||||
    pub content: String,
 | 
			
		||||
    pub tts: bool,
 | 
			
		||||
    pub attachment: Option<Vec<u8>>,
 | 
			
		||||
    pub attachment_name: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Content {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self { content: "".to_string(), tts: false, attachment: None, attachment_name: None }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										36
									
								
								src/models/reminder/errors.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								src/models/reminder/errors.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
use crate::consts::{MAX_TIME, MIN_INTERVAL};
 | 
			
		||||
 | 
			
		||||
#[derive(PartialEq, Eq, Hash, Debug)]
 | 
			
		||||
pub enum ReminderError {
 | 
			
		||||
    LongTime,
 | 
			
		||||
    LongInterval,
 | 
			
		||||
    PastTime,
 | 
			
		||||
    ShortInterval,
 | 
			
		||||
    InvalidTag,
 | 
			
		||||
    DiscordError(String),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ToString for ReminderError {
 | 
			
		||||
    fn to_string(&self) -> String {
 | 
			
		||||
        match self {
 | 
			
		||||
            ReminderError::LongTime => {
 | 
			
		||||
                "That time is too far in the future. Please specify a shorter time.".to_string()
 | 
			
		||||
            }
 | 
			
		||||
            ReminderError::LongInterval => format!(
 | 
			
		||||
                "Please ensure the interval specified is less than {max_time} days",
 | 
			
		||||
                max_time = *MAX_TIME / 86_400
 | 
			
		||||
            ),
 | 
			
		||||
            ReminderError::PastTime => {
 | 
			
		||||
                "Please ensure the time provided is in the future. If the time should be in the future, please be more specific with the definition.".to_string()
 | 
			
		||||
            }
 | 
			
		||||
            ReminderError::ShortInterval => format!(
 | 
			
		||||
                "Please ensure the interval provided is longer than {min_interval} seconds",
 | 
			
		||||
                min_interval = *MIN_INTERVAL
 | 
			
		||||
            ),
 | 
			
		||||
            ReminderError::InvalidTag => {
 | 
			
		||||
                "Couldn't find a location by your tag. Your tag must be either a channel or a user (not a role)".to_string()
 | 
			
		||||
            }
 | 
			
		||||
            ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										31
									
								
								src/models/reminder/helper.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								src/models/reminder/helper.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
			
		||||
use num_integer::Integer;
 | 
			
		||||
use rand::{rngs::OsRng, seq::IteratorRandom};
 | 
			
		||||
 | 
			
		||||
use crate::consts::{CHARACTERS, DAY, HOUR, MINUTE};
 | 
			
		||||
 | 
			
		||||
pub fn longhand_displacement(seconds: u64) -> String {
 | 
			
		||||
    let (days, seconds) = seconds.div_rem(&DAY);
 | 
			
		||||
    let (hours, seconds) = seconds.div_rem(&HOUR);
 | 
			
		||||
    let (minutes, seconds) = seconds.div_rem(&MINUTE);
 | 
			
		||||
 | 
			
		||||
    let mut sections = vec![];
 | 
			
		||||
 | 
			
		||||
    for (var, name) in
 | 
			
		||||
        [days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter())
 | 
			
		||||
    {
 | 
			
		||||
        if *var > 0 {
 | 
			
		||||
            sections.push(format!("{} {}", var, name));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    sections.join(", ")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn generate_uid() -> String {
 | 
			
		||||
    let mut generator: OsRng = Default::default();
 | 
			
		||||
 | 
			
		||||
    (0..64)
 | 
			
		||||
        .map(|_| CHARACTERS.chars().choose(&mut generator).unwrap().to_owned().to_string())
 | 
			
		||||
        .collect::<Vec<String>>()
 | 
			
		||||
        .join("")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										23
									
								
								src/models/reminder/look_flags.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								src/models/reminder/look_flags.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
use serde_repr::*;
 | 
			
		||||
use serenity::model::id::ChannelId;
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize_repr, Deserialize_repr, Copy, Clone, Debug)]
 | 
			
		||||
#[repr(u8)]
 | 
			
		||||
pub enum TimeDisplayType {
 | 
			
		||||
    Absolute = 0,
 | 
			
		||||
    Relative = 1,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
 | 
			
		||||
pub struct LookFlags {
 | 
			
		||||
    pub show_disabled: bool,
 | 
			
		||||
    pub channel_id: Option<ChannelId>,
 | 
			
		||||
    pub time_display: TimeDisplayType,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for LookFlags {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self { show_disabled: true, channel_id: None, time_display: TimeDisplayType::Relative }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										284
									
								
								src/models/reminder/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										284
									
								
								src/models/reminder/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,284 @@
 | 
			
		||||
pub mod builder;
 | 
			
		||||
pub mod content;
 | 
			
		||||
pub mod errors;
 | 
			
		||||
mod helper;
 | 
			
		||||
pub mod look_flags;
 | 
			
		||||
 | 
			
		||||
use chrono::{NaiveDateTime, TimeZone};
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use serenity::{
 | 
			
		||||
    client::Context,
 | 
			
		||||
    model::id::{ChannelId, GuildId, UserId},
 | 
			
		||||
};
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    models::reminder::{
 | 
			
		||||
        helper::longhand_displacement,
 | 
			
		||||
        look_flags::{LookFlags, TimeDisplayType},
 | 
			
		||||
    },
 | 
			
		||||
    SQLPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct Reminder {
 | 
			
		||||
    pub id: u32,
 | 
			
		||||
    pub uid: String,
 | 
			
		||||
    pub channel: u64,
 | 
			
		||||
    pub utc_time: NaiveDateTime,
 | 
			
		||||
    pub interval: Option<u32>,
 | 
			
		||||
    pub expires: Option<NaiveDateTime>,
 | 
			
		||||
    pub enabled: bool,
 | 
			
		||||
    pub content: String,
 | 
			
		||||
    pub embed_description: String,
 | 
			
		||||
    pub set_by: Option<u64>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Reminder {
 | 
			
		||||
    pub async fn from_uid(pool: &MySqlPool, uid: String) -> Option<Self> {
 | 
			
		||||
        sqlx::query_as_unchecked!(
 | 
			
		||||
            Self,
 | 
			
		||||
            "
 | 
			
		||||
SELECT
 | 
			
		||||
    reminders.id,
 | 
			
		||||
    reminders.uid,
 | 
			
		||||
    channels.channel,
 | 
			
		||||
    reminders.utc_time,
 | 
			
		||||
    reminders.interval,
 | 
			
		||||
    reminders.expires,
 | 
			
		||||
    reminders.enabled,
 | 
			
		||||
    reminders.content,
 | 
			
		||||
    reminders.embed_description,
 | 
			
		||||
    users.user AS set_by
 | 
			
		||||
FROM
 | 
			
		||||
    reminders
 | 
			
		||||
INNER JOIN
 | 
			
		||||
    channels
 | 
			
		||||
ON
 | 
			
		||||
    reminders.channel_id = channels.id
 | 
			
		||||
LEFT JOIN
 | 
			
		||||
    users
 | 
			
		||||
ON
 | 
			
		||||
    reminders.set_by = users.id
 | 
			
		||||
WHERE
 | 
			
		||||
    reminders.uid = ?
 | 
			
		||||
            ",
 | 
			
		||||
            uid
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .ok()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn from_channel<C: Into<ChannelId>>(
 | 
			
		||||
        ctx: &Context,
 | 
			
		||||
        channel_id: C,
 | 
			
		||||
        flags: &LookFlags,
 | 
			
		||||
    ) -> Vec<Self> {
 | 
			
		||||
        let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
        let enabled = if flags.show_disabled { "0,1" } else { "1" };
 | 
			
		||||
        let channel_id = channel_id.into();
 | 
			
		||||
 | 
			
		||||
        sqlx::query_as_unchecked!(
 | 
			
		||||
            Self,
 | 
			
		||||
            "
 | 
			
		||||
SELECT
 | 
			
		||||
    reminders.id,
 | 
			
		||||
    reminders.uid,
 | 
			
		||||
    channels.channel,
 | 
			
		||||
    reminders.utc_time,
 | 
			
		||||
    reminders.interval,
 | 
			
		||||
    reminders.expires,
 | 
			
		||||
    reminders.enabled,
 | 
			
		||||
    reminders.content,
 | 
			
		||||
    reminders.embed_description,
 | 
			
		||||
    users.user AS set_by
 | 
			
		||||
FROM
 | 
			
		||||
    reminders
 | 
			
		||||
INNER JOIN
 | 
			
		||||
    channels
 | 
			
		||||
ON
 | 
			
		||||
    reminders.channel_id = channels.id
 | 
			
		||||
LEFT JOIN
 | 
			
		||||
    users
 | 
			
		||||
ON
 | 
			
		||||
    reminders.set_by = users.id
 | 
			
		||||
WHERE
 | 
			
		||||
    channels.channel = ? AND
 | 
			
		||||
    FIND_IN_SET(reminders.enabled, ?)
 | 
			
		||||
ORDER BY
 | 
			
		||||
    reminders.utc_time
 | 
			
		||||
            ",
 | 
			
		||||
            channel_id.as_u64(),
 | 
			
		||||
            enabled,
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_all(&pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn from_guild(ctx: &Context, guild_id: Option<GuildId>, user: UserId) -> Vec<Self> {
 | 
			
		||||
        let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
 | 
			
		||||
 | 
			
		||||
        if let Some(guild_id) = guild_id {
 | 
			
		||||
            let guild_opt = guild_id.to_guild_cached(&ctx);
 | 
			
		||||
 | 
			
		||||
            if let Some(guild) = guild_opt {
 | 
			
		||||
                let channels = guild
 | 
			
		||||
                    .channels
 | 
			
		||||
                    .keys()
 | 
			
		||||
                    .into_iter()
 | 
			
		||||
                    .map(|k| k.as_u64().to_string())
 | 
			
		||||
                    .collect::<Vec<String>>()
 | 
			
		||||
                    .join(",");
 | 
			
		||||
 | 
			
		||||
                sqlx::query_as_unchecked!(
 | 
			
		||||
                    Self,
 | 
			
		||||
                    "
 | 
			
		||||
SELECT
 | 
			
		||||
    reminders.id,
 | 
			
		||||
    reminders.uid,
 | 
			
		||||
    channels.channel,
 | 
			
		||||
    reminders.utc_time,
 | 
			
		||||
    reminders.interval,
 | 
			
		||||
    reminders.expires,
 | 
			
		||||
    reminders.enabled,
 | 
			
		||||
    reminders.content,
 | 
			
		||||
    reminders.embed_description,
 | 
			
		||||
    users.user AS set_by
 | 
			
		||||
FROM
 | 
			
		||||
    reminders
 | 
			
		||||
LEFT JOIN
 | 
			
		||||
    channels
 | 
			
		||||
ON
 | 
			
		||||
    channels.id = reminders.channel_id
 | 
			
		||||
LEFT JOIN
 | 
			
		||||
    users
 | 
			
		||||
ON
 | 
			
		||||
    reminders.set_by = users.id
 | 
			
		||||
WHERE
 | 
			
		||||
    FIND_IN_SET(channels.channel, ?)
 | 
			
		||||
                ",
 | 
			
		||||
                    channels
 | 
			
		||||
                )
 | 
			
		||||
                .fetch_all(&pool)
 | 
			
		||||
                .await
 | 
			
		||||
            } else {
 | 
			
		||||
                sqlx::query_as_unchecked!(
 | 
			
		||||
                    Self,
 | 
			
		||||
                    "
 | 
			
		||||
SELECT
 | 
			
		||||
    reminders.id,
 | 
			
		||||
    reminders.uid,
 | 
			
		||||
    channels.channel,
 | 
			
		||||
    reminders.utc_time,
 | 
			
		||||
    reminders.interval,
 | 
			
		||||
    reminders.expires,
 | 
			
		||||
    reminders.enabled,
 | 
			
		||||
    reminders.content,
 | 
			
		||||
    reminders.embed_description,
 | 
			
		||||
    users.user AS set_by
 | 
			
		||||
FROM
 | 
			
		||||
    reminders
 | 
			
		||||
LEFT JOIN
 | 
			
		||||
    channels
 | 
			
		||||
ON
 | 
			
		||||
    channels.id = reminders.channel_id
 | 
			
		||||
LEFT JOIN
 | 
			
		||||
    users
 | 
			
		||||
ON
 | 
			
		||||
    reminders.set_by = users.id
 | 
			
		||||
WHERE
 | 
			
		||||
    channels.guild_id = (SELECT id FROM guilds WHERE guild = ?)
 | 
			
		||||
                ",
 | 
			
		||||
                    guild_id.as_u64()
 | 
			
		||||
                )
 | 
			
		||||
                .fetch_all(&pool)
 | 
			
		||||
                .await
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            sqlx::query_as_unchecked!(
 | 
			
		||||
                Self,
 | 
			
		||||
                "
 | 
			
		||||
SELECT
 | 
			
		||||
    reminders.id,
 | 
			
		||||
    reminders.uid,
 | 
			
		||||
    channels.channel,
 | 
			
		||||
    reminders.utc_time,
 | 
			
		||||
    reminders.interval,
 | 
			
		||||
    reminders.expires,
 | 
			
		||||
    reminders.enabled,
 | 
			
		||||
    reminders.content,
 | 
			
		||||
    reminders.embed_description,
 | 
			
		||||
    users.user AS set_by
 | 
			
		||||
FROM
 | 
			
		||||
    reminders
 | 
			
		||||
INNER JOIN
 | 
			
		||||
    channels
 | 
			
		||||
ON
 | 
			
		||||
    channels.id = reminders.channel_id
 | 
			
		||||
LEFT JOIN
 | 
			
		||||
    users
 | 
			
		||||
ON
 | 
			
		||||
    reminders.set_by = users.id
 | 
			
		||||
WHERE
 | 
			
		||||
    channels.id = (SELECT dm_channel FROM users WHERE user = ?)
 | 
			
		||||
            ",
 | 
			
		||||
                user.as_u64()
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_all(&pool)
 | 
			
		||||
            .await
 | 
			
		||||
        }
 | 
			
		||||
        .unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn display_content(&self) -> &str {
 | 
			
		||||
        if self.content.is_empty() {
 | 
			
		||||
            &self.embed_description
 | 
			
		||||
        } else {
 | 
			
		||||
            &self.content
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn display_del(&self, count: usize, timezone: &Tz) -> String {
 | 
			
		||||
        format!(
 | 
			
		||||
            "**{}**: '{}' *<#{}>* at **{}**",
 | 
			
		||||
            count + 1,
 | 
			
		||||
            self.display_content(),
 | 
			
		||||
            self.channel,
 | 
			
		||||
            timezone
 | 
			
		||||
                .timestamp(self.utc_time.timestamp(), 0)
 | 
			
		||||
                .format("%Y-%m-%d %H:%M:%S")
 | 
			
		||||
                .to_string()
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn display(&self, flags: &LookFlags, timezone: &Tz) -> String {
 | 
			
		||||
        let time_display = match flags.time_display {
 | 
			
		||||
            TimeDisplayType::Absolute => timezone
 | 
			
		||||
                .timestamp(self.utc_time.timestamp(), 0)
 | 
			
		||||
                .format("%Y-%m-%d %H:%M:%S")
 | 
			
		||||
                .to_string(),
 | 
			
		||||
 | 
			
		||||
            TimeDisplayType::Relative => format!("<t:{}:R>", self.utc_time.timestamp()),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        if let Some(interval) = self.interval {
 | 
			
		||||
            format!(
 | 
			
		||||
                "'{}' *occurs next at* **{}**, repeating every **{}** (set by {})",
 | 
			
		||||
                self.display_content(),
 | 
			
		||||
                time_display,
 | 
			
		||||
                longhand_displacement(interval as u64),
 | 
			
		||||
                self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
 | 
			
		||||
            )
 | 
			
		||||
        } else {
 | 
			
		||||
            format!(
 | 
			
		||||
                "'{}' *occurs next at* **{}** (set by {})",
 | 
			
		||||
                self.display_content(),
 | 
			
		||||
                time_display,
 | 
			
		||||
                self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
 | 
			
		||||
            )
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										49
									
								
								src/models/timer.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								src/models/timer.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,49 @@
 | 
			
		||||
use chrono::NaiveDateTime;
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
 | 
			
		||||
pub struct Timer {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub start_time: NaiveDateTime,
 | 
			
		||||
    pub owner: u64,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Timer {
 | 
			
		||||
    pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> {
 | 
			
		||||
        sqlx::query_as_unchecked!(
 | 
			
		||||
            Timer,
 | 
			
		||||
            "
 | 
			
		||||
SELECT name, start_time, owner FROM timers WHERE owner = ?
 | 
			
		||||
            ",
 | 
			
		||||
            owner
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_all(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
SELECT COUNT(1) as count FROM timers WHERE owner = ?
 | 
			
		||||
            ",
 | 
			
		||||
            owner
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap()
 | 
			
		||||
        .count as u32
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn create(name: &str, owner: u64, pool: &MySqlPool) {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
INSERT INTO timers (name, owner) VALUES (?, ?)
 | 
			
		||||
            ",
 | 
			
		||||
            name,
 | 
			
		||||
            owner
 | 
			
		||||
        )
 | 
			
		||||
        .execute(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										126
									
								
								src/models/user_data.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								src/models/user_data.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,126 @@
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use log::error;
 | 
			
		||||
use serenity::{
 | 
			
		||||
    http::CacheHttp,
 | 
			
		||||
    model::{id::UserId, user::User},
 | 
			
		||||
};
 | 
			
		||||
use sqlx::MySqlPool;
 | 
			
		||||
 | 
			
		||||
use crate::consts::LOCAL_TIMEZONE;
 | 
			
		||||
 | 
			
		||||
pub struct UserData {
 | 
			
		||||
    pub id: u32,
 | 
			
		||||
    pub user: u64,
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub dm_channel: u32,
 | 
			
		||||
    pub timezone: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl UserData {
 | 
			
		||||
    pub async fn timezone_of<U>(user: U, pool: &MySqlPool) -> Tz
 | 
			
		||||
    where
 | 
			
		||||
        U: Into<UserId>,
 | 
			
		||||
    {
 | 
			
		||||
        let user_id = user.into().as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        match sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
SELECT timezone FROM users WHERE user = ?
 | 
			
		||||
            ",
 | 
			
		||||
            user_id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(r) => r.timezone,
 | 
			
		||||
 | 
			
		||||
            Err(_) => LOCAL_TIMEZONE.clone(),
 | 
			
		||||
        }
 | 
			
		||||
        .parse()
 | 
			
		||||
        .unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn from_user(
 | 
			
		||||
        user: &User,
 | 
			
		||||
        ctx: impl CacheHttp,
 | 
			
		||||
        pool: &MySqlPool,
 | 
			
		||||
    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
 | 
			
		||||
        let user_id = user.id.as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
        match sqlx::query_as_unchecked!(
 | 
			
		||||
            Self,
 | 
			
		||||
            "
 | 
			
		||||
SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ?
 | 
			
		||||
            ",
 | 
			
		||||
            *LOCAL_TIMEZONE,
 | 
			
		||||
            user_id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        {
 | 
			
		||||
            Ok(c) => Ok(c),
 | 
			
		||||
 | 
			
		||||
            Err(sqlx::Error::RowNotFound) => {
 | 
			
		||||
                let dm_channel = user.create_dm_channel(ctx).await?;
 | 
			
		||||
                let dm_id = dm_channel.id.as_u64().to_owned();
 | 
			
		||||
 | 
			
		||||
                let pool_c = pool.clone();
 | 
			
		||||
 | 
			
		||||
                sqlx::query!(
 | 
			
		||||
                    "
 | 
			
		||||
INSERT IGNORE INTO channels (channel) VALUES (?)
 | 
			
		||||
                    ",
 | 
			
		||||
                    dm_id
 | 
			
		||||
                )
 | 
			
		||||
                .execute(&pool_c)
 | 
			
		||||
                .await?;
 | 
			
		||||
 | 
			
		||||
                sqlx::query!(
 | 
			
		||||
                    "
 | 
			
		||||
INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?)
 | 
			
		||||
                    ",
 | 
			
		||||
                    user_id,
 | 
			
		||||
                    user.name,
 | 
			
		||||
                    dm_id,
 | 
			
		||||
                    *LOCAL_TIMEZONE
 | 
			
		||||
                )
 | 
			
		||||
                .execute(&pool_c)
 | 
			
		||||
                .await?;
 | 
			
		||||
 | 
			
		||||
                Ok(sqlx::query_as_unchecked!(
 | 
			
		||||
                    Self,
 | 
			
		||||
                    "
 | 
			
		||||
SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ?
 | 
			
		||||
                    ",
 | 
			
		||||
                    user_id
 | 
			
		||||
                )
 | 
			
		||||
                .fetch_one(pool)
 | 
			
		||||
                .await?)
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            Err(e) => {
 | 
			
		||||
                error!("Error querying for user: {:?}", e);
 | 
			
		||||
 | 
			
		||||
                Err(Box::new(e))
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn commit_changes(&self, pool: &MySqlPool) {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
UPDATE users SET name = ?, timezone = ? WHERE id = ?
 | 
			
		||||
            ",
 | 
			
		||||
            self.name,
 | 
			
		||||
            self.timezone,
 | 
			
		||||
            self.id
 | 
			
		||||
        )
 | 
			
		||||
        .execute(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn timezone(&self) -> Tz {
 | 
			
		||||
        self.timezone.parse().unwrap()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										552
									
								
								src/sender.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										552
									
								
								src/sender.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,552 @@
 | 
			
		||||
use chrono::Duration;
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use log::{error, info, warn};
 | 
			
		||||
use num_integer::Integer;
 | 
			
		||||
use regex::{Captures, Regex};
 | 
			
		||||
use serenity::{
 | 
			
		||||
    builder::CreateEmbed,
 | 
			
		||||
    http::{CacheHttp, Http, StatusCode},
 | 
			
		||||
    model::{
 | 
			
		||||
        channel::{Channel, Embed as SerenityEmbed},
 | 
			
		||||
        id::ChannelId,
 | 
			
		||||
        webhook::Webhook,
 | 
			
		||||
    },
 | 
			
		||||
    Error, Result,
 | 
			
		||||
};
 | 
			
		||||
use sqlx::{
 | 
			
		||||
    types::chrono::{NaiveDateTime, Utc},
 | 
			
		||||
    MySqlPool,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
lazy_static! {
 | 
			
		||||
    pub static ref TIMEFROM_REGEX: Regex =
 | 
			
		||||
        Regex::new(r#"<<timefrom:(?P<time>\d+):(?P<format>.+)?>>"#).unwrap();
 | 
			
		||||
    pub static ref TIMENOW_REGEX: Regex =
 | 
			
		||||
        Regex::new(r#"<<timenow:(?P<timezone>(?:\w|/|_)+):(?P<format>.+)?>>"#).unwrap();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn fmt_displacement(format: &str, seconds: u64) -> String {
 | 
			
		||||
    let mut seconds = seconds;
 | 
			
		||||
    let mut days: u64 = 0;
 | 
			
		||||
    let mut hours: u64 = 0;
 | 
			
		||||
    let mut minutes: u64 = 0;
 | 
			
		||||
 | 
			
		||||
    for (rep, time_type, div) in
 | 
			
		||||
        [("%d", &mut days, 86400), ("%h", &mut hours, 3600), ("%m", &mut minutes, 60)].iter_mut()
 | 
			
		||||
    {
 | 
			
		||||
        if format.contains(*rep) {
 | 
			
		||||
            let (divided, new_seconds) = seconds.div_rem(&div);
 | 
			
		||||
 | 
			
		||||
            **time_type = divided;
 | 
			
		||||
            seconds = new_seconds;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    format
 | 
			
		||||
        .replace("%s", &seconds.to_string())
 | 
			
		||||
        .replace("%m", &minutes.to_string())
 | 
			
		||||
        .replace("%h", &hours.to_string())
 | 
			
		||||
        .replace("%d", &days.to_string())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn substitute(string: &str) -> String {
 | 
			
		||||
    let new = TIMEFROM_REGEX.replace(string, |caps: &Captures| {
 | 
			
		||||
        let final_time = caps.name("time").unwrap().as_str();
 | 
			
		||||
        let format = caps.name("format").unwrap().as_str();
 | 
			
		||||
 | 
			
		||||
        if let Ok(final_time) = final_time.parse::<i64>() {
 | 
			
		||||
            let dt = NaiveDateTime::from_timestamp(final_time, 0);
 | 
			
		||||
            let now = Utc::now().naive_utc();
 | 
			
		||||
 | 
			
		||||
            let difference = {
 | 
			
		||||
                if now < dt {
 | 
			
		||||
                    dt - Utc::now().naive_utc()
 | 
			
		||||
                } else {
 | 
			
		||||
                    Utc::now().naive_utc() - dt
 | 
			
		||||
                }
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            fmt_displacement(format, difference.num_seconds() as u64)
 | 
			
		||||
        } else {
 | 
			
		||||
            String::new()
 | 
			
		||||
        }
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    TIMENOW_REGEX
 | 
			
		||||
        .replace(&new, |caps: &Captures| {
 | 
			
		||||
            let timezone = caps.name("timezone").unwrap().as_str();
 | 
			
		||||
 | 
			
		||||
            println!("{}", timezone);
 | 
			
		||||
 | 
			
		||||
            if let Ok(tz) = timezone.parse::<Tz>() {
 | 
			
		||||
                let format = caps.name("format").unwrap().as_str();
 | 
			
		||||
                let now = Utc::now().with_timezone(&tz);
 | 
			
		||||
 | 
			
		||||
                now.format(format).to_string()
 | 
			
		||||
            } else {
 | 
			
		||||
                String::new()
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
        .to_string()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct Embed {
 | 
			
		||||
    inner: EmbedInner,
 | 
			
		||||
    fields: Vec<EmbedField>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct EmbedInner {
 | 
			
		||||
    title: String,
 | 
			
		||||
    description: String,
 | 
			
		||||
    image_url: Option<String>,
 | 
			
		||||
    thumbnail_url: Option<String>,
 | 
			
		||||
    footer: String,
 | 
			
		||||
    footer_url: Option<String>,
 | 
			
		||||
    author: String,
 | 
			
		||||
    author_url: Option<String>,
 | 
			
		||||
    color: u32,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct EmbedField {
 | 
			
		||||
    title: String,
 | 
			
		||||
    value: String,
 | 
			
		||||
    inline: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Embed {
 | 
			
		||||
    pub async fn from_id(pool: &MySqlPool, id: u32) -> Option<Self> {
 | 
			
		||||
        let mut inner = sqlx::query_as_unchecked!(
 | 
			
		||||
            EmbedInner,
 | 
			
		||||
            "
 | 
			
		||||
SELECT
 | 
			
		||||
    `embed_title` AS title,
 | 
			
		||||
    `embed_description` AS description,
 | 
			
		||||
    `embed_image_url` AS image_url,
 | 
			
		||||
    `embed_thumbnail_url` AS thumbnail_url,
 | 
			
		||||
    `embed_footer` AS footer,
 | 
			
		||||
    `embed_footer_url` AS footer_url,
 | 
			
		||||
    `embed_author` AS author,
 | 
			
		||||
    `embed_author_url` AS author_url,
 | 
			
		||||
    `embed_color` AS color
 | 
			
		||||
FROM
 | 
			
		||||
    reminders
 | 
			
		||||
WHERE
 | 
			
		||||
    `id` = ?
 | 
			
		||||
            ",
 | 
			
		||||
            id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_one(&pool.clone())
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        inner.title = substitute(&inner.title);
 | 
			
		||||
        inner.description = substitute(&inner.description);
 | 
			
		||||
        inner.footer = substitute(&inner.footer);
 | 
			
		||||
 | 
			
		||||
        let mut fields = sqlx::query_as_unchecked!(
 | 
			
		||||
            EmbedField,
 | 
			
		||||
            "
 | 
			
		||||
SELECT
 | 
			
		||||
    title,
 | 
			
		||||
    value,
 | 
			
		||||
    inline
 | 
			
		||||
FROM
 | 
			
		||||
    embed_fields
 | 
			
		||||
WHERE
 | 
			
		||||
    reminder_id = ?
 | 
			
		||||
            ",
 | 
			
		||||
            id
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_all(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap();
 | 
			
		||||
 | 
			
		||||
        fields.iter_mut().for_each(|mut field| {
 | 
			
		||||
            field.title = substitute(&field.title);
 | 
			
		||||
            field.value = substitute(&field.value);
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        let e = Embed { inner, fields };
 | 
			
		||||
 | 
			
		||||
        if e.has_content() {
 | 
			
		||||
            Some(e)
 | 
			
		||||
        } else {
 | 
			
		||||
            None
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn has_content(&self) -> bool {
 | 
			
		||||
        if self.inner.title.is_empty()
 | 
			
		||||
            && self.inner.description.is_empty()
 | 
			
		||||
            && self.inner.image_url.is_none()
 | 
			
		||||
            && self.inner.thumbnail_url.is_none()
 | 
			
		||||
            && self.inner.footer.is_empty()
 | 
			
		||||
            && self.inner.footer_url.is_none()
 | 
			
		||||
            && self.inner.author.is_empty()
 | 
			
		||||
            && self.inner.author_url.is_none()
 | 
			
		||||
            && self.fields.is_empty()
 | 
			
		||||
        {
 | 
			
		||||
            false
 | 
			
		||||
        } else {
 | 
			
		||||
            true
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Into<CreateEmbed> for Embed {
 | 
			
		||||
    fn into(self) -> CreateEmbed {
 | 
			
		||||
        let mut c = CreateEmbed::default();
 | 
			
		||||
 | 
			
		||||
        c.title(&self.inner.title)
 | 
			
		||||
            .description(&self.inner.description)
 | 
			
		||||
            .color(self.inner.color)
 | 
			
		||||
            .author(|a| {
 | 
			
		||||
                a.name(&self.inner.author);
 | 
			
		||||
 | 
			
		||||
                if let Some(author_icon) = &self.inner.author_url {
 | 
			
		||||
                    a.icon_url(author_icon);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                a
 | 
			
		||||
            })
 | 
			
		||||
            .footer(|f| {
 | 
			
		||||
                f.text(&self.inner.footer);
 | 
			
		||||
 | 
			
		||||
                if let Some(footer_icon) = &self.inner.footer_url {
 | 
			
		||||
                    f.icon_url(footer_icon);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                f
 | 
			
		||||
            });
 | 
			
		||||
 | 
			
		||||
        for field in &self.fields {
 | 
			
		||||
            c.field(&field.title, &field.value, field.inline);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if let Some(image_url) = &self.inner.image_url {
 | 
			
		||||
            c.image(image_url);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if let Some(thumbnail_url) = &self.inner.thumbnail_url {
 | 
			
		||||
            c.thumbnail(thumbnail_url);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        c
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct Reminder {
 | 
			
		||||
    id: u32,
 | 
			
		||||
 | 
			
		||||
    channel_id: u64,
 | 
			
		||||
    webhook_id: Option<u64>,
 | 
			
		||||
    webhook_token: Option<String>,
 | 
			
		||||
 | 
			
		||||
    channel_paused: bool,
 | 
			
		||||
    channel_paused_until: Option<NaiveDateTime>,
 | 
			
		||||
    enabled: bool,
 | 
			
		||||
 | 
			
		||||
    tts: bool,
 | 
			
		||||
    pin: bool,
 | 
			
		||||
    content: String,
 | 
			
		||||
    attachment: Option<Vec<u8>>,
 | 
			
		||||
    attachment_name: Option<String>,
 | 
			
		||||
 | 
			
		||||
    utc_time: NaiveDateTime,
 | 
			
		||||
    timezone: String,
 | 
			
		||||
    restartable: bool,
 | 
			
		||||
    expires: Option<NaiveDateTime>,
 | 
			
		||||
    interval: Option<u32>,
 | 
			
		||||
 | 
			
		||||
    avatar: Option<String>,
 | 
			
		||||
    username: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Reminder {
 | 
			
		||||
    pub async fn fetch_reminders(pool: &MySqlPool) -> Vec<Self> {
 | 
			
		||||
        sqlx::query_as_unchecked!(
 | 
			
		||||
            Reminder,
 | 
			
		||||
            "
 | 
			
		||||
SELECT
 | 
			
		||||
    reminders.`id` AS id,
 | 
			
		||||
 | 
			
		||||
    channels.`channel` AS channel_id,
 | 
			
		||||
    channels.`webhook_id` AS webhook_id,
 | 
			
		||||
    channels.`webhook_token` AS webhook_token,
 | 
			
		||||
 | 
			
		||||
    channels.`paused` AS channel_paused,
 | 
			
		||||
    channels.`paused_until` AS channel_paused_until,
 | 
			
		||||
    reminders.`enabled` AS enabled,
 | 
			
		||||
 | 
			
		||||
    reminders.`tts` AS tts,
 | 
			
		||||
    reminders.`pin` AS pin,
 | 
			
		||||
    reminders.`content` AS content,
 | 
			
		||||
    reminders.`attachment` AS attachment,
 | 
			
		||||
    reminders.`attachment_name` AS attachment_name,
 | 
			
		||||
 | 
			
		||||
    reminders.`utc_time` AS 'utc_time',
 | 
			
		||||
    reminders.`timezone` AS timezone,
 | 
			
		||||
    reminders.`restartable` AS restartable,
 | 
			
		||||
    reminders.`expires` AS expires,
 | 
			
		||||
    reminders.`interval` AS 'interval',
 | 
			
		||||
 | 
			
		||||
    reminders.`avatar` AS avatar,
 | 
			
		||||
    reminders.`username` AS username
 | 
			
		||||
FROM
 | 
			
		||||
    reminders
 | 
			
		||||
INNER JOIN
 | 
			
		||||
    channels
 | 
			
		||||
ON
 | 
			
		||||
    reminders.channel_id = channels.id
 | 
			
		||||
WHERE
 | 
			
		||||
    reminders.`utc_time` < NOW()
 | 
			
		||||
            ",
 | 
			
		||||
        )
 | 
			
		||||
        .fetch_all(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .unwrap()
 | 
			
		||||
        .into_iter()
 | 
			
		||||
        .map(|mut rem| {
 | 
			
		||||
            rem.content = substitute(&rem.content);
 | 
			
		||||
 | 
			
		||||
            rem
 | 
			
		||||
        })
 | 
			
		||||
        .collect::<Vec<Self>>()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn reset_webhook(&self, pool: &MySqlPool) {
 | 
			
		||||
        let _ = sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ?
 | 
			
		||||
            ",
 | 
			
		||||
            self.channel_id
 | 
			
		||||
        )
 | 
			
		||||
        .execute(pool)
 | 
			
		||||
        .await;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn refresh(&self, pool: &MySqlPool) {
 | 
			
		||||
        if let Some(interval) = self.interval {
 | 
			
		||||
            let now = Utc::now().naive_local();
 | 
			
		||||
            let mut updated_reminder_time = self.utc_time;
 | 
			
		||||
 | 
			
		||||
            while updated_reminder_time < now {
 | 
			
		||||
                updated_reminder_time += Duration::seconds(interval as i64);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if self.expires.map_or(false, |expires| {
 | 
			
		||||
                NaiveDateTime::from_timestamp(updated_reminder_time.timestamp(), 0) > expires
 | 
			
		||||
            }) {
 | 
			
		||||
                self.force_delete(pool).await;
 | 
			
		||||
            } else {
 | 
			
		||||
                sqlx::query!(
 | 
			
		||||
                    "
 | 
			
		||||
UPDATE reminders SET `utc_time` = ? WHERE `id` = ?
 | 
			
		||||
                    ",
 | 
			
		||||
                    updated_reminder_time,
 | 
			
		||||
                    self.id
 | 
			
		||||
                )
 | 
			
		||||
                .execute(pool)
 | 
			
		||||
                .await
 | 
			
		||||
                .expect(&format!("Could not update time on Reminder {}", self.id));
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            self.force_delete(pool).await;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn force_delete(&self, pool: &MySqlPool) {
 | 
			
		||||
        sqlx::query!(
 | 
			
		||||
            "
 | 
			
		||||
DELETE FROM reminders WHERE `id` = ?
 | 
			
		||||
            ",
 | 
			
		||||
            self.id
 | 
			
		||||
        )
 | 
			
		||||
        .execute(pool)
 | 
			
		||||
        .await
 | 
			
		||||
        .expect(&format!("Could not delete Reminder {}", self.id));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn pin_message<M: Into<u64>>(&self, message_id: M, http: impl AsRef<Http>) {
 | 
			
		||||
        let _ = http.as_ref().pin_message(self.channel_id, message_id.into(), None).await;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn send(&self, pool: MySqlPool, cache_http: impl CacheHttp) {
 | 
			
		||||
        async fn send_to_channel(
 | 
			
		||||
            cache_http: impl CacheHttp,
 | 
			
		||||
            reminder: &Reminder,
 | 
			
		||||
            embed: Option<CreateEmbed>,
 | 
			
		||||
        ) -> Result<()> {
 | 
			
		||||
            let channel = ChannelId(reminder.channel_id).to_channel(&cache_http).await;
 | 
			
		||||
 | 
			
		||||
            match channel {
 | 
			
		||||
                Ok(Channel::Guild(channel)) => {
 | 
			
		||||
                    match channel
 | 
			
		||||
                        .send_message(&cache_http, |m| {
 | 
			
		||||
                            m.content(&reminder.content).tts(reminder.tts);
 | 
			
		||||
 | 
			
		||||
                            if let (Some(attachment), Some(name)) =
 | 
			
		||||
                                (&reminder.attachment, &reminder.attachment_name)
 | 
			
		||||
                            {
 | 
			
		||||
                                m.add_file((attachment as &[u8], name.as_str()));
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            if let Some(embed) = embed {
 | 
			
		||||
                                m.set_embed(embed);
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            m
 | 
			
		||||
                        })
 | 
			
		||||
                        .await
 | 
			
		||||
                    {
 | 
			
		||||
                        Ok(m) => {
 | 
			
		||||
                            if reminder.pin {
 | 
			
		||||
                                reminder.pin_message(m.id, cache_http.http()).await;
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            Ok(())
 | 
			
		||||
                        }
 | 
			
		||||
                        Err(e) => Err(e),
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                Ok(Channel::Private(channel)) => {
 | 
			
		||||
                    match channel
 | 
			
		||||
                        .send_message(&cache_http.http(), |m| {
 | 
			
		||||
                            m.content(&reminder.content).tts(reminder.tts);
 | 
			
		||||
 | 
			
		||||
                            if let (Some(attachment), Some(name)) =
 | 
			
		||||
                                (&reminder.attachment, &reminder.attachment_name)
 | 
			
		||||
                            {
 | 
			
		||||
                                m.add_file((attachment as &[u8], name.as_str()));
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            if let Some(embed) = embed {
 | 
			
		||||
                                m.set_embed(embed);
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            m
 | 
			
		||||
                        })
 | 
			
		||||
                        .await
 | 
			
		||||
                    {
 | 
			
		||||
                        Ok(m) => {
 | 
			
		||||
                            if reminder.pin {
 | 
			
		||||
                                reminder.pin_message(m.id, cache_http.http()).await;
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            Ok(())
 | 
			
		||||
                        }
 | 
			
		||||
                        Err(e) => Err(e),
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                Err(e) => Err(e),
 | 
			
		||||
                _ => Err(Error::Other("Channel not of valid type")),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        async fn send_to_webhook(
 | 
			
		||||
            cache_http: impl CacheHttp,
 | 
			
		||||
            reminder: &Reminder,
 | 
			
		||||
            webhook: Webhook,
 | 
			
		||||
            embed: Option<CreateEmbed>,
 | 
			
		||||
        ) -> Result<()> {
 | 
			
		||||
            match webhook
 | 
			
		||||
                .execute(&cache_http.http(), reminder.pin || reminder.restartable, |w| {
 | 
			
		||||
                    w.content(&reminder.content).tts(reminder.tts);
 | 
			
		||||
 | 
			
		||||
                    if let Some(username) = &reminder.username {
 | 
			
		||||
                        w.username(username);
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    if let Some(avatar) = &reminder.avatar {
 | 
			
		||||
                        w.avatar_url(avatar);
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    if let (Some(attachment), Some(name)) =
 | 
			
		||||
                        (&reminder.attachment, &reminder.attachment_name)
 | 
			
		||||
                    {
 | 
			
		||||
                        w.add_file((attachment as &[u8], name.as_str()));
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    if let Some(embed) = embed {
 | 
			
		||||
                        w.embeds(vec![SerenityEmbed::fake(|c| {
 | 
			
		||||
                            *c = embed;
 | 
			
		||||
                            c
 | 
			
		||||
                        })]);
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    w
 | 
			
		||||
                })
 | 
			
		||||
                .await
 | 
			
		||||
            {
 | 
			
		||||
                Ok(m) => {
 | 
			
		||||
                    if reminder.pin {
 | 
			
		||||
                        if let Some(message) = m {
 | 
			
		||||
                            reminder.pin_message(message.id, cache_http.http()).await;
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    Ok(())
 | 
			
		||||
                }
 | 
			
		||||
                Err(e) => Err(e),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if self.enabled
 | 
			
		||||
            && !(self.channel_paused
 | 
			
		||||
                && self
 | 
			
		||||
                    .channel_paused_until
 | 
			
		||||
                    .map_or(true, |inner| inner >= Utc::now().naive_local()))
 | 
			
		||||
        {
 | 
			
		||||
            let _ = sqlx::query!(
 | 
			
		||||
                "
 | 
			
		||||
UPDATE `channels` SET paused = 0, paused_until = NULL WHERE `channel` = ?
 | 
			
		||||
                ",
 | 
			
		||||
                self.channel_id
 | 
			
		||||
            )
 | 
			
		||||
            .execute(&pool.clone())
 | 
			
		||||
            .await;
 | 
			
		||||
 | 
			
		||||
            let embed = Embed::from_id(&pool.clone(), self.id).await.map(|e| e.into());
 | 
			
		||||
 | 
			
		||||
            let result = if let (Some(webhook_id), Some(webhook_token)) =
 | 
			
		||||
                (self.webhook_id, &self.webhook_token)
 | 
			
		||||
            {
 | 
			
		||||
                let webhook_res =
 | 
			
		||||
                    cache_http.http().get_webhook_with_token(webhook_id, webhook_token).await;
 | 
			
		||||
 | 
			
		||||
                if let Ok(webhook) = webhook_res {
 | 
			
		||||
                    send_to_webhook(cache_http, &self, webhook, embed).await
 | 
			
		||||
                } else {
 | 
			
		||||
                    warn!("Webhook vanished: {:?}", webhook_res);
 | 
			
		||||
 | 
			
		||||
                    self.reset_webhook(&pool.clone()).await;
 | 
			
		||||
                    send_to_channel(cache_http, &self, embed).await
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                send_to_channel(cache_http, &self, embed).await
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            if let Err(e) = result {
 | 
			
		||||
                error!("Error sending {:?}: {:?}", self, e);
 | 
			
		||||
 | 
			
		||||
                if let Error::Http(error) = e {
 | 
			
		||||
                    if error.status_code() == Some(StatusCode::from_u16(404).unwrap()) {
 | 
			
		||||
                        error!("Seeing channel is deleted. Removing reminder");
 | 
			
		||||
                        self.force_delete(&pool).await;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        self.refresh(&pool).await;
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    self.refresh(&pool).await;
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                self.refresh(&pool).await;
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            info!("Reminder {} is paused", self.id);
 | 
			
		||||
 | 
			
		||||
            self.refresh(&pool).await;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,15 +1,16 @@
 | 
			
		||||
use std::time::{SystemTime, UNIX_EPOCH};
 | 
			
		||||
 | 
			
		||||
use std::fmt::{Display, Formatter, Result as FmtResult};
 | 
			
		||||
 | 
			
		||||
use crate::consts::{LOCAL_TIMEZONE, PYTHON_LOCATION};
 | 
			
		||||
use std::{
 | 
			
		||||
    convert::TryFrom,
 | 
			
		||||
    fmt::{Display, Formatter, Result as FmtResult},
 | 
			
		||||
    str::from_utf8,
 | 
			
		||||
    time::{SystemTime, UNIX_EPOCH},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use chrono::{DateTime, Datelike, Timelike, Utc};
 | 
			
		||||
use chrono_tz::Tz;
 | 
			
		||||
use std::convert::TryFrom;
 | 
			
		||||
use std::str::from_utf8;
 | 
			
		||||
use tokio::process::Command;
 | 
			
		||||
 | 
			
		||||
use crate::consts::{LOCAL_TIMEZONE, PYTHON_LOCATION};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum InvalidTime {
 | 
			
		||||
    ParseErrorDMY,
 | 
			
		||||
@@ -26,11 +27,13 @@ impl Display for InvalidTime {
 | 
			
		||||
 | 
			
		||||
impl std::error::Error for InvalidTime {}
 | 
			
		||||
 | 
			
		||||
#[derive(Copy, Clone)]
 | 
			
		||||
enum ParseType {
 | 
			
		||||
    Explicit,
 | 
			
		||||
    Displacement,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct TimeParser {
 | 
			
		||||
    timezone: Tz,
 | 
			
		||||
    inverted: bool,
 | 
			
		||||
@@ -95,10 +98,7 @@ impl TimeParser {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn process_explicit(&self) -> Result<i64, InvalidTime> {
 | 
			
		||||
        let mut time = Utc::now()
 | 
			
		||||
            .with_timezone(&self.timezone)
 | 
			
		||||
            .with_second(0)
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        let mut time = Utc::now().with_timezone(&self.timezone).with_second(0).unwrap();
 | 
			
		||||
 | 
			
		||||
        let mut segments = self.time_string.rsplit('-');
 | 
			
		||||
        // this segment will always exist even if split fails
 | 
			
		||||
@@ -106,13 +106,11 @@ impl TimeParser {
 | 
			
		||||
 | 
			
		||||
        let h_m_s = hms.split(':');
 | 
			
		||||
 | 
			
		||||
        for (t, setter) in h_m_s.take(3).zip(&[
 | 
			
		||||
            DateTime::with_hour,
 | 
			
		||||
            DateTime::with_minute,
 | 
			
		||||
            DateTime::with_second,
 | 
			
		||||
        ]) {
 | 
			
		||||
        for (t, setter) in
 | 
			
		||||
            h_m_s.take(3).zip(&[DateTime::with_hour, DateTime::with_minute, DateTime::with_second])
 | 
			
		||||
        {
 | 
			
		||||
            time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorHMS)?)
 | 
			
		||||
                .map_or_else(|| Err(InvalidTime::ParseErrorHMS), |inner| Ok(inner))?;
 | 
			
		||||
                .map_or_else(|| Err(InvalidTime::ParseErrorHMS), Ok)?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if let Some(dmy) = segments.next() {
 | 
			
		||||
@@ -122,13 +120,11 @@ impl TimeParser {
 | 
			
		||||
            let month = d_m_y.next();
 | 
			
		||||
            let year = d_m_y.next();
 | 
			
		||||
 | 
			
		||||
            for (t, setter) in [day, month]
 | 
			
		||||
                .iter()
 | 
			
		||||
                .zip(&[DateTime::with_day, DateTime::with_month])
 | 
			
		||||
            for (t, setter) in [day, month].iter().zip(&[DateTime::with_day, DateTime::with_month])
 | 
			
		||||
            {
 | 
			
		||||
                if let Some(t) = t {
 | 
			
		||||
                    time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorDMY)?)
 | 
			
		||||
                        .map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?;
 | 
			
		||||
                        .map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
@@ -136,7 +132,7 @@ impl TimeParser {
 | 
			
		||||
                if year.len() == 4 {
 | 
			
		||||
                    time = time
 | 
			
		||||
                        .with_year(year.parse().map_err(|_| InvalidTime::ParseErrorDMY)?)
 | 
			
		||||
                        .map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?;
 | 
			
		||||
                        .map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
 | 
			
		||||
                } else if year.len() == 2 {
 | 
			
		||||
                    time = time
 | 
			
		||||
                        .with_year(
 | 
			
		||||
@@ -144,9 +140,9 @@ impl TimeParser {
 | 
			
		||||
                                .parse()
 | 
			
		||||
                                .map_err(|_| InvalidTime::ParseErrorDMY)?,
 | 
			
		||||
                        )
 | 
			
		||||
                        .map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?;
 | 
			
		||||
                        .map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
 | 
			
		||||
                } else {
 | 
			
		||||
                    Err(InvalidTime::ParseErrorDMY)?;
 | 
			
		||||
                    return Err(InvalidTime::ParseErrorDMY);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
@@ -157,10 +153,10 @@ impl TimeParser {
 | 
			
		||||
    fn process_displacement(&self) -> Result<i64, InvalidTime> {
 | 
			
		||||
        let mut current_buffer = "0".to_string();
 | 
			
		||||
 | 
			
		||||
        let mut seconds = 0 as i64;
 | 
			
		||||
        let mut minutes = 0 as i64;
 | 
			
		||||
        let mut hours = 0 as i64;
 | 
			
		||||
        let mut days = 0 as i64;
 | 
			
		||||
        let mut seconds = 0_i64;
 | 
			
		||||
        let mut minutes = 0_i64;
 | 
			
		||||
        let mut hours = 0_i64;
 | 
			
		||||
        let mut days = 0_i64;
 | 
			
		||||
 | 
			
		||||
        for character in self.time_string.chars() {
 | 
			
		||||
            match character {
 | 
			
		||||
@@ -205,7 +201,7 @@ impl TimeParser {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub(crate) async fn natural_parser(time: &str, timezone: &str) -> Option<i64> {
 | 
			
		||||
pub async fn natural_parser(time: &str, timezone: &str) -> Option<i64> {
 | 
			
		||||
    Command::new(&*PYTHON_LOCATION)
 | 
			
		||||
        .arg("-c")
 | 
			
		||||
        .arg(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/dp.py")))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user