37 Commits

Author SHA1 Message Date
e5ab99f67b removed some log messages. rustfmt 2021-12-21 13:46:10 +00:00
e47715917e integrate reminder sender 2021-12-20 13:48:18 +00:00
4f9eb58c16 made the missing perms send a message (since the webhook responses bypass perms) 2021-11-18 21:05:49 +00:00
c953bc0cd3 various todo fixes 2021-11-15 08:09:48 +00:00
610779a293 added mention blocker to everything 2021-11-15 07:51:38 +00:00
ebd1efa990 added check for guild only commands 2021-11-13 22:30:18 +00:00
5230101a8d beta0 2021-11-13 14:12:37 +00:00
d8f42c1b25 fixed an issue with utc time. removed intents 2021-11-07 13:23:41 +00:00
23c6b3869e patreon gated repeat argument 2021-11-06 23:30:38 +00:00
a21f518b21 removed framework impl 2021-11-02 20:19:29 +00:00
f1bfc11160 removed all remaining restriction code 2021-11-02 20:10:10 +00:00
72228911f2 Readded some guild data code. fixed some weird cases with macro command. removed restrict command. changed db to be 'as it was'. removed execution limiters since commands are quite heavily ratelimited anyway 2021-10-30 21:02:11 +01:00
db7cca6296 added to the migration file somewhat. added some checks to components 2021-10-26 22:13:51 +01:00
e36e718f28 removed all guild data related code 2021-10-26 21:10:14 +01:00
44debf93c5 removed dead code 2021-10-26 20:54:22 +01:00
9b54fba5e5 Revert "turned pager into a single type"
This reverts commit 4490f19c
2021-10-26 20:11:19 +01:00
6cf660c7ee macro stuff 2021-10-16 19:18:16 +01:00
4490f19c04 turned pager into a single type 2021-10-13 17:23:50 +01:00
a362a24cfc changed a bunch of types so the macro run command works nicely 2021-10-13 16:37:15 +01:00
903daf65e6 ... 2021-10-12 21:52:43 +01:00
b310e99085 todo pager and selector 2021-10-11 21:19:08 +01:00
ebabe0e85a todo stuff 2021-10-02 22:54:34 +01:00
6b5d6ae288 fixed del pager. todo stuff 2021-09-27 17:34:13 +01:00
379e488f7a subcommand group syntax 2021-09-24 12:55:35 +01:00
d84d7ab62b added functionality for reusable hook functions that will execute on commands 2021-09-22 21:12:29 +01:00
a0974795e1 ... 2021-09-18 13:40:30 +01:00
a9c91bee93 pager improvements. deleting working 2021-09-16 18:30:16 +01:00
b2207e308a optimized packing slightly. restrict interactions 2021-09-16 15:42:50 +01:00
3c1eeed92f look command pager 2021-09-16 14:48:29 +01:00
395a8481f1 typing 2021-09-12 16:59:19 +01:00
bae0433bd9 framework now supports subcommands. timer cmd working 2021-09-12 16:09:57 +01:00
3e547861ea components 2021-09-11 20:40:58 +01:00
9b5333dc87 more commands. fixed an issue with text only commands 2021-09-11 00:14:23 +01:00
471948bed3 linked everything together 2021-09-10 18:09:25 +01:00
c148cdf556 removed language_manager.rs. framework reworked for slash commands. updated info commands for new framework 2021-09-06 13:46:16 +01:00
98aed91d21 revert some usages of discord builtin timestamp formatting 2021-09-02 23:59:30 +01:00
40630c0014 restructured all the reminder creation stuff into builders 2021-09-02 23:38:12 +01:00
41 changed files with 5532 additions and 4661 deletions

956
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,10 @@
[package] [package]
name = "reminder_rs" name = "reminder_rs"
version = "1.5.2" version = "1.6.0-beta3"
authors = ["jellywx <judesouthworth@pm.me>"] authors = ["jellywx <judesouthworth@pm.me>"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
dashmap = "4.0"
dotenv = "0.15" dotenv = "0.15"
humantime = "2.1" humantime = "2.1"
tokio = { version = "1", features = ["process", "full"] } tokio = { version = "1", features = ["process", "full"] }
@ -14,16 +13,34 @@ regex = "1.4"
log = "0.4" log = "0.4"
env_logger = "0.8" env_logger = "0.8"
chrono = "0.4" chrono = "0.4"
chrono-tz = "0.5" chrono-tz = { version = "0.5", features = ["serde"] }
lazy_static = "1.4" lazy_static = "1.4"
num-integer = "0.1" num-integer = "0.1"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
serde_repr = "0.1"
rmp-serde = "0.15"
rand = "0.7" rand = "0.7"
Inflector = "0.11"
levenshtein = "1.0" levenshtein = "1.0"
serenity = { git = "https://github.com/jellywx/serenity", branch = "jellywx-attachment_option", features = ["collector", "unstable_discord_api"] }
sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]} sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]}
base64 = "0.13.0"
[dependencies.regex_command_attr] [dependencies.regex_command_attr]
path = "./regex_command_attr" path = "command_attributes"
[dependencies.serenity]
git = "https://github.com/serenity-rs/serenity"
branch = "next"
default-features = false
features = [
"builder",
"client",
"cache",
"gateway",
"http",
"model",
"utils",
"rustls_backend",
"collector",
"unstable_discord_api"
]

View File

@ -1,6 +1,5 @@
# reminder-rs # reminder-rs
Reminder Bot for Discord, now in Rust. Reminder Bot for Discord.
Old Python version: https://github.com/reminder-bot/bot
## How do I use it? ## How do I use it?
We offer a hosted version of the bot. You can invite it with: **https://invite.reminder-bot.com**. The catch is that repeating We offer a hosted version of the bot. You can invite it with: **https://invite.reminder-bot.com**. The catch is that repeating
@ -15,7 +14,6 @@ Reminder Bot can be built by running `cargo build --release` in the top level di
These environment variables must be provided when compiling the bot These environment variables must be provided when compiling the bot
* `DATABASE_URL` - the URL of your MySQL database (`mysql://user[:password]@domain/database`) * `DATABASE_URL` - the URL of your MySQL database (`mysql://user[:password]@domain/database`)
* `WEBHOOK_AVATAR` - accepts the name of an image file located in `$CARGO_MANIFEST_DIR/assets/` to be used as the avatar when creating webhooks. **IMPORTANT: image file must be 128x128 or smaller in size** * `WEBHOOK_AVATAR` - accepts the name of an image file located in `$CARGO_MANIFEST_DIR/assets/` to be used as the avatar when creating webhooks. **IMPORTANT: image file must be 128x128 or smaller in size**
* `STRINGS_FILE` - accepts the name of a compiled strings file located in `$CARGO_MANIFEST_DIR/assets/` to be used for creating messages. Compiled string files can be generated with `compile.py` at https://github.com/reminder-bot/languages
### Setting up Python ### Setting up Python
Reminder Bot by default looks for a venv within it's working directory to run Python out of. To set up a venv, install `python3-venv` and run `python3 -m venv venv`. Then, run `source venv/bin/activate` to activate the venv, and do `pip install dateparser` to install the required library Reminder Bot by default looks for a venv within it's working directory to run Python out of. To set up a venv, install `python3-venv` and run `python3 -m venv venv`. Then, run `source venv/bin/activate` to activate the venv, and do `pip install dateparser` to install the required library
@ -29,16 +27,17 @@ __Required Variables__
__Other Variables__ __Other Variables__
* `MIN_INTERVAL` - default `600`, defines the shortest interval the bot should accept * `MIN_INTERVAL` - default `600`, defines the shortest interval the bot should accept
* `MAX_TIME` - default `1576800000`, defines the maximum time ahead that reminders can be set for
* `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor * `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor
* `DEFAULT_PREFIX` - default `$`, used for the default prefix on new guilds
* `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users * `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users
* `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to * `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to
* `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots * `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots
* `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else * `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else
* `LOCAL_LANGUAGE` - default `EN`. Specifies the string set to fall back to if a string cannot be found (and to be used with new users)
* `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds * `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds
* `CASE_INSENSITIVE` - default `1`, if `1`, commands will be treated with case insensitivity (so both `$help` and `$HELP` will work)
* `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran * `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran
* `SHARD_RANGE` - default `None`, if `SHARD_COUNT` is specified, specifies what range of shards to start on this process * `SHARD_RANGE` - default `None`, if `SHARD_COUNT` is specified, specifies what range of shards to start on this process
* `DM_ENABLED` - default `1`, if `1`, Reminder Bot will respond to direct messages * `DM_ENABLED` - default `1`, if `1`, Reminder Bot will respond to direct messages
### Todo List
* Convert aliases to macros
* Help command

View File

@ -1,9 +1,10 @@
[package] [package]
name = "regex_command_attr" name = "regex_command_attr"
version = "0.2.0" version = "0.3.6"
authors = ["acdenisSK <acdenissk69@gmail.com>", "jellywx <judesouthworth@pm.me>"] authors = ["acdenisSK <acdenissk69@gmail.com>", "jellywx <judesouthworth@pm.me>"]
edition = "2018" edition = "2018"
description = "Procedural macros for command creation for the RegexFramework for serenity." description = "Procedural macros for command creation for the Serenity library."
license = "ISC"
[lib] [lib]
proc-macro = true proc-macro = true
@ -12,3 +13,4 @@ proc-macro = true
quote = "^1.0" quote = "^1.0"
syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] } syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] }
proc-macro2 = "1.0" proc-macro2 = "1.0"
uuid = { version = "0.8", features = ["v4"] }

View File

@ -1,13 +1,17 @@
use proc_macro2::Span;
use syn::parse::{Error, Result};
use syn::spanned::Spanned;
use syn::{Attribute, Ident, Lit, LitStr, Meta, NestedMeta, Path};
use crate::structures::PermissionLevel;
use crate::util::{AsOption, LitExt};
use std::fmt::{self, Write}; use std::fmt::{self, Write};
use proc_macro2::Span;
use syn::{
parse::{Error, Result},
spanned::Spanned,
Attribute, Ident, Lit, LitStr, Meta, NestedMeta, Path,
};
use crate::{
structures::{ApplicationCommandOptionType, Arg},
util::{AsOption, LitExt},
};
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub enum ValueKind { pub enum ValueKind {
// #[<name>] // #[<name>]
@ -19,6 +23,9 @@ pub enum ValueKind {
// #[<name>([<value>, <value>, <value>, ...])] // #[<name>([<value>, <value>, <value>, ...])]
List, List,
// #[<name>([<prop> = <value>, <prop> = <value>, ...])]
EqualsList,
// #[<name>(<value>)] // #[<name>(<value>)]
SingleList, SingleList,
} }
@ -29,6 +36,9 @@ impl fmt::Display for ValueKind {
ValueKind::Name => f.pad("`#[<name>]`"), ValueKind::Name => f.pad("`#[<name>]`"),
ValueKind::Equals => f.pad("`#[<name> = <value>]`"), ValueKind::Equals => f.pad("`#[<name> = <value>]`"),
ValueKind::List => f.pad("`#[<name>([<value>, <value>, <value>, ...])]`"), ValueKind::List => f.pad("`#[<name>([<value>, <value>, <value>, ...])]`"),
ValueKind::EqualsList => {
f.pad("`#[<name>([<prop> = <value>, <prop> = <value>, ...])]`")
}
ValueKind::SingleList => f.pad("`#[<name>(<value>)]`"), ValueKind::SingleList => f.pad("`#[<name>(<value>)]`"),
} }
} }
@ -36,24 +46,15 @@ impl fmt::Display for ValueKind {
fn to_ident(p: Path) -> Result<Ident> { fn to_ident(p: Path) -> Result<Ident> {
if p.segments.is_empty() { if p.segments.is_empty() {
return Err(Error::new( return Err(Error::new(p.span(), "cannot convert an empty path to an identifier"));
p.span(),
"cannot convert an empty path to an identifier",
));
} }
if p.segments.len() > 1 { if p.segments.len() > 1 {
return Err(Error::new( return Err(Error::new(p.span(), "the path must not have more than one segment"));
p.span(),
"the path must not have more than one segment",
));
} }
if !p.segments[0].arguments.is_empty() { if !p.segments[0].arguments.is_empty() {
return Err(Error::new( return Err(Error::new(p.span(), "the singular path segment must not have any arguments"));
p.span(),
"the singular path segment must not have any arguments",
));
} }
Ok(p.segments[0].ident.clone()) Ok(p.segments[0].ident.clone())
@ -62,24 +63,37 @@ fn to_ident(p: Path) -> Result<Ident> {
#[derive(Debug)] #[derive(Debug)]
pub struct Values { pub struct Values {
pub name: Ident, pub name: Ident,
pub literals: Vec<Lit>, pub literals: Vec<(Option<String>, Lit)>,
pub kind: ValueKind, pub kind: ValueKind,
pub span: Span, pub span: Span,
} }
impl Values { impl Values {
#[inline] #[inline]
pub fn new(name: Ident, kind: ValueKind, literals: Vec<Lit>, span: Span) -> Self { pub fn new(
Values { name: Ident,
name, kind: ValueKind,
literals, literals: Vec<(Option<String>, Lit)>,
kind, span: Span,
span, ) -> Self {
} Values { name, literals, kind, span }
} }
} }
pub fn parse_values(attr: &Attribute) -> Result<Values> { pub fn parse_values(attr: &Attribute) -> Result<Values> {
fn is_list_or_named_list(meta: &NestedMeta) -> ValueKind {
match meta {
// catch if the nested value is a literal value
NestedMeta::Lit(_) => ValueKind::List,
// catch if the nested value is a meta value
NestedMeta::Meta(m) => match m {
// path => some quoted value
Meta::Path(_) => ValueKind::List,
Meta::List(_) | Meta::NameValue(_) => ValueKind::EqualsList,
},
}
}
let meta = attr.parse_meta()?; let meta = attr.parse_meta()?;
match meta { match meta {
@ -96,36 +110,62 @@ pub fn parse_values(attr: &Attribute) -> Result<Values> {
return Err(Error::new(attr.span(), "list cannot be empty")); return Err(Error::new(attr.span(), "list cannot be empty"));
} }
let mut lits = Vec::with_capacity(nested.len()); if is_list_or_named_list(nested.first().unwrap()) == ValueKind::List {
let mut lits = Vec::with_capacity(nested.len());
for meta in nested { for meta in nested {
match meta { match meta {
NestedMeta::Lit(l) => lits.push(l), // catch if the nested value is a literal value
NestedMeta::Meta(m) => match m { NestedMeta::Lit(l) => lits.push((None, l)),
Meta::Path(path) => { // catch if the nested value is a meta value
let i = to_ident(path)?; NestedMeta::Meta(m) => match m {
lits.push(Lit::Str(LitStr::new(&i.to_string(), i.span()))) // path => some quoted value
} Meta::Path(path) => {
Meta::List(_) | Meta::NameValue(_) => { let i = to_ident(path)?;
return Err(Error::new(attr.span(), "cannot nest a list; only accept literals and identifiers at this level")) lits.push((None, Lit::Str(LitStr::new(&i.to_string(), i.span()))))
} }
}, Meta::List(_) | Meta::NameValue(_) => {
return Err(Error::new(attr.span(), "cannot nest a list; only accept literals and identifiers at this level"))
}
},
}
} }
}
let kind = if lits.len() == 1 { let kind = if lits.len() == 1 { ValueKind::SingleList } else { ValueKind::List };
ValueKind::SingleList
Ok(Values::new(name, kind, lits, attr.span()))
} else { } else {
ValueKind::List let mut lits = Vec::with_capacity(nested.len());
};
Ok(Values::new(name, kind, lits, attr.span())) for meta in nested {
match meta {
// catch if the nested value is a literal value
NestedMeta::Lit(_) => {
return Err(Error::new(attr.span(), "key-value pairs expected"))
}
// catch if the nested value is a meta value
NestedMeta::Meta(m) => match m {
Meta::NameValue(n) => {
let name = to_ident(n.path)?.to_string();
let value = n.lit;
lits.push((Some(name), value));
}
Meta::List(_) | Meta::Path(_) => {
return Err(Error::new(attr.span(), "key-value pairs expected"))
}
},
}
}
Ok(Values::new(name, ValueKind::EqualsList, lits, attr.span()))
}
} }
Meta::NameValue(meta) => { Meta::NameValue(meta) => {
let name = to_ident(meta.path)?; let name = to_ident(meta.path)?;
let lit = meta.lit; let lit = meta.lit;
Ok(Values::new(name, ValueKind::Equals, vec![lit], attr.span())) Ok(Values::new(name, ValueKind::Equals, vec![(None, lit)], attr.span()))
} }
} }
} }
@ -168,10 +208,7 @@ fn validate(values: &Values, forms: &[ValueKind]) -> Result<()> {
return Err(Error::new( return Err(Error::new(
values.span, values.span,
// Using the `_args` version here to avoid an allocation. // Using the `_args` version here to avoid an allocation.
format_args!( format_args!("the attribute must be in of these forms:\n{}", DisplaySlice(forms)),
"the attribute must be in of these forms:\n{}",
DisplaySlice(forms)
),
)); ));
} }
@ -191,11 +228,7 @@ impl AttributeOption for Vec<String> {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::List])?; validate(&values, &[ValueKind::List])?;
Ok(values Ok(values.literals.into_iter().map(|(_, l)| l.to_str()).collect())
.literals
.into_iter()
.map(|lit| lit.to_str())
.collect())
} }
} }
@ -204,7 +237,7 @@ impl AttributeOption for String {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::Equals, ValueKind::SingleList])?; validate(&values, &[ValueKind::Equals, ValueKind::SingleList])?;
Ok(values.literals[0].to_str()) Ok(values.literals[0].1.to_str())
} }
} }
@ -213,7 +246,7 @@ impl AttributeOption for bool {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::Name, ValueKind::SingleList])?; validate(&values, &[ValueKind::Name, ValueKind::SingleList])?;
Ok(values.literals.get(0).map_or(true, |l| l.to_bool())) Ok(values.literals.get(0).map_or(true, |(_, l)| l.to_bool()))
} }
} }
@ -222,7 +255,7 @@ impl AttributeOption for Ident {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::SingleList])?; validate(&values, &[ValueKind::SingleList])?;
Ok(values.literals[0].to_ident()) Ok(values.literals[0].1.to_ident())
} }
} }
@ -231,7 +264,7 @@ impl AttributeOption for Vec<Ident> {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::List])?; validate(&values, &[ValueKind::List])?;
Ok(values.literals.into_iter().map(|l| l.to_ident()).collect()) Ok(values.literals.into_iter().map(|(_, l)| l.to_ident()).collect())
} }
} }
@ -239,15 +272,40 @@ impl AttributeOption for Option<String> {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList])?; validate(&values, &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList])?;
Ok(values.literals.get(0).map(|l| l.to_str())) Ok(values.literals.get(0).map(|(_, l)| l.to_str()))
} }
} }
impl AttributeOption for PermissionLevel { impl AttributeOption for Arg {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::SingleList])?; validate(&values, &[ValueKind::EqualsList])?;
Ok(values.literals.get(0).map(|l| PermissionLevel::from_str(&*l.to_str()).unwrap()).unwrap()) let mut arg: Arg = Default::default();
for (key, value) in &values.literals {
match key {
Some(s) => match s.as_str() {
"name" => {
arg.name = value.to_str();
}
"description" => {
arg.description = value.to_str();
}
"required" => {
arg.required = value.to_bool();
}
"kind" => arg.kind = ApplicationCommandOptionType::from_str(value.to_str()),
_ => {
return Err(Error::new(key.span(), "unexpected attribute"));
}
},
_ => {
return Err(Error::new(key.span(), "unnamed attribute"));
}
}
}
Ok(arg)
} }
} }
@ -265,7 +323,7 @@ macro_rules! attr_option_num {
fn parse(values: Values) -> Result<Self> { fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::SingleList])?; validate(&values, &[ValueKind::SingleList])?;
Ok(match &values.literals[0] { Ok(match &values.literals[0].1 {
Lit::Int(l) => l.base10_parse::<$n>()?, Lit::Int(l) => l.base10_parse::<$n>()?,
l => { l => {
let s = l.to_str(); let s = l.to_str();

View File

@ -0,0 +1,10 @@
pub mod suffixes {
pub const COMMAND: &str = "COMMAND";
pub const ARG: &str = "ARG";
pub const SUBCOMMAND: &str = "SUBCOMMAND";
pub const SUBCOMMAND_GROUP: &str = "GROUP";
pub const CHECK: &str = "CHECK";
pub const HOOK: &str = "HOOK";
}
pub use self::suffixes::*;

View File

@ -0,0 +1,321 @@
#![deny(rust_2018_idioms)]
#![deny(broken_intra_doc_links)]
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use syn::{parse::Error, parse_macro_input, parse_quote, spanned::Spanned, Lit, Type};
use uuid::Uuid;
pub(crate) mod attributes;
pub(crate) mod consts;
pub(crate) mod structures;
#[macro_use]
pub(crate) mod util;
use attributes::*;
use consts::*;
use structures::*;
use util::*;
macro_rules! match_options {
($v:expr, $values:ident, $options:ident, $span:expr => [$($name:ident);*]) => {
match $v {
$(
stringify!($name) => $options.$name = propagate_err!($crate::attributes::parse($values)),
)*
_ => {
return Error::new($span, format_args!("invalid attribute: {:?}", $v))
.to_compile_error()
.into();
},
}
};
}
#[proc_macro_attribute]
pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
enum LastItem {
Fun,
SubFun,
SubGroup,
SubGroupFun,
}
let mut fun = parse_macro_input!(input as CommandFun);
let _name = if !attr.is_empty() {
parse_macro_input!(attr as Lit).to_str()
} else {
fun.name.to_string()
};
let mut hooks: Vec<Ident> = Vec::new();
let mut options = Options::new();
let mut last_desc = LastItem::Fun;
for attribute in &fun.attributes {
let span = attribute.span();
let values = propagate_err!(parse_values(attribute));
let name = values.name.to_string();
let name = &name[..];
match name {
"subcommand" => {
let new_subcommand = Subcommand::new(propagate_err!(attributes::parse(values)));
if let Some(subcommand_group) = options.subcommand_groups.last_mut() {
last_desc = LastItem::SubGroupFun;
subcommand_group.subcommands.push(new_subcommand);
} else {
last_desc = LastItem::SubFun;
options.subcommands.push(new_subcommand);
}
}
"subcommandgroup" => {
let new_group = SubcommandGroup::new(propagate_err!(attributes::parse(values)));
last_desc = LastItem::SubGroup;
options.subcommand_groups.push(new_group);
}
"arg" => {
let arg = propagate_err!(attributes::parse(values));
match last_desc {
LastItem::Fun => {
options.cmd_args.push(arg);
}
LastItem::SubFun => {
options.subcommands.last_mut().unwrap().cmd_args.push(arg);
}
LastItem::SubGroup => {
panic!("Argument not expected under subcommand group");
}
LastItem::SubGroupFun => {
options
.subcommand_groups
.last_mut()
.unwrap()
.subcommands
.last_mut()
.unwrap()
.cmd_args
.push(arg);
}
}
}
"example" => {
options.examples.push(propagate_err!(attributes::parse(values)));
}
"description" => {
let line: String = propagate_err!(attributes::parse(values));
match last_desc {
LastItem::Fun => {
util::append_line(&mut options.description, line);
}
LastItem::SubFun => {
util::append_line(
&mut options.subcommands.last_mut().unwrap().description,
line,
);
}
LastItem::SubGroup => {
util::append_line(
&mut options.subcommand_groups.last_mut().unwrap().description,
line,
);
}
LastItem::SubGroupFun => {
util::append_line(
&mut options
.subcommand_groups
.last_mut()
.unwrap()
.subcommands
.last_mut()
.unwrap()
.description,
line,
);
}
}
}
"hook" => {
hooks.push(propagate_err!(attributes::parse(values)));
}
_ => {
match_options!(name, values, options, span => [
aliases;
group;
can_blacklist;
supports_dm
]);
}
}
}
let Options {
aliases,
description,
group,
examples,
can_blacklist,
supports_dm,
mut cmd_args,
mut subcommands,
mut subcommand_groups,
} = options;
let visibility = fun.visibility;
let name = fun.name.clone();
let body = fun.body;
let root_ident = name.with_suffix(COMMAND);
let command_path = quote!(crate::framework::Command);
populate_fut_lifetimes_on_refs(&mut fun.args);
let mut subcommand_group_idents = subcommand_groups
.iter()
.map(|subcommand| {
root_ident
.with_suffix(subcommand.name.replace("-", "_").as_str())
.with_suffix(SUBCOMMAND_GROUP)
})
.collect::<Vec<Ident>>();
let mut subcommand_idents = subcommands
.iter()
.map(|subcommand| {
root_ident
.with_suffix(subcommand.name.replace("-", "_").as_str())
.with_suffix(SUBCOMMAND)
})
.collect::<Vec<Ident>>();
let mut arg_idents = cmd_args
.iter()
.map(|arg| root_ident.with_suffix(arg.name.replace("-", "_").as_str()).with_suffix(ARG))
.collect::<Vec<Ident>>();
let mut tokens = quote! {};
tokens.extend(
subcommand_groups
.iter_mut()
.zip(subcommand_group_idents.iter())
.map(|(group, group_ident)| group.as_tokens(group_ident))
.fold(quote! {}, |mut a, b| {
a.extend(b);
a
}),
);
tokens.extend(
subcommands
.iter_mut()
.zip(subcommand_idents.iter())
.map(|(subcommand, sc_ident)| subcommand.as_tokens(sc_ident))
.fold(quote! {}, |mut a, b| {
a.extend(b);
a
}),
);
tokens.extend(
cmd_args.iter_mut().zip(arg_idents.iter()).map(|(arg, ident)| arg.as_tokens(ident)).fold(
quote! {},
|mut a, b| {
a.extend(b);
a
},
),
);
arg_idents.append(&mut subcommand_group_idents);
arg_idents.append(&mut subcommand_idents);
let args = fun.args;
let variant = if args.len() == 2 {
quote!(crate::framework::CommandFnType::Multi)
} else {
let string: Type = parse_quote!(String);
let final_arg = args.get(2).unwrap();
if final_arg.kind == string {
quote!(crate::framework::CommandFnType::Text)
} else {
quote!(crate::framework::CommandFnType::Slash)
}
};
tokens.extend(quote! {
#[allow(missing_docs)]
pub static #root_ident: #command_path = #command_path {
fun: #variant(#name),
names: &[#_name, #(#aliases),*],
desc: #description,
group: #group,
examples: &[#(#examples),*],
can_blacklist: #can_blacklist,
supports_dm: #supports_dm,
args: &[#(&#arg_idents),*],
hooks: &[#(&#hooks),*],
};
#[allow(missing_docs)]
#visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, ()> {
use ::serenity::futures::future::FutureExt;
async move {
#(#body)*;
}.boxed()
}
});
tokens.into()
}
#[proc_macro_attribute]
pub fn check(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut fun = parse_macro_input!(input as CommandFun);
let n = fun.name.clone();
let name = n.with_suffix(HOOK);
let fn_name = n.with_suffix(CHECK);
let visibility = fun.visibility;
let body = fun.body;
let ret = fun.ret;
populate_fut_lifetimes_on_refs(&mut fun.args);
let args = fun.args;
let hook_path = quote!(crate::framework::Hook);
let uuid = Uuid::new_v4().as_u128();
(quote! {
#[allow(missing_docs)]
#visibility fn #fn_name<'fut>(#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> {
use ::serenity::futures::future::FutureExt;
async move {
let _output: #ret = { #(#body)* };
#[allow(unreachable_code)]
_output
}.boxed()
}
#[allow(missing_docs)]
pub static #name: #hook_path = #hook_path {
fun: #fn_name,
uuid: #uuid,
};
})
.into()
}

View File

@ -0,0 +1,331 @@
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{
braced,
parse::{Error, Parse, ParseStream, Result},
spanned::Spanned,
Attribute, Block, FnArg, Ident, Pat, ReturnType, Stmt, Token, Type, Visibility,
};
use crate::{
consts::{ARG, SUBCOMMAND},
util::{Argument, IdentExt2, Parenthesised},
};
fn parse_argument(arg: FnArg) -> Result<Argument> {
match arg {
FnArg::Typed(typed) => {
let pat = typed.pat;
let kind = typed.ty;
match *pat {
Pat::Ident(id) => {
let name = id.ident;
let mutable = id.mutability;
Ok(Argument { mutable, name, kind: *kind })
}
Pat::Wild(wild) => {
let token = wild.underscore_token;
let name = Ident::new("_", token.spans[0]);
Ok(Argument { mutable: None, name, kind: *kind })
}
_ => Err(Error::new(pat.span(), format_args!("unsupported pattern: {:?}", pat))),
}
}
FnArg::Receiver(_) => {
Err(Error::new(arg.span(), format_args!("`self` arguments are prohibited: {:?}", arg)))
}
}
}
#[derive(Debug)]
pub struct CommandFun {
/// `#[...]`-style attributes.
pub attributes: Vec<Attribute>,
/// Populated cooked attributes. These are attributes outside of the realm of this crate's procedural macros
/// and will appear in generated output.
pub visibility: Visibility,
pub name: Ident,
pub args: Vec<Argument>,
pub ret: Type,
pub body: Vec<Stmt>,
}
impl Parse for CommandFun {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let attributes = input.call(Attribute::parse_outer)?;
let visibility = input.parse::<Visibility>()?;
input.parse::<Token![async]>()?;
input.parse::<Token![fn]>()?;
let name = input.parse()?;
// (...)
let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?;
let ret = match input.parse::<ReturnType>()? {
ReturnType::Type(_, t) => (*t).clone(),
ReturnType::Default => Type::Verbatim(quote!(())),
};
// { ... }
let bcont;
braced!(bcont in input);
let body = bcont.call(Block::parse_within)?;
let args = args.into_iter().map(parse_argument).collect::<Result<Vec<_>>>()?;
Ok(Self { attributes, visibility, name, args, ret, body })
}
}
impl ToTokens for CommandFun {
fn to_tokens(&self, stream: &mut TokenStream2) {
let Self { attributes: _, visibility, name, args, ret, body } = self;
stream.extend(quote! {
#visibility async fn #name (#(#args),*) -> #ret {
#(#body)*
}
});
}
}
#[derive(Debug)]
pub(crate) enum ApplicationCommandOptionType {
SubCommand,
SubCommandGroup,
String,
Integer,
Boolean,
User,
Channel,
Role,
Mentionable,
Number,
Unknown,
}
impl ApplicationCommandOptionType {
pub fn from_str(s: String) -> Self {
match s.as_str() {
"SubCommand" => Self::SubCommand,
"SubCommandGroup" => Self::SubCommandGroup,
"String" => Self::String,
"Integer" => Self::Integer,
"Boolean" => Self::Boolean,
"User" => Self::User,
"Channel" => Self::Channel,
"Role" => Self::Role,
"Mentionable" => Self::Mentionable,
"Number" => Self::Number,
_ => Self::Unknown,
}
}
}
impl ToTokens for ApplicationCommandOptionType {
fn to_tokens(&self, stream: &mut TokenStream2) {
let path = quote!(
serenity::model::interactions::application_command::ApplicationCommandOptionType
);
let variant = match self {
ApplicationCommandOptionType::SubCommand => quote!(SubCommand),
ApplicationCommandOptionType::SubCommandGroup => quote!(SubCommandGroup),
ApplicationCommandOptionType::String => quote!(String),
ApplicationCommandOptionType::Integer => quote!(Integer),
ApplicationCommandOptionType::Boolean => quote!(Boolean),
ApplicationCommandOptionType::User => quote!(User),
ApplicationCommandOptionType::Channel => quote!(Channel),
ApplicationCommandOptionType::Role => quote!(Role),
ApplicationCommandOptionType::Mentionable => quote!(Mentionable),
ApplicationCommandOptionType::Number => quote!(Number),
ApplicationCommandOptionType::Unknown => quote!(Unknown),
};
stream.extend(quote! {
#path::#variant
});
}
}
#[derive(Debug)]
pub(crate) struct Arg {
pub name: String,
pub description: String,
pub kind: ApplicationCommandOptionType,
pub required: bool,
}
impl Arg {
pub fn as_tokens(&self, ident: &Ident) -> TokenStream2 {
let arg_path = quote!(crate::framework::Arg);
let Arg { name, description, kind, required } = self;
quote! {
#[allow(missing_docs)]
pub static #ident: #arg_path = #arg_path {
name: #name,
description: #description,
kind: #kind,
required: #required,
options: &[]
};
}
}
}
impl Default for Arg {
fn default() -> Self {
Self {
name: String::new(),
description: String::new(),
kind: ApplicationCommandOptionType::String,
required: false,
}
}
}
#[derive(Debug)]
pub(crate) struct Subcommand {
pub name: String,
pub description: String,
pub cmd_args: Vec<Arg>,
}
impl Subcommand {
pub fn as_tokens(&mut self, ident: &Ident) -> TokenStream2 {
let arg_path = quote!(crate::framework::Arg);
let subcommand_path = ApplicationCommandOptionType::SubCommand;
let arg_idents = self
.cmd_args
.iter()
.map(|arg| ident.with_suffix(arg.name.as_str()).with_suffix(ARG))
.collect::<Vec<Ident>>();
let mut tokens = self
.cmd_args
.iter_mut()
.zip(arg_idents.iter())
.map(|(arg, ident)| arg.as_tokens(ident))
.fold(quote! {}, |mut a, b| {
a.extend(b);
a
});
let Subcommand { name, description, .. } = self;
tokens.extend(quote! {
#[allow(missing_docs)]
pub static #ident: #arg_path = #arg_path {
name: #name,
description: #description,
kind: #subcommand_path,
required: false,
options: &[#(&#arg_idents),*],
};
});
tokens
}
}
impl Default for Subcommand {
fn default() -> Self {
Self { name: String::new(), description: String::new(), cmd_args: vec![] }
}
}
impl Subcommand {
pub(crate) fn new(name: String) -> Self {
Self { name, ..Default::default() }
}
}
#[derive(Debug)]
pub(crate) struct SubcommandGroup {
pub name: String,
pub description: String,
pub subcommands: Vec<Subcommand>,
}
impl SubcommandGroup {
pub fn as_tokens(&mut self, ident: &Ident) -> TokenStream2 {
let arg_path = quote!(crate::framework::Arg);
let subcommand_group_path = ApplicationCommandOptionType::SubCommandGroup;
let arg_idents = self
.subcommands
.iter()
.map(|arg| {
ident
.with_suffix(self.name.as_str())
.with_suffix(arg.name.as_str())
.with_suffix(SUBCOMMAND)
})
.collect::<Vec<Ident>>();
let mut tokens = self
.subcommands
.iter_mut()
.zip(arg_idents.iter())
.map(|(subcommand, ident)| subcommand.as_tokens(ident))
.fold(quote! {}, |mut a, b| {
a.extend(b);
a
});
let SubcommandGroup { name, description, .. } = self;
tokens.extend(quote! {
#[allow(missing_docs)]
pub static #ident: #arg_path = #arg_path {
name: #name,
description: #description,
kind: #subcommand_group_path,
required: false,
options: &[#(&#arg_idents),*],
};
});
tokens
}
}
impl Default for SubcommandGroup {
fn default() -> Self {
Self { name: String::new(), description: String::new(), subcommands: vec![] }
}
}
impl SubcommandGroup {
pub(crate) fn new(name: String) -> Self {
Self { name, ..Default::default() }
}
}
#[derive(Debug, Default)]
pub(crate) struct Options {
pub aliases: Vec<String>,
pub description: String,
pub group: String,
pub examples: Vec<String>,
pub can_blacklist: bool,
pub supports_dm: bool,
pub cmd_args: Vec<Arg>,
pub subcommands: Vec<Subcommand>,
pub subcommand_groups: Vec<SubcommandGroup>,
}
impl Options {
#[inline]
pub fn new() -> Self {
Self { group: "None".to_string(), ..Default::default() }
}
}

View File

@ -1,6 +1,5 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use proc_macro2::Span; use proc_macro2::{Span, TokenStream as TokenStream2};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens}; use quote::{format_ident, quote, ToTokens};
use syn::{ use syn::{
braced, bracketed, parenthesized, braced, bracketed, parenthesized,
@ -158,3 +157,20 @@ pub fn populate_fut_lifetimes_on_refs(args: &mut Vec<Argument>) {
} }
} }
} }
pub fn append_line(desc: &mut String, mut line: String) {
if line.starts_with(' ') {
line.remove(0);
}
match line.rfind("\\$") {
Some(i) => {
desc.push_str(line[..i].trim_end());
desc.push(' ');
}
None => {
desc.push_str(&line);
desc.push('\n');
}
}
}

View File

@ -56,8 +56,7 @@ CREATE TABLE reminders_new (
-- , CONSTRAINT interval_enabled_mutin CHECK (`enabled` = 1 OR `interval` IS NULL) -- , CONSTRAINT interval_enabled_mutin CHECK (`enabled` = 1 OR `interval` IS NULL)
# disallow an expiry time if interval is unspecified # disallow an expiry time if interval is unspecified
-- , CONSTRAINT interval_expires_mutin CHECK (`expires` IS NULL OR `interval` IS NOT NULL) -- , CONSTRAINT interval_expires_mutin CHECK (`expires` IS NULL OR `interval` IS NOT NULL)
) );
COLLATE utf8mb4_unicode_ci;
# import data from other tables # import data from other tables
INSERT INTO reminders_new ( INSERT INTO reminders_new (

13
migration/02-macro.sql Normal file
View File

@ -0,0 +1,13 @@
USE reminders;
CREATE TABLE macro (
id INT UNSIGNED AUTO_INCREMENT,
guild_id INT UNSIGNED NOT NULL,
name VARCHAR(100) NOT NULL,
description VARCHAR(100),
commands TEXT NOT NULL,
FOREIGN KEY (guild_id) REFERENCES guilds(id) ON DELETE CASCADE,
PRIMARY KEY (id)
);

View File

@ -1,5 +0,0 @@
pub mod suffixes {
pub const COMMAND: &str = "COMMAND";
}
pub use self::suffixes::*;

View File

@ -1,102 +0,0 @@
#![deny(rust_2018_idioms)]
// FIXME: Remove this in a foreseeable future.
// Currently exists for backwards compatibility to previous Rust versions.
#![recursion_limit = "128"]
#[allow(unused_extern_crates)]
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse::Error, parse_macro_input, spanned::Spanned, Lit};
pub(crate) mod attributes;
pub(crate) mod consts;
pub(crate) mod structures;
#[macro_use]
pub(crate) mod util;
use attributes::*;
use consts::*;
use structures::*;
use util::*;
macro_rules! match_options {
($v:expr, $values:ident, $options:ident, $span:expr => [$($name:ident);*]) => {
match $v {
$(
stringify!($name) => $options.$name = propagate_err!($crate::attributes::parse($values)),
)*
_ => {
return Error::new($span, format_args!("invalid attribute: {:?}", $v))
.to_compile_error()
.into();
},
}
};
}
#[proc_macro_attribute]
pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
let mut fun = parse_macro_input!(input as CommandFun);
let lit_name = if !attr.is_empty() {
parse_macro_input!(attr as Lit).to_str()
} else {
fun.name.to_string()
};
let mut options = Options::new();
for attribute in &fun.attributes {
let span = attribute.span();
let values = propagate_err!(parse_values(attribute));
let name = values.name.to_string();
let name = &name[..];
match_options!(name, values, options, span => [
permission_level;
supports_dm;
can_blacklist
]);
}
let Options {
permission_level,
supports_dm,
can_blacklist,
} = options;
let visibility = fun.visibility;
let name = fun.name.clone();
let body = fun.body;
let n = name.with_suffix(COMMAND);
let cooked = fun.cooked.clone();
let command_path = quote!(crate::framework::Command);
populate_fut_lifetimes_on_refs(&mut fun.args);
let args = fun.args;
(quote! {
#(#cooked)*
pub static #n: #command_path = #command_path {
func: #name,
name: #lit_name,
required_perms: #permission_level,
supports_dm: #supports_dm,
can_blacklist: #can_blacklist,
};
#visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, ()> {
use ::serenity::futures::future::FutureExt;
async move { #(#body)* }.boxed()
}
})
.into()
}

View File

@ -1,231 +0,0 @@
use crate::util::{Argument, Parenthesised};
use proc_macro2::Span;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{
braced,
parse::{Error, Parse, ParseStream, Result},
spanned::Spanned,
Attribute, Block, FnArg, Ident, Pat, Path, PathSegment, Stmt, Token, Visibility,
};
fn parse_argument(arg: FnArg) -> Result<Argument> {
match arg {
FnArg::Typed(typed) => {
let pat = typed.pat;
let kind = typed.ty;
match *pat {
Pat::Ident(id) => {
let name = id.ident;
let mutable = id.mutability;
Ok(Argument {
mutable,
name,
kind: *kind,
})
}
Pat::Wild(wild) => {
let token = wild.underscore_token;
let name = Ident::new("_", token.spans[0]);
Ok(Argument {
mutable: None,
name,
kind: *kind,
})
}
_ => Err(Error::new(
pat.span(),
format_args!("unsupported pattern: {:?}", pat),
)),
}
}
FnArg::Receiver(_) => Err(Error::new(
arg.span(),
format_args!("`self` arguments are prohibited: {:?}", arg),
)),
}
}
/// Test if the attribute is cooked.
fn is_cooked(attr: &Attribute) -> bool {
const COOKED_ATTRIBUTE_NAMES: &[&str] = &[
"cfg", "cfg_attr", "doc", "derive", "inline", "allow", "warn", "deny", "forbid",
];
COOKED_ATTRIBUTE_NAMES.iter().any(|n| attr.path.is_ident(n))
}
/// Removes cooked attributes from a vector of attributes. Uncooked attributes are left in the vector.
///
/// # Return
///
/// Returns a vector of cooked attributes that have been removed from the input vector.
fn remove_cooked(attrs: &mut Vec<Attribute>) -> Vec<Attribute> {
let mut cooked = Vec::new();
// FIXME: Replace with `Vec::drain_filter` once it is stable.
let mut i = 0;
while i < attrs.len() {
if !is_cooked(&attrs[i]) {
i += 1;
continue;
}
cooked.push(attrs.remove(i));
}
cooked
}
#[derive(Debug)]
pub struct CommandFun {
/// `#[...]`-style attributes.
pub attributes: Vec<Attribute>,
/// Populated cooked attributes. These are attributes outside of the realm of this crate's procedural macros
/// and will appear in generated output.
pub cooked: Vec<Attribute>,
pub visibility: Visibility,
pub name: Ident,
pub args: Vec<Argument>,
pub body: Vec<Stmt>,
}
impl Parse for CommandFun {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let mut attributes = input.call(Attribute::parse_outer)?;
// `#[doc = "..."]` is a cooked attribute but it is special-cased for commands.
for attr in &mut attributes {
// Rename documentation comment attributes (`#[doc = "..."]`) to `#[description = "..."]`.
if attr.path.is_ident("doc") {
attr.path = Path::from(PathSegment::from(Ident::new(
"description",
Span::call_site(),
)));
}
}
let cooked = remove_cooked(&mut attributes);
let visibility = input.parse::<Visibility>()?;
input.parse::<Token![async]>()?;
input.parse::<Token![fn]>()?;
let name = input.parse()?;
// (...)
let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?;
// { ... }
let bcont;
braced!(bcont in input);
let body = bcont.call(Block::parse_within)?;
let args = args
.into_iter()
.map(parse_argument)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
attributes,
cooked,
visibility,
name,
args,
body,
})
}
}
impl ToTokens for CommandFun {
fn to_tokens(&self, stream: &mut TokenStream2) {
let Self {
attributes: _,
cooked,
visibility,
name,
args,
body,
} = self;
stream.extend(quote! {
#(#cooked)*
#visibility async fn #name (#(#args),*) -> () {
#(#body)*
}
});
}
}
#[derive(Debug)]
pub enum PermissionLevel {
Unrestricted,
Managed,
Restricted,
}
impl Default for PermissionLevel {
fn default() -> Self {
Self::Unrestricted
}
}
impl PermissionLevel {
pub fn from_str(s: &str) -> Option<Self> {
Some(match s.to_uppercase().as_str() {
"UNRESTRICTED" => Self::Unrestricted,
"MANAGED" => Self::Managed,
"RESTRICTED" => Self::Restricted,
_ => return None,
})
}
}
impl ToTokens for PermissionLevel {
fn to_tokens(&self, stream: &mut TokenStream2) {
let path = quote!(crate::framework::PermissionLevel);
let variant;
match self {
Self::Unrestricted => {
variant = quote!(Unrestricted);
}
Self::Managed => {
variant = quote!(Managed);
}
Self::Restricted => {
variant = quote!(Restricted);
}
}
stream.extend(quote! {
#path::#variant
});
}
}
#[derive(Debug, Default)]
pub struct Options {
pub permission_level: PermissionLevel,
pub supports_dm: bool,
pub can_blacklist: bool,
}
impl Options {
#[inline]
pub fn new() -> Self {
let mut options = Self::default();
options.can_blacklist = true;
options.supports_dm = true;
options
}
}

3
rustfmt.toml Normal file
View File

@ -0,0 +1,3 @@
imports_granularity = "Crate"
group_imports = "StdExternalCrate"
use_small_heuristics = "Max"

View File

@ -1,40 +1,14 @@
use regex_command_attr::command;
use serenity::{builder::CreateEmbedFooter, client::Context, model::channel::Message};
use chrono::offset::Utc; use chrono::offset::Utc;
use regex_command_attr::command;
use serenity::{builder::CreateEmbedFooter, client::Context};
use crate::{ use crate::{
command_help, framework::{CommandInvoke, CreateGenericResponse},
consts::DEFAULT_PREFIX, models::CtxData,
get_ctx_data, THEME_COLOR,
language_manager::LanguageManager,
models::{user_data::UserData, CtxGuildData},
FrameworkCtx, THEME_COLOR,
}; };
use std::{ fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter {
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
#[command]
#[can_blacklist(false)]
async fn ping(ctx: &Context, msg: &Message, _args: String) {
let now = SystemTime::now();
let since_epoch = now
.duration_since(UNIX_EPOCH)
.expect("Time calculated as going backwards. Very bad");
let delta = since_epoch.as_millis() as i64 - msg.timestamp.timestamp_millis();
let _ = msg
.channel_id
.say(&ctx, format!("Time taken to receive message: {}ms", delta))
.await;
}
async fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter {
let shard_count = ctx.cache.shard_count(); let shard_count = ctx.cache.shard_count();
let shard = ctx.shard_id; let shard = ctx.shard_id;
@ -49,173 +23,140 @@ async fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut Cr
} }
#[command] #[command]
#[can_blacklist(false)] #[description("Get an overview of the bot commands")]
async fn help(ctx: &Context, msg: &Message, args: String) { async fn help(ctx: &Context, invoke: &mut CommandInvoke) {
async fn default_help( let footer = footer(ctx);
ctx: &Context,
msg: &Message,
lm: Arc<LanguageManager>,
prefix: &str,
language: &str,
) {
let desc = lm.get(language, "help/desc").replace("{prefix}", prefix);
let footer = footer(ctx).await;
let _ = msg let _ = invoke
.channel_id .respond(
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Help Menu")
.description(desc)
.field(
lm.get(language, "help/setup_title"),
"`lang` `timezone` `meridian`",
true,
)
.field(
lm.get(language, "help/mod_title"),
"`prefix` `blacklist` `restrict` `alias`",
true,
)
.field(
lm.get(language, "help/reminder_title"),
"`remind` `interval` `natural` `look` `countdown`",
true,
)
.field(
lm.get(language, "help/reminder_mod_title"),
"`del` `offset` `pause` `nudge`",
true,
)
.field(
lm.get(language, "help/info_title"),
"`help` `info` `donate` `clock`",
true,
)
.field(
lm.get(language, "help/todo_title"),
"`todo` `todos` `todoc`",
true,
)
.field(lm.get(language, "help/other_title"), "`timer`", true)
.footer(footer)
.color(*THEME_COLOR)
})
})
.await;
}
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&msg.author, &pool);
let prefix = ctx.prefix(msg.guild_id);
if !args.is_empty() {
let framework = ctx
.data
.read()
.await
.get::<FrameworkCtx>()
.cloned()
.expect("Could not get FrameworkCtx from data");
let matched = framework
.commands
.get(args.as_str())
.map(|inner| inner.name);
if let Some(command_name) = matched {
command_help(ctx, msg, lm, &prefix.await, &language.await, command_name).await
} else {
default_help(ctx, msg, lm, &prefix.await, &language.await).await;
}
} else {
default_help(ctx, msg, lm, &prefix.await, &language.await).await;
}
}
#[command]
async fn info(ctx: &Context, msg: &Message, _args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&msg.author, &pool);
let prefix = ctx.prefix(msg.guild_id);
let current_user = ctx.cache.current_user();
let footer = footer(ctx).await;
let desc = lm
.get(&language.await, "info")
.replacen("{user}", &current_user.name, 1)
.replace("{default_prefix}", &*DEFAULT_PREFIX)
.replace("{prefix}", &prefix.await);
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Info")
.description(desc)
.footer(footer)
.color(*THEME_COLOR)
})
})
.await;
}
#[command]
async fn donate(ctx: &Context, msg: &Message, _args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&msg.author, &pool).await;
let desc = lm.get(&language, "donate");
let footer = footer(ctx).await;
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Donate")
.description(desc)
.footer(footer)
.color(*THEME_COLOR)
})
})
.await;
}
#[command]
async fn dashboard(ctx: &Context, msg: &Message, _args: String) {
let footer = footer(ctx).await;
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Dashboard")
.description("https://reminder-bot.com/dashboard")
.footer(footer)
.color(*THEME_COLOR)
})
})
.await;
}
#[command]
async fn clock(ctx: &Context, msg: &Message, _args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&msg.author, &pool).await;
let timezone = UserData::timezone_of(&msg.author, &pool).await;
let now = Utc::now().with_timezone(&timezone);
let clock_display = lm.get(&language, "clock/time");
let _ = msg
.channel_id
.say(
&ctx, &ctx,
clock_display.replacen("{}", &now.format("%H:%M").to_string(), 1), CreateGenericResponse::new().embed(|e| {
e.title("Help")
.color(*THEME_COLOR)
.description(
"__Info Commands__
`/help` `/info` `/donate` `/dashboard` `/clock`
*run these commands with no options*
__Reminder Commands__
`/remind` - Create a new reminder that will send a message at a certain time
`/timer` - Start a timer from now, that will count time passed. Also used to view and remove timers
__Reminder Management__
`/del` - Delete reminders
`/look` - View reminders
`/pause` - Pause all reminders on the channel
`/offset` - Move all reminders by a certain time
`/nudge` - Move all new reminders on this channel by a certain time
__Todo Commands__
`/todo` - Add, view and manage the server, channel or user todo lists
__Setup Commands__
`/timezone` - Set your timezone (necessary for `/remind` to work properly)
__Advanced Commands__
`/macro` - Record and replay command sequences
",
)
.footer(footer)
}),
)
.await;
}
#[command]
#[aliases("invite")]
#[description("Get information about the bot")]
async fn info(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx);
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Info")
.description(format!(
"Help: `/help`
**Welcome to Reminder Bot!**
Developer: <@203532103185465344>
Icon: <@253202252821430272>
Find me on https://discord.jellywx.com and on https://github.com/JellyWX :)
Invite the bot: https://invite.reminder-bot.com/
Use our dashboard: https://reminder-bot.com/",
))
.footer(footer)
.color(*THEME_COLOR)
}),
)
.await;
}
#[command]
#[description("Details on supporting the bot and Patreon benefits")]
#[group("Info")]
async fn donate(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx);
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Donate")
.description("Thinking of adding a monthly contribution? Click below for my Patreon and official bot server :)
**https://www.patreon.com/jellywx/**
**https://discord.jellywx.com/**
When you subscribe, Patreon will automatically rank you up on our Discord server (make sure you link your Patreon and Discord accounts!)
With your new rank, you'll be able to:
• Set repeating reminders with `interval`, `natural` or the dashboard
• Use unlimited uploads on SoundFX
(Also, members of servers you __own__ will be able to set repeating reminders via commands)
Just $2 USD/month!
*Please note, you must be in the JellyWX Discord server to receive Patreon features*")
.footer(footer)
.color(*THEME_COLOR)
}),
)
.await;
}
#[command]
#[description("Get the link to the online dashboard")]
#[group("Info")]
async fn dashboard(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx);
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
e.title("Dashboard")
.description("**https://reminder-bot.com/dashboard**")
.footer(footer)
.color(*THEME_COLOR)
}),
)
.await;
}
#[command]
#[description("View the current time in your selected timezone")]
#[group("Info")]
async fn clock(ctx: &Context, invoke: &mut CommandInvoke) {
let ud = ctx.user_data(&invoke.author_id()).await.unwrap();
let now = Utc::now().with_timezone(&ud.timezone());
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content(format!("Current time: {}", now.format("%H:%M"))),
) )
.await; .await;
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,455 +1,260 @@
use regex_command_attr::command; use regex_command_attr::command;
use serenity::client::Context;
use serenity::{
async_trait,
client::Context,
constants::MESSAGE_CODE_LIMIT,
model::{
channel::Message,
id::{ChannelId, GuildId, UserId},
},
};
use std::fmt;
use crate::{ use crate::{
command_help, get_ctx_data, component_models::{
models::{user_data::UserData, CtxGuildData}, pager::{Pager, TodoPager},
ComponentDataModel, TodoSelector,
},
consts::{EMBED_DESCRIPTION_MAX_LENGTH, SELECT_MAX_ENTRIES, THEME_COLOR},
framework::{CommandInvoke, CommandOptions, CreateGenericResponse},
hooks::CHECK_GUILD_PERMISSIONS_HOOK,
SQLPool,
}; };
use sqlx::MySqlPool;
use std::convert::TryFrom;
#[derive(Debug)] #[command]
struct TodoNotFound; #[description("Manage todo lists")]
#[subcommandgroup("server")]
#[description("Manage the server todo list")]
#[subcommand("add")]
#[description("Add an item to the server todo list")]
#[arg(
name = "task",
description = "The task to add to the todo list",
kind = "String",
required = true
)]
#[subcommand("view")]
#[description("View and remove from the server todo list")]
#[subcommandgroup("channel")]
#[description("Manage the channel todo list")]
#[subcommand("add")]
#[description("Add to the channel todo list")]
#[arg(
name = "task",
description = "The task to add to the todo list",
kind = "String",
required = true
)]
#[subcommand("view")]
#[description("View and remove from the channel todo list")]
#[subcommandgroup("user")]
#[description("Manage your personal todo list")]
#[subcommand("add")]
#[description("Add to your personal todo list")]
#[arg(
name = "task",
description = "The task to add to the todo list",
kind = "String",
required = true
)]
#[subcommand("view")]
#[description("View and remove from your personal todo list")]
#[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn todo(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) {
if invoke.guild_id().is_none() && args.subcommand_group != Some("user".to_string()) {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("Please use `/todo user` in direct messages"),
)
.await;
} else {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
impl std::error::Error for TodoNotFound {} let keys = match args.subcommand_group.as_ref().unwrap().as_str() {
impl fmt::Display for TodoNotFound { "server" => (None, None, invoke.guild_id().map(|g| g.0)),
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { "channel" => (None, Some(invoke.channel_id().0), invoke.guild_id().map(|g| g.0)),
write!(f, "Todo not found") _ => (Some(invoke.author_id().0), None, None),
}
}
struct Todo {
id: u32,
value: String,
}
struct TodoTarget {
user: UserId,
guild: Option<GuildId>,
channel: Option<ChannelId>,
}
impl TodoTarget {
pub fn command(&self, subcommand_opt: Option<SubCommand>) -> String {
let context = if self.channel.is_some() {
"channel"
} else if self.guild.is_some() {
"guild"
} else {
"user"
}; };
if let Some(subcommand) = subcommand_opt { match args.get("task") {
format!("todo {} {}", context, subcommand.to_string()) Some(task) => {
} else { let task = task.to_string();
format!("todo {}", context)
}
}
pub fn name(&self) -> String { sqlx::query!(
if self.channel.is_some() { "INSERT INTO todos (user_id, channel_id, guild_id, value) VALUES ((SELECT id FROM users WHERE user = ?), (SELECT id FROM channels WHERE channel = ?), (SELECT id FROM guilds WHERE guild = ?), ?)",
"Channel" keys.0,
} else if self.guild.is_some() { keys.1,
"Guild" keys.2,
} else { task
"User" )
} .execute(&pool)
.to_string() .await
} .unwrap();
pub async fn view( let _ = invoke
&self, .respond(&ctx, CreateGenericResponse::new().content("Item added to todo list"))
pool: MySqlPool,
) -> Result<Vec<Todo>, Box<dyn std::error::Error + Send + Sync>> {
Ok(if let Some(cid) = self.channel {
sqlx::query_as!(
Todo,
"
SELECT id, value FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
",
cid.as_u64()
)
.fetch_all(&pool)
.await?
} else if let Some(gid) = self.guild {
sqlx::query_as!(
Todo,
"
SELECT id, value FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
",
gid.as_u64()
)
.fetch_all(&pool)
.await?
} else {
sqlx::query_as!(
Todo,
"
SELECT id, value FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
",
self.user.as_u64()
)
.fetch_all(&pool)
.await?
})
}
pub async fn add(
&self,
value: String,
pool: MySqlPool,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let (Some(cid), Some(gid)) = (self.channel, self.guild) {
sqlx::query!(
"
INSERT INTO todos (user_id, guild_id, channel_id, value) VALUES (
(SELECT id FROM users WHERE user = ?),
(SELECT id FROM guilds WHERE guild = ?),
(SELECT id FROM channels WHERE channel = ?),
?
)
",
self.user.as_u64(),
gid.as_u64(),
cid.as_u64(),
value
)
.execute(&pool)
.await?;
} else if let Some(gid) = self.guild {
sqlx::query!(
"
INSERT INTO todos (user_id, guild_id, value) VALUES (
(SELECT id FROM users WHERE user = ?),
(SELECT id FROM guilds WHERE guild = ?),
?
)
",
self.user.as_u64(),
gid.as_u64(),
value
)
.execute(&pool)
.await?;
} else {
sqlx::query!(
"
INSERT INTO todos (user_id, value) VALUES (
(SELECT id FROM users WHERE user = ?),
?
)
",
self.user.as_u64(),
value
)
.execute(&pool)
.await?;
}
Ok(())
}
pub async fn remove(
&self,
num: usize,
pool: &MySqlPool,
) -> Result<Todo, Box<dyn std::error::Error + Sync + Send>> {
let todos = self.view(pool.clone()).await?;
if let Some(removal_todo) = todos.get(num) {
let deleting = sqlx::query_as!(
Todo,
"
SELECT id, value FROM todos WHERE id = ?
",
removal_todo.id
)
.fetch_one(&pool.clone())
.await?;
sqlx::query!(
"
DELETE FROM todos WHERE id = ?
",
removal_todo.id
)
.execute(pool)
.await?;
Ok(deleting)
} else {
Err(Box::new(TodoNotFound))
}
}
pub async fn clear(
&self,
pool: &MySqlPool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
if let Some(cid) = self.channel {
sqlx::query!(
"
DELETE FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
",
cid.as_u64()
)
.execute(pool)
.await?;
} else if let Some(gid) = self.guild {
sqlx::query!(
"
DELETE FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
",
gid.as_u64()
)
.execute(pool)
.await?;
} else {
sqlx::query!(
"
DELETE FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
",
self.user.as_u64()
)
.execute(pool)
.await?;
}
Ok(())
}
async fn execute(&self, ctx: &Context, msg: &Message, subcommand: SubCommand, extra: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let prefix = ctx.prefix(msg.guild_id).await;
match subcommand {
SubCommand::View => {
let todo_items = self.view(pool).await.unwrap();
let mut todo_groups = vec!["".to_string()];
let mut char_count = 0;
todo_items.iter().enumerate().for_each(|(count, todo)| {
let display = format!("{}: {}\n", count + 1, todo.value);
if char_count + display.len() > MESSAGE_CODE_LIMIT as usize {
char_count = display.len();
todo_groups.push(display);
} else {
char_count += display.len();
let last_group = todo_groups.pop().unwrap();
todo_groups.push(format!("{}{}", last_group, display));
}
});
for group in todo_groups {
let _ = msg
.channel_id
.send_message(&ctx, |m| {
m.embed(|e| e.title(format!("{} Todo", self.name())).description(group))
})
.await;
}
}
SubCommand::Add => {
let content = lm
.get(&user_data.language, "todo/added")
.replacen("{name}", &extra, 1);
self.add(extra, pool).await.unwrap();
let _ = msg
.channel_id
.send_message(&ctx, |m| {
m.content(content).allowed_mentions(|m| m.empty_parse())
})
.await; .await;
} }
None => {
SubCommand::Remove => { let values = if let Some(uid) = keys.0 {
if let Ok(num) = extra.parse::<usize>() { sqlx::query!(
if let Ok(todo) = self.remove(num - 1, &pool).await { "SELECT todos.id, value FROM todos
let content = lm.get(&user_data.language, "todo/removed").replacen( INNER JOIN users ON todos.user_id = users.id
"{}", WHERE users.user = ?",
&todo.value, uid,
1, )
); .fetch_all(&pool)
.await
let _ = msg .unwrap()
.channel_id .iter()
.send_message(&ctx, |m| { .map(|row| (row.id as usize, row.value.clone()))
m.content(content).allowed_mentions(|m| m.empty_parse()) .collect::<Vec<(usize, String)>>()
}) } else if let Some(cid) = keys.1 {
.await; sqlx::query!(
} else { "SELECT todos.id, value FROM todos
let _ = msg INNER JOIN channels ON todos.channel_id = channels.id
.channel_id WHERE channels.channel = ?",
.say(&ctx, lm.get(&user_data.language, "todo/error_index")) cid,
.await; )
} .fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
} else { } else {
let content = lm sqlx::query!(
.get(&user_data.language, "todo/error_value") "SELECT todos.id, value FROM todos
.replacen("{prefix}", &prefix, 1) INNER JOIN guilds ON todos.guild_id = guilds.id
.replacen("{command}", &self.command(Some(subcommand)), 1); WHERE guilds.guild = ?",
keys.2,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
};
let _ = msg.channel_id.say(&ctx, content).await; let resp = show_todo_page(&values, 0, keys.0, keys.1, keys.2);
}
}
SubCommand::Clear => { invoke.respond(&ctx, resp).await.unwrap();
self.clear(&pool).await.unwrap();
let content = lm.get(&user_data.language, "todo/cleared");
let _ = msg.channel_id.say(&ctx, content).await;
} }
} }
} }
} }
enum SubCommand { pub fn max_todo_page(todo_values: &[(usize, String)]) -> usize {
View, let mut rows = 0;
Add, let mut char_count = 0;
Remove,
Clear,
}
impl TryFrom<Option<&str>> for SubCommand { todo_values.iter().enumerate().map(|(c, (_, v))| format!("{}: {}", c, v)).fold(
type Error = (); 1,
|mut pages, text| {
rows += 1;
char_count += text.len();
fn try_from(value: Option<&str>) -> Result<Self, Self::Error> { if char_count > EMBED_DESCRIPTION_MAX_LENGTH || rows > SELECT_MAX_ENTRIES {
match value { rows = 1;
Some("add") => Ok(SubCommand::Add), char_count = text.len();
pages += 1;
Some("remove") => Ok(SubCommand::Remove),
Some("clear") => Ok(SubCommand::Clear),
None | Some("") => Ok(SubCommand::View),
Some(_unrecognised) => Err(()),
}
}
}
impl ToString for SubCommand {
fn to_string(&self) -> String {
match self {
SubCommand::View => "",
SubCommand::Add => "add",
SubCommand::Remove => "remove",
SubCommand::Clear => "clear",
}
.to_string()
}
}
#[async_trait]
trait Execute {
async fn execute(self, ctx: &Context, msg: &Message, extra: String, target: TodoTarget);
}
#[async_trait]
impl Execute for Result<SubCommand, ()> {
async fn execute(self, ctx: &Context, msg: &Message, extra: String, target: TodoTarget) {
if let Ok(subcommand) = self {
target.execute(ctx, msg, subcommand, extra).await;
} else {
show_help(&ctx, msg, Some(target)).await;
}
}
}
#[command("todo")]
async fn todo_user(ctx: &Context, msg: &Message, args: String) {
let mut split = args.split(' ');
let target = TodoTarget {
user: msg.author.id,
guild: None,
channel: None,
};
let subcommand_opt = SubCommand::try_from(split.next());
subcommand_opt
.execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
.await;
}
#[command("todoc")]
#[supports_dm(false)]
#[permission_level(Managed)]
async fn todo_channel(ctx: &Context, msg: &Message, args: String) {
let mut split = args.split(' ');
let target = TodoTarget {
user: msg.author.id,
guild: msg.guild_id,
channel: Some(msg.channel_id),
};
let subcommand_opt = SubCommand::try_from(split.next());
subcommand_opt
.execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
.await;
}
#[command("todos")]
#[supports_dm(false)]
#[permission_level(Managed)]
async fn todo_guild(ctx: &Context, msg: &Message, args: String) {
let mut split = args.split(' ');
let target = TodoTarget {
user: msg.author.id,
guild: msg.guild_id,
channel: None,
};
let subcommand_opt = SubCommand::try_from(split.next());
subcommand_opt
.execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
.await;
}
async fn show_help(ctx: &Context, msg: &Message, target: Option<TodoTarget>) {
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&msg.author, &pool);
let prefix = ctx.prefix(msg.guild_id);
let command = match target {
None => "todo",
Some(t) => {
if t.channel.is_some() {
"todoc"
} else if t.guild.is_some() {
"todos"
} else {
"todo"
} }
}
pages
},
)
}
pub fn show_todo_page(
todo_values: &[(usize, String)],
page: usize,
user_id: Option<u64>,
channel_id: Option<u64>,
guild_id: Option<u64>,
) -> CreateGenericResponse {
let pager = TodoPager::new(page, user_id, channel_id, guild_id);
let pages = max_todo_page(todo_values);
let mut page = page;
if page >= pages {
page = pages - 1;
}
let mut char_count = 0;
let mut rows = 0;
let mut skipped_rows = 0;
let mut skipped_char_count = 0;
let mut first_num = 0;
let mut skipped_pages = 0;
let (todo_ids, display_vec): (Vec<usize>, Vec<String>) = todo_values
.iter()
.enumerate()
.map(|(c, (i, v))| (i, format!("`{}`: {}", c + 1, v)))
.skip_while(|(_, p)| {
first_num += 1;
skipped_rows += 1;
skipped_char_count += p.len();
if skipped_char_count > EMBED_DESCRIPTION_MAX_LENGTH
|| skipped_rows > SELECT_MAX_ENTRIES
{
skipped_rows = 1;
skipped_char_count = p.len();
skipped_pages += 1;
}
skipped_pages < page
})
.take_while(|(_, p)| {
rows += 1;
char_count += p.len();
char_count < EMBED_DESCRIPTION_MAX_LENGTH && rows <= SELECT_MAX_ENTRIES
})
.unzip();
let display = display_vec.join("\n");
let title = if user_id.is_some() {
"Your"
} else if channel_id.is_some() {
"Channel"
} else {
"Server"
}; };
command_help(ctx, msg, lm, &prefix.await, &language.await, command).await; if todo_ids.is_empty() {
CreateGenericResponse::new().embed(|e| {
e.title(format!("{} Todo List", title))
.description("Todo List Empty!")
.footer(|f| f.text(format!("Page {} of {}", page + 1, pages)))
.color(*THEME_COLOR)
})
} else {
let todo_selector =
ComponentDataModel::TodoSelector(TodoSelector { page, user_id, channel_id, guild_id });
CreateGenericResponse::new()
.embed(|e| {
e.title(format!("{} Todo List", title))
.description(display)
.footer(|f| f.text(format!("Page {} of {}", page + 1, pages)))
.color(*THEME_COLOR)
})
.components(|comp| {
pager.create_button_row(pages, comp);
comp.create_action_row(|row| {
row.create_select_menu(|menu| {
menu.custom_id(todo_selector.to_custom_id()).options(|opt| {
for (count, (id, disp)) in todo_ids.iter().zip(&display_vec).enumerate()
{
opt.create_option(|o| {
o.label(format!("Mark {} complete", count + first_num))
.value(id)
.description(disp.split_once(" ").unwrap_or(("", "")).1)
});
}
opt
})
})
})
})
}
} }

310
src/component_models/mod.rs Normal file
View File

@ -0,0 +1,310 @@
pub(crate) mod pager;
use std::io::Cursor;
use chrono_tz::Tz;
use num_integer::Integer;
use rmp_serde::Serializer;
use serde::{Deserialize, Serialize};
use serenity::{
builder::CreateEmbed,
client::Context,
model::{
channel::Channel,
interactions::{message_component::MessageComponentInteraction, InteractionResponseType},
prelude::InteractionApplicationCommandCallbackDataFlags,
},
};
use crate::{
commands::{
moderation_cmds::{max_macro_page, show_macro_page},
reminder_cmds::{max_delete_page, show_delete_page},
todo_cmds::{max_todo_page, show_todo_page},
},
component_models::pager::{DelPager, LookPager, MacroPager, Pager, TodoPager},
consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR},
framework::CommandInvoke,
models::{command_macro::CommandMacro, reminder::Reminder},
SQLPool,
};
#[derive(Deserialize, Serialize)]
#[serde(tag = "type")]
#[repr(u8)]
pub enum ComponentDataModel {
LookPager(LookPager),
DelPager(DelPager),
TodoPager(TodoPager),
DelSelector(DelSelector),
TodoSelector(TodoSelector),
MacroPager(MacroPager),
}
impl ComponentDataModel {
pub fn to_custom_id(&self) -> String {
let mut buf = Vec::new();
self.serialize(&mut Serializer::new(&mut buf)).unwrap();
base64::encode(buf)
}
pub fn from_custom_id(data: &String) -> Self {
let buf = base64::decode(data)
.map_err(|e| format!("Could not decode `custom_id' {}: {:?}", data, e))
.unwrap();
let cur = Cursor::new(buf);
rmp_serde::from_read(cur).unwrap()
}
pub async fn act(&self, ctx: &Context, component: MessageComponentInteraction) {
match self {
ComponentDataModel::LookPager(pager) => {
let flags = pager.flags;
let channel_opt = component.channel_id.to_channel_cached(&ctx);
let channel_id = if let Some(Channel::Guild(channel)) = channel_opt {
if Some(channel.guild_id) == component.guild_id {
flags.channel_id.unwrap_or(component.channel_id)
} else {
component.channel_id
}
} else {
component.channel_id
};
let reminders = Reminder::from_channel(ctx, channel_id, &flags).await;
let pages = reminders
.iter()
.map(|reminder| reminder.display(&flags, &pager.timezone))
.fold(0, |t, r| t + r.len())
.div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH);
let channel_name =
if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx) {
Some(channel.name)
} else {
None
};
let next_page = pager.next_page(pages);
let mut char_count = 0;
let mut skip_char_count = 0;
let display = reminders
.iter()
.map(|reminder| reminder.display(&flags, &pager.timezone))
.skip_while(|p| {
skip_char_count += p.len();
skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * next_page as usize
})
.take_while(|p| {
char_count += p.len();
char_count < EMBED_DESCRIPTION_MAX_LENGTH
})
.collect::<Vec<String>>()
.join("\n");
let mut embed = CreateEmbed::default();
embed
.title(format!(
"Reminders{}",
channel_name.map_or(String::new(), |n| format!(" on #{}", n))
))
.description(display)
.footer(|f| f.text(format!("Page {} of {}", next_page + 1, pages)))
.color(*THEME_COLOR);
let _ = component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::UpdateMessage).interaction_response_data(
|response| {
response.embeds(vec![embed]).components(|comp| {
pager.create_button_row(pages, comp);
comp
})
},
)
})
.await;
}
ComponentDataModel::DelPager(pager) => {
let reminders =
Reminder::from_guild(ctx, component.guild_id, component.user.id).await;
let max_pages = max_delete_page(&reminders, &pager.timezone);
let resp = show_delete_page(&reminders, pager.next_page(max_pages), pager.timezone);
let mut invoke = CommandInvoke::component(component);
let _ = invoke.respond(&ctx, resp).await;
}
ComponentDataModel::DelSelector(selector) => {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let selected_id = component.data.values.join(",");
sqlx::query!("DELETE FROM reminders WHERE FIND_IN_SET(id, ?)", selected_id)
.execute(&pool)
.await
.unwrap();
let reminders =
Reminder::from_guild(ctx, component.guild_id, component.user.id).await;
let resp = show_delete_page(&reminders, selector.page, selector.timezone);
let mut invoke = CommandInvoke::component(component);
let _ = invoke.respond(&ctx, resp).await;
}
ComponentDataModel::TodoPager(pager) => {
if Some(component.user.id.0) == pager.user_id || pager.user_id.is_none() {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let values = if let Some(uid) = pager.user_id {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN users ON todos.user_id = users.id
WHERE users.user = ?",
uid,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
} else if let Some(cid) = pager.channel_id {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN channels ON todos.channel_id = channels.id
WHERE channels.channel = ?",
cid,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
} else {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN guilds ON todos.guild_id = guilds.id
WHERE guilds.guild = ?",
pager.guild_id,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
};
let max_pages = max_todo_page(&values);
let resp = show_todo_page(
&values,
pager.next_page(max_pages),
pager.user_id,
pager.channel_id,
pager.guild_id,
);
let mut invoke = CommandInvoke::component(component);
let _ = invoke.respond(&ctx, resp).await;
} else {
let _ = component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.flags(
InteractionApplicationCommandCallbackDataFlags::EPHEMERAL,
)
.content("Only the user who performed the command can use these components")
})
})
.await;
}
}
ComponentDataModel::TodoSelector(selector) => {
if Some(component.user.id.0) == selector.user_id || selector.user_id.is_none() {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let selected_id = component.data.values.join(",");
sqlx::query!("DELETE FROM todos WHERE FIND_IN_SET(id, ?)", selected_id)
.execute(&pool)
.await
.unwrap();
let values = sqlx::query!(
// fucking braindead mysql use <=> instead of = for null comparison
"SELECT id, value FROM todos WHERE user_id <=> ? AND channel_id <=> ? AND guild_id <=> ?",
selector.user_id,
selector.channel_id,
selector.guild_id,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>();
let resp = show_todo_page(
&values,
selector.page,
selector.user_id,
selector.channel_id,
selector.guild_id,
);
let mut invoke = CommandInvoke::component(component);
let _ = invoke.respond(&ctx, resp).await;
} else {
let _ = component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.flags(
InteractionApplicationCommandCallbackDataFlags::EPHEMERAL,
)
.content("Only the user who performed the command can use these components")
})
})
.await;
}
}
ComponentDataModel::MacroPager(pager) => {
let mut invoke = CommandInvoke::component(component);
let macros = CommandMacro::from_guild(ctx, invoke.guild_id().unwrap()).await;
let max_page = max_macro_page(&macros);
let page = pager.next_page(max_page);
let resp = show_macro_page(&macros, page);
let _ = invoke.respond(&ctx, resp).await;
}
}
}
}
#[derive(Serialize, Deserialize)]
pub struct DelSelector {
pub page: usize,
pub timezone: Tz,
}
#[derive(Serialize, Deserialize)]
pub struct TodoSelector {
pub page: usize,
pub user_id: Option<u64>,
pub channel_id: Option<u64>,
pub guild_id: Option<u64>,
}

View File

@ -0,0 +1,411 @@
// todo split pager out into a single struct
use chrono_tz::Tz;
use serde::{Deserialize, Serialize};
use serde_repr::*;
use serenity::{builder::CreateComponents, model::interactions::message_component::ButtonStyle};
use crate::{component_models::ComponentDataModel, models::reminder::look_flags::LookFlags};
pub trait Pager {
fn next_page(&self, max_pages: usize) -> usize;
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents);
}
#[derive(Serialize_repr, Deserialize_repr)]
#[repr(u8)]
enum PageAction {
First = 0,
Previous = 1,
Refresh = 2,
Next = 3,
Last = 4,
}
#[derive(Serialize, Deserialize)]
pub struct LookPager {
pub flags: LookFlags,
pub page: usize,
action: PageAction,
pub timezone: Tz,
}
impl Pager for LookPager {
fn next_page(&self, max_pages: usize) -> usize {
match self.action {
PageAction::First => 0,
PageAction::Previous => 0.max(self.page - 1),
PageAction::Refresh => self.page,
PageAction::Next => (max_pages - 1).min(self.page + 1),
PageAction::Last => max_pages - 1,
}
}
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
let next_page = self.next_page(max_pages);
let (page_first, page_prev, page_refresh, page_next, page_last) =
LookPager::buttons(self.flags, next_page, self.timezone);
comp.create_action_row(|row| {
row.create_button(|b| {
b.label("⏮️")
.style(ButtonStyle::Primary)
.custom_id(page_first.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("◀️")
.style(ButtonStyle::Secondary)
.custom_id(page_prev.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
})
.create_button(|b| {
b.label("▶️")
.style(ButtonStyle::Secondary)
.custom_id(page_next.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
.create_button(|b| {
b.label("⏭️")
.style(ButtonStyle::Primary)
.custom_id(page_last.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
});
}
}
impl LookPager {
pub fn new(flags: LookFlags, timezone: Tz) -> Self {
Self { flags, page: 0, action: PageAction::First, timezone }
}
pub fn buttons(
flags: LookFlags,
page: usize,
timezone: Tz,
) -> (
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
) {
(
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::First,
timezone,
}),
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::Previous,
timezone,
}),
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::Refresh,
timezone,
}),
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::Next,
timezone,
}),
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::Last,
timezone,
}),
)
}
}
#[derive(Serialize, Deserialize)]
pub struct DelPager {
pub page: usize,
action: PageAction,
pub timezone: Tz,
}
impl Pager for DelPager {
fn next_page(&self, max_pages: usize) -> usize {
match self.action {
PageAction::First => 0,
PageAction::Previous => 0.max(self.page - 1),
PageAction::Refresh => self.page,
PageAction::Next => (max_pages - 1).min(self.page + 1),
PageAction::Last => max_pages - 1,
}
}
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
let next_page = self.next_page(max_pages);
let (page_first, page_prev, page_refresh, page_next, page_last) =
DelPager::buttons(next_page, self.timezone);
comp.create_action_row(|row| {
row.create_button(|b| {
b.label("⏮️")
.style(ButtonStyle::Primary)
.custom_id(page_first.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("◀️")
.style(ButtonStyle::Secondary)
.custom_id(page_prev.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
})
.create_button(|b| {
b.label("▶️")
.style(ButtonStyle::Secondary)
.custom_id(page_next.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
.create_button(|b| {
b.label("⏭️")
.style(ButtonStyle::Primary)
.custom_id(page_last.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
});
}
}
impl DelPager {
pub fn new(page: usize, timezone: Tz) -> Self {
Self { page, action: PageAction::Refresh, timezone }
}
pub fn buttons(
page: usize,
timezone: Tz,
) -> (
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
) {
(
ComponentDataModel::DelPager(DelPager { page, action: PageAction::First, timezone }),
ComponentDataModel::DelPager(DelPager { page, action: PageAction::Previous, timezone }),
ComponentDataModel::DelPager(DelPager { page, action: PageAction::Refresh, timezone }),
ComponentDataModel::DelPager(DelPager { page, action: PageAction::Next, timezone }),
ComponentDataModel::DelPager(DelPager { page, action: PageAction::Last, timezone }),
)
}
}
#[derive(Deserialize, Serialize)]
pub struct TodoPager {
pub page: usize,
action: PageAction,
pub user_id: Option<u64>,
pub channel_id: Option<u64>,
pub guild_id: Option<u64>,
}
impl Pager for TodoPager {
fn next_page(&self, max_pages: usize) -> usize {
match self.action {
PageAction::First => 0,
PageAction::Previous => 0.max(self.page - 1),
PageAction::Refresh => self.page,
PageAction::Next => (max_pages - 1).min(self.page + 1),
PageAction::Last => max_pages - 1,
}
}
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
let next_page = self.next_page(max_pages);
let (page_first, page_prev, page_refresh, page_next, page_last) =
TodoPager::buttons(next_page, self.user_id, self.channel_id, self.guild_id);
comp.create_action_row(|row| {
row.create_button(|b| {
b.label("⏮️")
.style(ButtonStyle::Primary)
.custom_id(page_first.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("◀️")
.style(ButtonStyle::Secondary)
.custom_id(page_prev.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
})
.create_button(|b| {
b.label("▶️")
.style(ButtonStyle::Secondary)
.custom_id(page_next.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
.create_button(|b| {
b.label("⏭️")
.style(ButtonStyle::Primary)
.custom_id(page_last.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
});
}
}
impl TodoPager {
pub fn new(
page: usize,
user_id: Option<u64>,
channel_id: Option<u64>,
guild_id: Option<u64>,
) -> Self {
Self { page, action: PageAction::Refresh, user_id, channel_id, guild_id }
}
pub fn buttons(
page: usize,
user_id: Option<u64>,
channel_id: Option<u64>,
guild_id: Option<u64>,
) -> (
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
) {
(
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::First,
user_id,
channel_id,
guild_id,
}),
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::Previous,
user_id,
channel_id,
guild_id,
}),
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::Refresh,
user_id,
channel_id,
guild_id,
}),
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::Next,
user_id,
channel_id,
guild_id,
}),
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::Last,
user_id,
channel_id,
guild_id,
}),
)
}
}
#[derive(Serialize, Deserialize)]
pub struct MacroPager {
pub page: usize,
action: PageAction,
}
impl Pager for MacroPager {
fn next_page(&self, max_pages: usize) -> usize {
match self.action {
PageAction::First => 0,
PageAction::Previous => 0.max(self.page - 1),
PageAction::Refresh => self.page,
PageAction::Next => (max_pages - 1).min(self.page + 1),
PageAction::Last => max_pages - 1,
}
}
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
let next_page = self.next_page(max_pages);
let (page_first, page_prev, page_refresh, page_next, page_last) =
MacroPager::buttons(next_page);
comp.create_action_row(|row| {
row.create_button(|b| {
b.label("⏮️")
.style(ButtonStyle::Primary)
.custom_id(page_first.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("◀️")
.style(ButtonStyle::Secondary)
.custom_id(page_prev.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
})
.create_button(|b| {
b.label("▶️")
.style(ButtonStyle::Secondary)
.custom_id(page_next.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
.create_button(|b| {
b.label("⏭️")
.style(ButtonStyle::Primary)
.custom_id(page_last.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
});
}
}
impl MacroPager {
pub fn new(page: usize) -> Self {
Self { page, action: PageAction::Refresh }
}
pub fn buttons(
page: usize,
) -> (
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
) {
(
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::First }),
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Previous }),
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Refresh }),
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Next }),
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Last }),
)
}
}

View File

@ -1,6 +1,8 @@
pub const DAY: u64 = 86_400; pub const DAY: u64 = 86_400;
pub const HOUR: u64 = 3_600; pub const HOUR: u64 = 3_600;
pub const MINUTE: u64 = 60; pub const MINUTE: u64 = 60;
pub const EMBED_DESCRIPTION_MAX_LENGTH: usize = 4000;
pub const SELECT_MAX_ENTRIES: usize = 25;
pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
@ -8,43 +10,25 @@ const THEME_COLOR_FALLBACK: u32 = 0x8fb677;
use std::{collections::HashSet, env, iter::FromIterator}; use std::{collections::HashSet, env, iter::FromIterator};
use regex::{Regex, RegexBuilder}; use regex::Regex;
use serenity::http::AttachmentType;
lazy_static! { lazy_static! {
pub static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); pub static ref REMIND_INTERVAL: u64 = env::var("REMIND_INTERVAL")
.map(|inner| inner.parse::<u64>().ok())
pub static ref REGEX_ROLE: Regex = Regex::new(r#"<@&(\d+)>"#).unwrap(); .ok()
.flatten()
pub static ref REGEX_COMMANDS: Regex = Regex::new(r#"([a-z]+)"#).unwrap(); .unwrap_or(10);
pub static ref DEFAULT_AVATAR: AttachmentType<'static> = (
pub static ref REGEX_ALIAS: Regex = include_bytes!(concat!(
Regex::new(r#"(?P<name>[\S]{1,12})(?:(?: (?P<cmd>.*)$)|$)"#).unwrap(); env!("CARGO_MANIFEST_DIR"),
"/assets/",
pub static ref REGEX_CONTENT_SUBSTITUTION: Regex = Regex::new(r#"<<((?P<user>\d+)|(?P<role>.{1,100}))>>"#).unwrap(); env!("WEBHOOK_AVATAR", "WEBHOOK_AVATAR not provided for compilation")
)) as &[u8],
env!("WEBHOOK_AVATAR"),
)
.into();
pub static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"\s*<(#|@)(?:!)?(\d+)>\s*"#).unwrap(); pub static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"\s*<(#|@)(?:!)?(\d+)>\s*"#).unwrap();
pub static ref REGEX_REMIND_COMMAND: Regex = RegexBuilder::new(
r#"(?P<mentions>(?:<@\d+>\s+|<@!\d+>\s+|<#\d+>\s+)*)(?P<time>(?:(?:\d+)(?:s|m|h|d|:|/|-|))+)(?:\s+(?P<interval>(?:(?:\d+)(?:s|m|h|d|))+))?(?:\s+(?P<expires>(?:(?:\d+)(?:s|m|h|d|:|/|-|))+))?\s+(?P<content>.*)"#
)
.dot_matches_new_line(true)
.build()
.unwrap();
pub static ref REGEX_NATURAL_COMMAND_1: Regex = RegexBuilder::new(
r#"(?P<time>.*?)(?:\s+)(?:send|say)(?:\s+)(?P<msg>.*?)(?:(?:\s+)to(?:\s+)(?P<mentions>((?:<@\d+>)|(?:<@!\d+>)|(?:<#\d+>)|(?:\s+))+))?$"#
)
.dot_matches_new_line(true)
.build()
.unwrap();
pub static ref REGEX_NATURAL_COMMAND_2: Regex = RegexBuilder::new(
r#"(?P<msg>.*)(?:\s+)every(?:\s+)(?P<interval>.*?)(?:(?:\s+)(?:until|for)(?:\s+)(?P<expires>.*?))?$"#
)
.dot_matches_new_line(true)
.build()
.unwrap();
pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter( pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(
env::var("SUBSCRIPTION_ROLES") env::var("SUBSCRIPTION_ROLES")
.map(|var| var .map(|var| var
@ -53,38 +37,23 @@ lazy_static! {
.collect::<Vec<u64>>()) .collect::<Vec<u64>>())
.unwrap_or_else(|_| Vec::new()) .unwrap_or_else(|_| Vec::new())
); );
pub static ref CNC_GUILD: Option<u64> =
pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD") env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten();
.map(|var| var.parse::<u64>().ok())
.ok()
.flatten();
pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL") pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL")
.ok() .ok()
.map(|inner| inner.parse::<i64>().ok()) .map(|inner| inner.parse::<i64>().ok())
.flatten() .flatten()
.unwrap_or(600); .unwrap_or(600);
pub static ref MAX_TIME: i64 = env::var("MAX_TIME") pub static ref MAX_TIME: i64 = env::var("MAX_TIME")
.ok() .ok()
.map(|inner| inner.parse::<i64>().ok()) .map(|inner| inner.parse::<i64>().ok())
.flatten() .flatten()
.unwrap_or(60 * 60 * 24 * 365 * 50); .unwrap_or(60 * 60 * 24 * 365 * 50);
pub static ref LOCAL_TIMEZONE: String = pub static ref LOCAL_TIMEZONE: String =
env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string()); env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string());
pub static ref THEME_COLOR: u32 = env::var("THEME_COLOR")
pub static ref LOCAL_LANGUAGE: String = .map_or(THEME_COLOR_FALLBACK, |inner| u32::from_str_radix(&inner, 16)
env::var("LOCAL_LANGUAGE").unwrap_or_else(|_| "EN".to_string()); .unwrap_or(THEME_COLOR_FALLBACK));
pub static ref DEFAULT_PREFIX: String =
env::var("DEFAULT_PREFIX").unwrap_or_else(|_| "$".to_string());
pub static ref THEME_COLOR: u32 = env::var("THEME_COLOR").map_or(
THEME_COLOR_FALLBACK,
|inner| u32::from_str_radix(&inner, 16).unwrap_or(THEME_COLOR_FALLBACK)
);
pub static ref PYTHON_LOCATION: String = pub static ref PYTHON_LOCATION: String =
env::var("PYTHON_LOCATION").unwrap_or_else(|_| "venv/bin/python3".to_string()); env::var("PYTHON_LOCATION").unwrap_or_else(|_| "venv/bin/python3".to_string());
} }

File diff suppressed because it is too large Load Diff

152
src/hooks.rs Normal file
View File

@ -0,0 +1,152 @@
use regex_command_attr::check;
use serenity::{client::Context, model::channel::Channel};
use crate::{
framework::{CommandInvoke, CommandOptions, CreateGenericResponse, HookResult},
moderation_cmds, RecordingMacros,
};
#[check]
pub async fn guild_only(
ctx: &Context,
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if invoke.guild_id().is_some() {
HookResult::Continue
} else {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("This command can only be used in servers"),
)
.await;
HookResult::Halt
}
}
#[check]
pub async fn macro_check(
ctx: &Context,
invoke: &mut CommandInvoke,
args: &CommandOptions,
) -> HookResult {
if let Some(guild_id) = invoke.guild_id() {
if args.command != moderation_cmds::MACRO_CMD_COMMAND.names[0] {
let active_recordings =
ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
let mut lock = active_recordings.write().await;
if let Some(command_macro) = lock.get_mut(&(guild_id, invoke.author_id())) {
if command_macro.commands.len() >= 5 {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("5 commands already recorded. Please use `/macro finish` to end recording."),
)
.await;
} else {
command_macro.commands.push(args.clone());
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("Command recorded to macro"),
)
.await;
}
HookResult::Halt
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
}
#[check]
pub async fn check_self_permissions(
ctx: &Context,
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let user_id = ctx.cache.current_user_id();
let manage_webhooks =
guild.member_permissions(&ctx, user_id).await.map_or(false, |p| p.manage_webhooks());
let (view_channel, send_messages, embed_links) = invoke
.channel_id()
.to_channel_cached(&ctx)
.map(|c| {
if let Channel::Guild(channel) = c {
channel.permissions_for_user(ctx, user_id).ok()
} else {
None
}
})
.flatten()
.map_or((false, false, false), |p| {
(p.read_messages(), p.send_messages(), p.embed_links())
});
if manage_webhooks && send_messages && embed_links {
HookResult::Continue
} else {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(format!(
"Please ensure the bot has the correct permissions:
{} **View Channel**
{} **Send Message**
{} **Embed Links**
{} **Manage Webhooks**",
if view_channel { "" } else { "" },
if send_messages { "" } else { "" },
if manage_webhooks { "" } else { "" },
if embed_links { "" } else { "" },
)),
)
.await;
HookResult::Halt
}
} else {
HookResult::Continue
}
}
#[check]
pub async fn check_guild_permissions(
ctx: &Context,
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap();
if !permissions.manage_guild() {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(
"You must have the \"Manage Server\" permission to use this command",
),
)
.await;
HookResult::Halt
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
}

View File

@ -1,65 +0,0 @@
use serde::Deserialize;
use serde_json::from_str;
use serenity::prelude::TypeMapKey;
use std::{collections::HashMap, error::Error, sync::Arc};
use crate::consts::LOCAL_LANGUAGE;
#[derive(Deserialize)]
pub struct LanguageManager {
languages: HashMap<String, String>,
strings: HashMap<String, HashMap<String, String>>,
}
impl LanguageManager {
pub fn from_compiled(content: &'static str) -> Result<Self, Box<dyn Error + Send + Sync>> {
let new: Self = from_str(content)?;
Ok(new)
}
pub fn get(&self, language: &str, name: &str) -> &str {
self.strings
.get(language)
.map(|sm| sm.get(name))
.unwrap_or_else(|| panic!(r#"Language does not exist: "{}""#, language))
.unwrap_or_else(|| {
self.strings
.get(&*LOCAL_LANGUAGE)
.map(|sm| {
sm.get(name)
.unwrap_or_else(|| panic!(r#"String does not exist: "{}""#, name))
})
.expect("LOCAL_LANGUAGE is not available")
})
}
pub fn get_language(&self, language: &str) -> Option<&str> {
let language_normal = language.to_lowercase();
self.languages
.iter()
.filter(|(k, v)| {
k.to_lowercase() == language_normal || v.to_lowercase() == language_normal
})
.map(|(k, _)| k.as_str())
.next()
}
pub fn get_language_by_flag(&self, flag: &str) -> Option<&str> {
self.languages
.iter()
.filter(|(k, _)| self.get(k, "flag") == flag)
.map(|(k, _)| k.as_str())
.next()
}
pub fn all_languages(&self) -> impl Iterator<Item = (&str, &str)> {
self.languages.iter().map(|(k, v)| (k.as_str(), v.as_str()))
}
}
impl TypeMapKey for LanguageManager {
type Value = Arc<Self>;
}

View File

@ -1,65 +1,56 @@
#![feature(int_roundings)]
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
mod commands; mod commands;
mod component_models;
mod consts; mod consts;
mod framework; mod framework;
mod language_manager; mod hooks;
mod models; mod models;
mod sender;
mod time_parser; mod time_parser;
use std::{
collections::HashMap,
env,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use chrono_tz::Tz;
use dotenv::dotenv;
use log::info;
use serenity::{ use serenity::{
async_trait, async_trait,
cache::Cache, client::{bridge::gateway::GatewayIntents, Client},
client::Client,
futures::TryFutureExt,
http::{client::Http, CacheHttp}, http::{client::Http, CacheHttp},
model::{ model::{
channel::GuildChannel, channel::GuildChannel,
channel::Message, gateway::{Activity, Ready},
guild::Guild, guild::{Guild, GuildUnavailable},
id::{GuildId, UserId}, id::{GuildId, UserId},
interactions::Interaction, interactions::Interaction,
}, },
prelude::{Context, EventHandler, TypeMapKey}, prelude::{Context, EventHandler, TypeMapKey},
utils::shard_id, utils::shard_id,
}; };
use sqlx::mysql::MySqlPool; use sqlx::mysql::MySqlPool;
use tokio::{
use dotenv::dotenv; sync::RwLock,
time::{Duration, Instant},
use std::{collections::HashMap, env, sync::Arc, time::Instant}; };
use crate::{ use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds}, commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds},
consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR}, component_models::ComponentDataModel,
consts::{CNC_GUILD, REMIND_INTERVAL, SUBSCRIPTION_ROLES, THEME_COLOR},
framework::RegexFramework, framework::RegexFramework,
language_manager::LanguageManager, models::command_macro::CommandMacro,
models::{guild_data::GuildData, user_data::UserData},
}; };
use inflector::Inflector;
use log::info;
use dashmap::DashMap;
use tokio::sync::RwLock;
use chrono::Utc;
use chrono_tz::Tz;
use serenity::model::gateway::GatewayIntents;
use serenity::model::guild::UnavailableGuild;
use serenity::model::prelude::{
InteractionApplicationCommandCallbackDataFlags, InteractionResponseType,
};
struct GuildDataCache;
impl TypeMapKey for GuildDataCache {
type Value = Arc<DashMap<GuildId, Arc<RwLock<GuildData>>>>;
}
struct SQLPool; struct SQLPool;
impl TypeMapKey for SQLPool { impl TypeMapKey for SQLPool {
@ -72,81 +63,54 @@ impl TypeMapKey for ReqwestClient {
type Value = Arc<reqwest::Client>; type Value = Arc<reqwest::Client>;
} }
struct FrameworkCtx;
impl TypeMapKey for FrameworkCtx {
type Value = Arc<RegexFramework>;
}
struct PopularTimezones; struct PopularTimezones;
impl TypeMapKey for PopularTimezones { impl TypeMapKey for PopularTimezones {
type Value = Arc<Vec<Tz>>; type Value = Arc<Vec<Tz>>;
} }
struct CurrentlyExecuting; struct RecordingMacros;
impl TypeMapKey for CurrentlyExecuting { impl TypeMapKey for RecordingMacros {
type Value = Arc<RwLock<HashMap<UserId, Instant>>>; type Value = Arc<RwLock<HashMap<(GuildId, UserId), CommandMacro>>>;
} }
#[async_trait] struct Handler {
trait LimitExecutors { is_loop_running: AtomicBool,
async fn check_executing(&self, user: UserId) -> bool;
async fn set_executing(&self, user: UserId);
async fn drop_executing(&self, user: UserId);
} }
#[async_trait]
impl LimitExecutors for Context {
async fn check_executing(&self, user: UserId) -> bool {
let currently_executing = self
.data
.read()
.await
.get::<CurrentlyExecuting>()
.cloned()
.unwrap();
let lock = currently_executing.read().await;
lock.get(&user)
.map_or(false, |now| now.elapsed().as_secs() < 4)
}
async fn set_executing(&self, user: UserId) {
let currently_executing = self
.data
.read()
.await
.get::<CurrentlyExecuting>()
.cloned()
.unwrap();
let mut lock = currently_executing.write().await;
lock.insert(user, Instant::now());
}
async fn drop_executing(&self, user: UserId) {
let currently_executing = self
.data
.read()
.await
.get::<CurrentlyExecuting>()
.cloned()
.unwrap();
let mut lock = currently_executing.write().await;
lock.remove(&user);
}
}
struct Handler;
#[async_trait] #[async_trait]
impl EventHandler for Handler { impl EventHandler for Handler {
async fn cache_ready(&self, ctx_base: Context, _guilds: Vec<GuildId>) {
info!("Cache Ready!");
info!("Preparing to send reminders");
if !self.is_loop_running.load(Ordering::Relaxed) {
let ctx = ctx_base.clone();
tokio::spawn(async move {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
loop {
let sleep_until = Instant::now() + Duration::from_secs(*REMIND_INTERVAL);
let reminders = sender::Reminder::fetch_reminders(&pool).await;
if reminders.len() > 0 {
info!("Preparing to send {} reminders.", reminders.len());
for reminder in reminders {
reminder.send(pool.clone(), ctx.clone()).await;
}
}
tokio::time::sleep_until(sleep_until).await;
}
});
self.is_loop_running.swap(true, Ordering::Relaxed);
}
}
async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) { async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) {
let pool = ctx let pool = ctx
.data .data
@ -172,19 +136,11 @@ DELETE FROM channels WHERE channel = ?
let guild_id = guild.id.as_u64().to_owned(); let guild_id = guild.id.as_u64().to_owned();
{ {
let pool = ctx let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
GuildData::from_guild(guild, &pool) let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id)
.await .execute(&pool)
.unwrap_or_else(|_| { .await;
panic!("Failed to create new guild object for {}", guild_id)
});
} }
if let Ok(token) = env::var("DISCORDBOTS_TOKEN") { if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
@ -231,118 +187,33 @@ DELETE FROM channels WHERE channel = ?
} }
} }
async fn guild_delete( async fn guild_delete(&self, ctx: Context, incomplete: GuildUnavailable, _full: Option<Guild>) {
&self, let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
ctx: Context, let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0)
deleted_guild: UnavailableGuild, .execute(&pool)
_guild: Option<Guild>, .await;
) { }
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let guild_data_cache = ctx async fn ready(&self, ctx: Context, _: Ready) {
.data ctx.set_activity(Activity::watching("for /remind")).await;
.read()
.await
.get::<GuildDataCache>()
.cloned()
.unwrap();
guild_data_cache.remove(&deleted_guild.id);
sqlx::query!(
"
DELETE FROM guilds WHERE guild = ?
",
deleted_guild.id.as_u64()
)
.execute(&pool)
.await
.unwrap();
} }
async fn interaction_create(&self, ctx: Context, interaction: Interaction) { async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
let (pool, lm) = get_ctx_data(&&ctx).await;
match interaction { match interaction {
Interaction::MessageComponent(interaction) => { Interaction::ApplicationCommand(application_command) => {
if let Some(member) = interaction.clone().member { let framework = ctx
let data = interaction.data.clone(); .data
.read()
.await
.get::<RegexFramework>()
.cloned()
.expect("RegexFramework not found in context");
if data.custom_id.starts_with("timezone:") { framework.execute(ctx, application_command).await;
let mut user_data = UserData::from_user(&member.user, &ctx, &pool) }
.await Interaction::MessageComponent(component) => {
.unwrap(); let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id);
let new_timezone = data.custom_id.replace("timezone:", "").parse::<Tz>(); component_model.act(&ctx, component).await;
if let Ok(timezone) = new_timezone {
user_data.timezone = timezone.to_string();
user_data.commit_changes(&pool).await;
let _ = interaction.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
let footer_text = lm.get(&user_data.language, "timezone/footer").replacen(
"{timezone}",
&user_data.timezone,
1,
);
let now = Utc::now().with_timezone(&user_data.timezone());
let content = lm
.get(&user_data.language, "timezone/set_p")
.replacen("{timezone}", &user_data.timezone, 1)
.replacen(
"{time}",
&now.format("%H:%M").to_string(),
1,
);
d.create_embed(|e| e.title(lm.get(&user_data.language, "timezone/set_p_title"))
.color(*THEME_COLOR)
.description(content)
.footer(|f| f.text(footer_text)))
.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL);
d
})
}).await;
}
} else if data.custom_id.starts_with("lang:") {
let mut user_data = UserData::from_user(&member.user, &ctx, &pool)
.await
.unwrap();
let lang_code = data.custom_id.replace("lang:", "");
if let Some(lang) = lm.get_language(&lang_code) {
user_data.language = lang.to_string();
user_data.commit_changes(&pool).await;
let _ = interaction
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.create_embed(|e| {
e.title(
lm.get(&user_data.language, "lang/set_p_title"),
)
.color(*THEME_COLOR)
.description(
lm.get(&user_data.language, "lang/set_p"),
)
})
.flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
})
.await;
}
}
}
} }
_ => {} _ => {}
} }
@ -357,99 +228,59 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment"); let token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment");
let http = Http::new_with_token(&token); let application_id = {
let http = Http::new_with_token(&token);
let logged_in_id = http http.get_current_application_info().await?.id
.get_current_user() };
.map_ok(|user| user.id.as_u64().to_owned())
.await?;
let application_id = http.get_current_application_info().await?.id;
let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1"); let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1");
let framework = RegexFramework::new(logged_in_id) let framework = RegexFramework::new()
.default_prefix(DEFAULT_PREFIX.clone())
.case_insensitive(env::var("CASE_INSENSITIVE").map_or(true, |var| var == "1"))
.ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1")) .ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
.debug_guild(env::var("DEBUG_GUILD").map_or(None, |g| {
Some(GuildId(g.parse::<u64>().expect("DEBUG_GUILD must be a guild ID")))
}))
.dm_enabled(dm_enabled) .dm_enabled(dm_enabled)
// info commands // info commands
.add_command("ping", &info_cmds::PING_COMMAND) .add_command(&info_cmds::HELP_COMMAND)
.add_command("help", &info_cmds::HELP_COMMAND) .add_command(&info_cmds::INFO_COMMAND)
.add_command("info", &info_cmds::INFO_COMMAND) .add_command(&info_cmds::DONATE_COMMAND)
.add_command("invite", &info_cmds::INFO_COMMAND) .add_command(&info_cmds::DASHBOARD_COMMAND)
.add_command("donate", &info_cmds::DONATE_COMMAND) .add_command(&info_cmds::CLOCK_COMMAND)
.add_command("dashboard", &info_cmds::DASHBOARD_COMMAND)
.add_command("clock", &info_cmds::CLOCK_COMMAND)
// reminder commands // reminder commands
.add_command("timer", &reminder_cmds::TIMER_COMMAND) .add_command(&reminder_cmds::TIMER_COMMAND)
.add_command("remind", &reminder_cmds::REMIND_COMMAND) .add_command(&reminder_cmds::REMIND_COMMAND)
.add_command("r", &reminder_cmds::REMIND_COMMAND)
.add_command("interval", &reminder_cmds::INTERVAL_COMMAND)
.add_command("i", &reminder_cmds::INTERVAL_COMMAND)
.add_command("natural", &reminder_cmds::NATURAL_COMMAND)
.add_command("n", &reminder_cmds::NATURAL_COMMAND)
.add_command("", &reminder_cmds::NATURAL_COMMAND)
.add_command("countdown", &reminder_cmds::COUNTDOWN_COMMAND)
// management commands // management commands
.add_command("look", &reminder_cmds::LOOK_COMMAND) .add_command(&reminder_cmds::DELETE_COMMAND)
.add_command("del", &reminder_cmds::DELETE_COMMAND) .add_command(&reminder_cmds::LOOK_COMMAND)
.add_command(&reminder_cmds::PAUSE_COMMAND)
.add_command(&reminder_cmds::OFFSET_COMMAND)
.add_command(&reminder_cmds::NUDGE_COMMAND)
// to-do commands // to-do commands
.add_command("todo", &todo_cmds::TODO_USER_COMMAND) .add_command(&todo_cmds::TODO_COMMAND)
.add_command("todo user", &todo_cmds::TODO_USER_COMMAND)
.add_command("todoc", &todo_cmds::TODO_CHANNEL_COMMAND)
.add_command("todo channel", &todo_cmds::TODO_CHANNEL_COMMAND)
.add_command("todos", &todo_cmds::TODO_GUILD_COMMAND)
.add_command("todo server", &todo_cmds::TODO_GUILD_COMMAND)
.add_command("todo guild", &todo_cmds::TODO_GUILD_COMMAND)
// moderation commands // moderation commands
.add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND) .add_command(&moderation_cmds::TIMEZONE_COMMAND)
.add_command("restrict", &moderation_cmds::RESTRICT_COMMAND) .add_command(&moderation_cmds::MACRO_CMD_COMMAND)
.add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND) .add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK)
.add_command("prefix", &moderation_cmds::PREFIX_COMMAND) .add_hook(&hooks::MACRO_CHECK_HOOK);
.add_command("lang", &moderation_cmds::LANGUAGE_COMMAND)
.add_command("pause", &reminder_cmds::PAUSE_COMMAND)
.add_command("offset", &reminder_cmds::OFFSET_COMMAND)
.add_command("nudge", &reminder_cmds::NUDGE_COMMAND)
.add_command("alias", &moderation_cmds::ALIAS_COMMAND)
.add_command("a", &moderation_cmds::ALIAS_COMMAND)
.build();
let framework_arc = Arc::new(framework); let framework_arc = Arc::new(framework);
let mut client = Client::builder(&token) let mut client = Client::builder(&token)
.intents(if dm_enabled { .intents(GatewayIntents::GUILDS)
GatewayIntents::GUILD_MESSAGES
| GatewayIntents::GUILDS
| GatewayIntents::GUILD_MESSAGE_REACTIONS
| GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::DIRECT_MESSAGE_REACTIONS
} else {
GatewayIntents::GUILD_MESSAGES
| GatewayIntents::GUILDS
| GatewayIntents::GUILD_MESSAGE_REACTIONS
})
.application_id(application_id.0) .application_id(application_id.0)
.event_handler(Handler) .event_handler(Handler { is_loop_running: AtomicBool::from(false) })
.framework_arc(framework_arc.clone())
.await .await
.expect("Error occurred creating client"); .expect("Error occurred creating client");
{ {
let guild_data_cache = dashmap::DashMap::new();
let pool = MySqlPool::connect( let pool = MySqlPool::connect(
&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment"), &env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment"),
) )
.await .await
.unwrap(); .unwrap();
let language_manager = LanguageManager::from_compiled(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/assets/",
env!("STRINGS_FILE")
)))
.unwrap();
let popular_timezones = sqlx::query!( let popular_timezones = sqlx::query!(
"SELECT timezone FROM users GROUP BY timezone ORDER BY COUNT(timezone) DESC LIMIT 21" "SELECT timezone FROM users GROUP BY timezone ORDER BY COUNT(timezone) DESC LIMIT 21"
) )
@ -462,19 +293,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut data = client.data.write().await; let mut data = client.data.write().await;
data.insert::<GuildDataCache>(Arc::new(guild_data_cache));
data.insert::<CurrentlyExecuting>(Arc::new(RwLock::new(HashMap::new())));
data.insert::<SQLPool>(pool); data.insert::<SQLPool>(pool);
data.insert::<PopularTimezones>(Arc::new(popular_timezones)); data.insert::<PopularTimezones>(Arc::new(popular_timezones));
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new())); data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));
data.insert::<FrameworkCtx>(framework_arc.clone()); data.insert::<RegexFramework>(framework_arc.clone());
data.insert::<LanguageManager>(Arc::new(language_manager)) data.insert::<RecordingMacros>(Arc::new(RwLock::new(HashMap::new())));
} }
framework_arc.build_slash(&client.cache_and_http.http).await;
if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| { if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| {
let mut split = sr let mut split =
.split(',') sr.split(',').map(|val| val.parse::<u64>().expect("SHARD_RANGE not an integer"));
.map(|val| val.parse::<u64>().expect("SHARD_RANGE not an integer"));
(split.next(), split.next()) (split.next(), split.next())
}) { }) {
@ -484,24 +314,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.flatten() .flatten()
.expect("No SHARD_COUNT provided, but SHARD_RANGE was provided"); .expect("No SHARD_COUNT provided, but SHARD_RANGE was provided");
assert!( assert!(lower < upper, "SHARD_RANGE lower limit is not less than the upper limit");
lower < upper,
"SHARD_RANGE lower limit is not less than the upper limit"
);
info!( info!("Starting client fragment with shards {}-{}/{}", lower, upper, total_shards);
"Starting client fragment with shards {}-{}/{}",
lower, upper, total_shards
);
client client.start_shard_range([lower, upper], total_shards).await?;
.start_shard_range([lower, upper], total_shards) } else if let Ok(total_shards) = env::var("SHARD_COUNT")
.await?; .map(|shard_count| shard_count.parse::<u64>().expect("SHARD_COUNT not an integer"))
} else if let Ok(total_shards) = env::var("SHARD_COUNT").map(|shard_count| { {
shard_count
.parse::<u64>()
.expect("SHARD_COUNT not an integer")
}) {
info!("Starting client with {} shards", total_shards); info!("Starting client with {} shards", total_shards);
client.start_shards(total_shards).await?; client.start_shards(total_shards).await?;
@ -516,9 +336,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool { pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD { if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId(subscription_guild) let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await;
.member(cache_http, user_id)
.await;
if let Ok(member) = guild_member { if let Ok(member) = guild_member {
for role in member.roles { for role in member.roles {
@ -534,65 +352,15 @@ pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<U
} }
} }
pub async fn check_subscription_on_message( pub async fn check_guild_subscription(
cache_http: impl CacheHttp + AsRef<Cache>, cache_http: impl CacheHttp,
msg: &Message, guild_id: impl Into<GuildId>,
) -> bool { ) -> bool {
check_subscription(&cache_http, &msg.author).await if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) {
|| if let Some(guild) = msg.guild(&cache_http) { let owner = guild.owner_id;
check_subscription(&cache_http, guild.owner_id).await
} else {
false
}
}
pub async fn get_ctx_data(ctx: &&Context) -> (MySqlPool, Arc<LanguageManager>) { check_subscription(&cache_http, owner).await
let pool; } else {
let lm; false
{
let data = ctx.data.read().await;
pool = data
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool");
lm = data
.get::<LanguageManager>()
.cloned()
.expect("Could not get LanguageManager");
} }
(pool, lm)
}
async fn command_help(
ctx: &Context,
msg: &Message,
lm: Arc<LanguageManager>,
prefix: &str,
language: &str,
command_name: &str,
) {
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title(format!("{} Help", command_name.to_title_case()))
.description(
lm.get(&language, &format!("help/{}", command_name))
.replace("{prefix}", &prefix),
)
.footer(|f| {
f.text(concat!(
env!("CARGO_PKG_NAME"),
" ver ",
env!("CARGO_PKG_VERSION")
))
})
.color(*THEME_COLOR)
})
})
.await;
} }

View File

@ -1,8 +1,6 @@
use serenity::model::channel::Channel;
use sqlx::MySqlPool;
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use serenity::model::channel::Channel;
use sqlx::MySqlPool;
pub struct ChannelData { pub struct ChannelData {
pub id: u32, pub id: u32,
@ -17,51 +15,67 @@ pub struct ChannelData {
impl ChannelData { impl ChannelData {
pub async fn from_channel( pub async fn from_channel(
channel: Channel, channel: &Channel,
pool: &MySqlPool, pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> { ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel.id().as_u64().to_owned(); let channel_id = channel.id().as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self, if let Ok(c) = sqlx::query_as_unchecked!(
Self,
" "
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id) ",
.fetch_one(pool) channel_id
.await { )
.fetch_one(pool)
.await
{
Ok(c) Ok(c)
} } else {
else { let props = channel.to_owned().guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
let (guild_id, channel_name) = if let Some((a, b)) = props { let (guild_id, channel_name) = if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) };
(Some(a), Some(b))
} else {
(None, None)
};
sqlx::query!( sqlx::query!(
" "
INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?)) INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))
", channel_id, channel_name, guild_id) ",
.execute(&pool.clone()) channel_id,
.await?; channel_name,
guild_id
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as_unchecked!(Self, Ok(sqlx::query_as_unchecked!(
Self,
" "
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ? SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id) ",
.fetch_one(pool) channel_id
.await?) )
.fetch_one(pool)
.await?)
} }
} }
pub async fn commit_changes(&self, pool: &MySqlPool) { pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!( sqlx::query!(
" "
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ? UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until \
", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id) = ? WHERE id = ?
.execute(pool) ",
.await.unwrap(); self.name,
self.nudge,
self.blacklisted,
self.webhook_id,
self.webhook_token,
self.paused,
self.paused_until,
self.id
)
.execute(pool)
.await
.unwrap();
} }
} }

View File

@ -0,0 +1,33 @@
use serenity::{client::Context, model::id::GuildId};
use crate::{framework::CommandOptions, SQLPool};
pub struct CommandMacro {
pub guild_id: GuildId,
pub name: String,
pub description: Option<String>,
pub commands: Vec<CommandOptions>,
}
impl CommandMacro {
pub async fn from_guild(ctx: &Context, guild_id: impl Into<GuildId>) -> Vec<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let guild_id = guild_id.into();
sqlx::query!(
"SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
guild_id.0
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| Self {
guild_id,
name: row.name.clone(),
description: row.description.clone(),
commands: serde_json::from_str(&row.commands).unwrap(),
})
.collect::<Vec<Self>>()
}
}

View File

@ -1,79 +0,0 @@
use serenity::model::guild::Guild;
use sqlx::MySqlPool;
use log::error;
use crate::consts::DEFAULT_PREFIX;
pub struct GuildData {
pub id: u32,
pub name: Option<String>,
pub prefix: String,
}
impl GuildData {
pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result<Self, sqlx::Error> {
let guild_id = guild.id.as_u64().to_owned();
match sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
",
guild_id
)
.fetch_one(pool)
.await
{
Ok(mut g) => {
g.name = Some(guild.name);
Ok(g)
}
Err(sqlx::Error::RowNotFound) => {
sqlx::query!(
"
INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?)
",
guild_id,
guild.name,
*DEFAULT_PREFIX
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
",
guild_id
)
.fetch_one(pool)
.await?)
}
Err(e) => {
error!("Unexpected error in guild query: {:?}", e);
Err(e)
}
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE guilds SET name = ?, prefix = ? WHERE id = ?
",
self.name,
self.prefix,
self.id
)
.execute(pool)
.await
.unwrap();
}
}

View File

@ -1,78 +1,66 @@
pub mod channel_data; pub mod channel_data;
pub mod guild_data; pub mod command_macro;
pub mod reminder; pub mod reminder;
pub mod timer; pub mod timer;
pub mod user_data; pub mod user_data;
use serenity::{async_trait, model::id::GuildId, prelude::Context}; use chrono_tz::Tz;
use serenity::{
async_trait,
model::id::{ChannelId, UserId},
prelude::Context,
};
use crate::{consts::DEFAULT_PREFIX, GuildDataCache, SQLPool}; use crate::{
models::{channel_data::ChannelData, user_data::UserData},
use guild_data::GuildData; SQLPool,
};
use std::sync::Arc;
use tokio::sync::RwLock;
#[async_trait] #[async_trait]
pub trait CtxGuildData { pub trait CtxData {
async fn guild_data<G: Into<GuildId> + Send + Sync>( async fn user_data<U: Into<UserId> + Send + Sync>(
&self, &self,
guild_id: G, user_id: U,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error>; ) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String; async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz;
async fn channel_data<C: Into<ChannelId> + Send + Sync>(
&self,
channel_id: C,
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>>;
} }
#[async_trait] #[async_trait]
impl CtxGuildData for Context { impl CtxData for Context {
async fn guild_data<G: Into<GuildId> + Send + Sync>( async fn user_data<U: Into<UserId> + Send + Sync>(
&self, &self,
guild_id: G, user_id: U,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error> { ) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> {
let guild_id = guild_id.into(); let user_id = user_id.into();
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
let guild = guild_id.to_guild_cached(&self.cache).unwrap(); let user = user_id.to_user(self).await.unwrap();
let guild_cache = self UserData::from_user(&user, &self, &pool).await
.data
.read()
.await
.get::<GuildDataCache>()
.cloned()
.unwrap();
let x = if let Some(guild_data) = guild_cache.get(&guild_id) {
Ok(guild_data.clone())
} else {
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
match GuildData::from_guild(guild, &pool).await {
Ok(d) => {
let lock = Arc::new(RwLock::new(d));
guild_cache.insert(guild_id, lock.clone());
Ok(lock)
}
Err(e) => Err(e),
}
};
x
} }
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String { async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz {
if let Some(guild_id) = guild_id { let user_id = user_id.into();
self.guild_data(guild_id) let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
.await
.unwrap() UserData::timezone_of(user_id, &pool).await
.read() }
.await
.prefix async fn channel_data<C: Into<ChannelId> + Send + Sync>(
.clone() &self,
} else { channel_id: C,
DEFAULT_PREFIX.clone() ) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>> {
} let channel_id = channel_id.into();
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
let channel = channel_id.to_channel_cached(&self).unwrap();
ChannelData::from_channel(&channel, &pool).await
} }
} }

View File

@ -0,0 +1,305 @@
use std::{collections::HashSet, fmt::Display};
use chrono::{Duration, NaiveDateTime, Utc};
use chrono_tz::Tz;
use serenity::{
client::Context,
http::CacheHttp,
model::{
channel::GuildChannel,
id::{ChannelId, GuildId, UserId},
webhook::Webhook,
},
Result as SerenityResult,
};
use sqlx::MySqlPool;
use crate::{
consts,
consts::{MAX_TIME, MIN_INTERVAL},
models::{
channel_data::ChannelData,
reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
user_data::UserData,
},
SQLPool,
};
async fn create_webhook(
ctx: impl CacheHttp,
channel: GuildChannel,
name: impl Display,
) -> SerenityResult<Webhook> {
channel.create_webhook_with_avatar(ctx.http(), name, consts::DEFAULT_AVATAR.clone()).await
}
#[derive(Hash, PartialEq, Eq)]
pub enum ReminderScope {
User(u64),
Channel(u64),
}
impl ReminderScope {
pub fn mention(&self) -> String {
match self {
Self::User(id) => format!("<@{}>", id),
Self::Channel(id) => format!("<#{}>", id),
}
}
}
pub struct ReminderBuilder {
pool: MySqlPool,
uid: String,
channel: u32,
utc_time: NaiveDateTime,
timezone: String,
interval: Option<i64>,
expires: Option<NaiveDateTime>,
content: String,
tts: bool,
attachment_name: Option<String>,
attachment: Option<Vec<u8>>,
set_by: Option<u32>,
}
impl ReminderBuilder {
pub async fn build(self) -> Result<Reminder, ReminderError> {
let queried_time = sqlx::query!(
"SELECT DATE_ADD(?, INTERVAL (SELECT nudge FROM channels WHERE id = ?) SECOND) AS `utc_time`",
self.utc_time,
self.channel,
)
.fetch_one(&self.pool)
.await
.unwrap();
match queried_time.utc_time {
Some(utc_time) => {
if utc_time < (Utc::now() - Duration::seconds(60)).naive_local() {
Err(ReminderError::PastTime)
} else {
sqlx::query!(
"
INSERT INTO reminders (
`uid`,
`channel_id`,
`utc_time`,
`timezone`,
`interval`,
`expires`,
`content`,
`tts`,
`attachment_name`,
`attachment`,
`set_by`
) VALUES (
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?
)
",
self.uid,
self.channel,
utc_time,
self.timezone,
self.interval,
self.expires,
self.content,
self.tts,
self.attachment_name,
self.attachment,
self.set_by
)
.execute(&self.pool)
.await
.unwrap();
Ok(Reminder::from_uid(&self.pool, self.uid).await.unwrap())
}
}
None => Err(ReminderError::LongTime),
}
}
}
pub struct MultiReminderBuilder<'a> {
scopes: Vec<ReminderScope>,
utc_time: NaiveDateTime,
timezone: Tz,
interval: Option<i64>,
expires: Option<NaiveDateTime>,
content: Content,
set_by: Option<u32>,
ctx: &'a Context,
guild_id: Option<GuildId>,
}
impl<'a> MultiReminderBuilder<'a> {
pub fn new(ctx: &'a Context, guild_id: Option<GuildId>) -> Self {
MultiReminderBuilder {
scopes: vec![],
utc_time: Utc::now().naive_utc(),
timezone: Tz::UTC,
interval: None,
expires: None,
content: Content::new(),
set_by: None,
ctx,
guild_id,
}
}
pub fn content(mut self, content: Content) -> Self {
self.content = content;
self
}
pub fn time<T: Into<i64>>(mut self, time: T) -> Self {
self.utc_time = NaiveDateTime::from_timestamp(time.into(), 0);
self
}
pub fn expires<T: Into<i64>>(mut self, time: Option<T>) -> Self {
if let Some(t) = time {
self.expires = Some(NaiveDateTime::from_timestamp(t.into(), 0));
} else {
self.expires = None;
}
self
}
pub fn author(mut self, user: UserData) -> Self {
self.set_by = Some(user.id);
self.timezone = user.timezone();
self
}
pub fn interval(mut self, interval: Option<i64>) -> Self {
self.interval = interval;
self
}
pub fn set_scopes(&mut self, scopes: Vec<ReminderScope>) {
self.scopes = scopes;
}
pub async fn build(self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) {
let pool = self.ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let mut errors = HashSet::new();
let mut ok_locs = HashSet::new();
if self.interval.map_or(false, |i| (i as i64) < *MIN_INTERVAL) {
errors.insert(ReminderError::ShortInterval);
} else if self.interval.map_or(false, |i| (i as i64) > *MAX_TIME) {
errors.insert(ReminderError::LongInterval);
} else {
for scope in self.scopes {
let db_channel_id = match scope {
ReminderScope::User(user_id) => {
if let Ok(user) = UserId(user_id).to_user(&self.ctx).await {
let user_data =
UserData::from_user(&user, &self.ctx, &pool).await.unwrap();
if let Some(guild_id) = self.guild_id {
if guild_id.member(&self.ctx, user).await.is_err() {
Err(ReminderError::InvalidTag)
} else {
Ok(user_data.dm_channel)
}
} else {
Ok(user_data.dm_channel)
}
} else {
Err(ReminderError::InvalidTag)
}
}
ReminderScope::Channel(channel_id) => {
let channel = ChannelId(channel_id).to_channel(&self.ctx).await.unwrap();
if let Some(guild_channel) = channel.clone().guild() {
if Some(guild_channel.guild_id) != self.guild_id {
Err(ReminderError::InvalidTag)
} else {
let mut channel_data =
ChannelData::from_channel(&channel, &pool).await.unwrap();
if channel_data.webhook_id.is_none()
|| channel_data.webhook_token.is_none()
{
match create_webhook(&self.ctx, guild_channel, "Reminder").await
{
Ok(webhook) => {
channel_data.webhook_id =
Some(webhook.id.as_u64().to_owned());
channel_data.webhook_token = webhook.token;
channel_data.commit_changes(&pool).await;
Ok(channel_data.id)
}
Err(e) => Err(ReminderError::DiscordError(e.to_string())),
}
} else {
Ok(channel_data.id)
}
}
} else {
Err(ReminderError::InvalidTag)
}
}
};
match db_channel_id {
Ok(c) => {
let builder = ReminderBuilder {
pool: pool.clone(),
uid: generate_uid(),
channel: c,
utc_time: self.utc_time,
timezone: self.timezone.to_string(),
interval: self.interval,
expires: self.expires,
content: self.content.content.clone(),
tts: self.content.tts,
attachment_name: self.content.attachment_name.clone(),
attachment: self.content.attachment.clone(),
set_by: self.set_by,
};
match builder.build().await {
Ok(_) => {
ok_locs.insert(scope);
}
Err(e) => {
errors.insert(e);
}
}
}
Err(e) => {
errors.insert(e);
}
}
}
}
(errors, ok_locs)
}
}

View File

@ -0,0 +1,12 @@
pub struct Content {
pub content: String,
pub tts: bool,
pub attachment: Option<Vec<u8>>,
pub attachment_name: Option<String>,
}
impl Content {
pub fn new() -> Self {
Self { content: "".to_string(), tts: false, attachment: None, attachment_name: None }
}
}

View File

@ -0,0 +1,36 @@
use crate::consts::{MAX_TIME, MIN_INTERVAL};
#[derive(PartialEq, Eq, Hash, Debug)]
pub enum ReminderError {
LongTime,
LongInterval,
PastTime,
ShortInterval,
InvalidTag,
DiscordError(String),
}
impl ToString for ReminderError {
fn to_string(&self) -> String {
match self {
ReminderError::LongTime => {
"That time is too far in the future. Please specify a shorter time.".to_string()
}
ReminderError::LongInterval => format!(
"Please ensure the interval specified is less than {max_time} days",
max_time = *MAX_TIME / 86_400
),
ReminderError::PastTime => {
"Please ensure the time provided is in the future. If the time should be in the future, please be more specific with the definition.".to_string()
}
ReminderError::ShortInterval => format!(
"Please ensure the interval provided is longer than {min_interval} seconds",
min_interval = *MIN_INTERVAL
),
ReminderError::InvalidTag => {
"Couldn't find a location by your tag. Your tag must be either a channel or a user (not a role)".to_string()
}
ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s),
}
}
}

View File

@ -0,0 +1,31 @@
use num_integer::Integer;
use rand::{rngs::OsRng, seq::IteratorRandom};
use crate::consts::{CHARACTERS, DAY, HOUR, MINUTE};
pub fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let mut sections = vec![];
for (var, name) in
[days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter())
{
if *var > 0 {
sections.push(format!("{} {}", var, name));
}
}
sections.join(", ")
}
pub fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
(0..64)
.map(|_| CHARACTERS.chars().choose(&mut generator).unwrap().to_owned().to_string())
.collect::<Vec<String>>()
.join("")
}

View File

@ -0,0 +1,23 @@
use serde::{Deserialize, Serialize};
use serde_repr::*;
use serenity::model::id::ChannelId;
#[derive(Serialize_repr, Deserialize_repr, Copy, Clone, Debug)]
#[repr(u8)]
pub enum TimeDisplayType {
Absolute = 0,
Relative = 1,
}
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
pub struct LookFlags {
pub show_disabled: bool,
pub channel_id: Option<ChannelId>,
pub time_display: TimeDisplayType,
}
impl Default for LookFlags {
fn default() -> Self {
Self { show_disabled: true, channel_id: None, time_display: TimeDisplayType::Relative }
}
}

View File

@ -1,43 +1,32 @@
pub mod builder;
pub mod content;
pub mod errors;
mod helper;
pub mod look_flags;
use chrono::{NaiveDateTime, TimeZone};
use chrono_tz::Tz;
use serenity::{ use serenity::{
client::Context, client::Context,
model::id::{ChannelId, GuildId, UserId}, model::id::{ChannelId, GuildId, UserId},
}; };
use sqlx::MySqlPool;
use chrono::NaiveDateTime;
use crate::{ use crate::{
consts::{DAY, HOUR, MINUTE, REGEX_CHANNEL}, models::reminder::{
helper::longhand_displacement,
look_flags::{LookFlags, TimeDisplayType},
},
SQLPool, SQLPool,
}; };
use num_integer::Integer; #[derive(Debug, Clone)]
fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let mut sections = vec![];
for (var, name) in [days, hours, minutes, seconds]
.iter()
.zip(["days", "hours", "minutes", "seconds"].iter())
{
if *var > 0 {
sections.push(format!("{} {}", var, name));
}
}
sections.join(", ")
}
#[derive(Debug)]
pub struct Reminder { pub struct Reminder {
pub id: u32, pub id: u32,
pub uid: String, pub uid: String,
pub channel: u64, pub channel: u64,
pub utc_time: NaiveDateTime, pub utc_time: NaiveDateTime,
pub interval_seconds: Option<u32>, pub interval: Option<u32>,
pub expires: Option<NaiveDateTime>, pub expires: Option<NaiveDateTime>,
pub enabled: bool, pub enabled: bool,
pub content: String, pub content: String,
@ -46,9 +35,7 @@ pub struct Reminder {
} }
impl Reminder { impl Reminder {
pub async fn from_uid(ctx: &Context, uid: String) -> Option<Self> { pub async fn from_uid(pool: &MySqlPool, uid: String) -> Option<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
sqlx::query_as_unchecked!( sqlx::query_as_unchecked!(
Self, Self,
" "
@ -57,7 +44,7 @@ SELECT
reminders.uid, reminders.uid,
channels.channel, channels.channel,
reminders.utc_time, reminders.utc_time,
reminders.interval_seconds, reminders.interval,
reminders.expires, reminders.expires,
reminders.enabled, reminders.enabled,
reminders.content, reminders.content,
@ -78,7 +65,7 @@ WHERE
", ",
uid uid
) )
.fetch_one(&pool) .fetch_one(pool)
.await .await
.ok() .ok()
} }
@ -101,7 +88,7 @@ SELECT
reminders.uid, reminders.uid,
channels.channel, channels.channel,
reminders.utc_time, reminders.utc_time,
reminders.interval_seconds, reminders.interval,
reminders.expires, reminders.expires,
reminders.enabled, reminders.enabled,
reminders.content, reminders.content,
@ -122,12 +109,9 @@ WHERE
FIND_IN_SET(reminders.enabled, ?) FIND_IN_SET(reminders.enabled, ?)
ORDER BY ORDER BY
reminders.utc_time reminders.utc_time
LIMIT
?
", ",
channel_id.as_u64(), channel_id.as_u64(),
enabled, enabled,
flags.limit
) )
.fetch_all(&pool) .fetch_all(&pool)
.await .await
@ -157,7 +141,7 @@ SELECT
reminders.uid, reminders.uid,
channels.channel, channels.channel,
reminders.utc_time, reminders.utc_time,
reminders.interval_seconds, reminders.interval,
reminders.expires, reminders.expires,
reminders.enabled, reminders.enabled,
reminders.content, reminders.content,
@ -189,7 +173,7 @@ SELECT
reminders.uid, reminders.uid,
channels.channel, channels.channel,
reminders.utc_time, reminders.utc_time,
reminders.interval_seconds, reminders.interval,
reminders.expires, reminders.expires,
reminders.enabled, reminders.enabled,
reminders.content, reminders.content,
@ -222,7 +206,7 @@ SELECT
reminders.uid, reminders.uid,
channels.channel, channels.channel,
reminders.utc_time, reminders.utc_time,
reminders.interval_seconds, reminders.interval,
reminders.expires, reminders.expires,
reminders.enabled, reminders.enabled,
reminders.content, reminders.content,
@ -257,90 +241,44 @@ WHERE
} }
} }
pub fn display(&self, flags: &LookFlags, inter: &str) -> String { pub fn display_del(&self, count: usize, timezone: &Tz) -> String {
format!(
"**{}**: '{}' *<#{}>* at **{}**",
count + 1,
self.display_content(),
self.channel,
timezone
.timestamp(self.utc_time.timestamp(), 0)
.format("%Y-%m-%d %H:%M:%S")
.to_string()
)
}
pub fn display(&self, flags: &LookFlags, timezone: &Tz) -> String {
let time_display = match flags.time_display { let time_display = match flags.time_display {
TimeDisplayType::Absolute => format!("<t:{}>", self.utc_time.timestamp()), TimeDisplayType::Absolute => timezone
.timestamp(self.utc_time.timestamp(), 0)
.format("%Y-%m-%d %H:%M:%S")
.to_string(),
TimeDisplayType::Relative => format!("<t:{}:R>", self.utc_time.timestamp()), TimeDisplayType::Relative => format!("<t:{}:R>", self.utc_time.timestamp()),
}; };
if let Some(interval) = self.interval_seconds { if let Some(interval) = self.interval {
format!( format!(
"'{}' *{}* **{}**, repeating every **{}** (set by {})", "'{}' *occurs next at* **{}**, repeating every **{}** (set by {})",
self.display_content(), self.display_content(),
&inter,
time_display, time_display,
longhand_displacement(interval as u64), longhand_displacement(interval as u64),
self.set_by self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
.map(|i| format!("<@{}>", i))
.unwrap_or_else(|| "unknown".to_string())
) )
} else { } else {
format!( format!(
"'{}' *{}* **{}** (set by {})", "'{}' *occurs next at* **{}** (set by {})",
self.display_content(), self.display_content(),
&inter,
time_display, time_display,
self.set_by self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
.map(|i| format!("<@{}>", i))
.unwrap_or_else(|| "unknown".to_string())
) )
} }
} }
} }
enum TimeDisplayType {
Absolute,
Relative,
}
pub struct LookFlags {
pub limit: u16,
pub show_disabled: bool,
pub channel_id: Option<ChannelId>,
time_display: TimeDisplayType,
}
impl Default for LookFlags {
fn default() -> Self {
Self {
limit: u16::MAX,
show_disabled: true,
channel_id: None,
time_display: TimeDisplayType::Relative,
}
}
}
impl LookFlags {
pub fn from_string(args: &str) -> Self {
let mut new_flags: Self = Default::default();
for arg in args.split(' ') {
match arg {
"enabled" => {
new_flags.show_disabled = false;
}
"time" => {
new_flags.time_display = TimeDisplayType::Absolute;
}
param => {
if let Ok(val) = param.parse::<u16>() {
new_flags.limit = val;
} else if let Some(channel) = REGEX_CHANNEL
.captures(&arg)
.map(|cap| cap.get(1))
.flatten()
.map(|c| c.as_str().parse::<u64>().unwrap())
{
new_flags.channel_id = Some(ChannelId(channel));
}
}
}
}
new_flags
}
}

View File

@ -1,6 +1,5 @@
use sqlx::MySqlPool;
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use sqlx::MySqlPool;
pub struct Timer { pub struct Timer {
pub name: String, pub name: String,

View File

@ -1,47 +1,22 @@
use chrono_tz::Tz;
use log::error;
use serenity::{ use serenity::{
http::CacheHttp, http::CacheHttp,
model::{id::UserId, user::User}, model::{id::UserId, user::User},
}; };
use sqlx::MySqlPool; use sqlx::MySqlPool;
use chrono_tz::Tz; use crate::consts::LOCAL_TIMEZONE;
use log::error;
use crate::consts::{LOCAL_LANGUAGE, LOCAL_TIMEZONE};
pub struct UserData { pub struct UserData {
pub id: u32, pub id: u32,
pub user: u64, pub user: u64,
pub name: String, pub name: String,
pub dm_channel: u32, pub dm_channel: u32,
pub language: String,
pub timezone: String, pub timezone: String,
} }
impl UserData { impl UserData {
pub async fn language_of<U>(user: U, pool: &MySqlPool) -> String
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT language FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => r.language,
Err(_) => LOCAL_LANGUAGE.clone(),
}
}
pub async fn timezone_of<U>(user: U, pool: &MySqlPool) -> Tz pub async fn timezone_of<U>(user: U, pool: &MySqlPool) -> Tz
where where
U: Into<UserId>, U: Into<UserId>,
@ -75,9 +50,10 @@ SELECT timezone FROM users WHERE user = ?
match sqlx::query_as_unchecked!( match sqlx::query_as_unchecked!(
Self, Self,
" "
SELECT id, user, name, dm_channel, IF(language IS NULL, ?, language) AS language, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ? SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ?
", ",
*LOCAL_LANGUAGE, *LOCAL_TIMEZONE, user_id *LOCAL_TIMEZONE,
user_id
) )
.fetch_one(pool) .fetch_one(pool)
.await .await
@ -101,15 +77,20 @@ INSERT IGNORE INTO channels (channel) VALUES (?)
sqlx::query!( sqlx::query!(
" "
INSERT INTO users (user, name, dm_channel, language, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?, ?) INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?)
", user_id, user.name, dm_id, *LOCAL_LANGUAGE, *LOCAL_TIMEZONE) ",
.execute(&pool_c) user_id,
.await?; user.name,
dm_id,
*LOCAL_TIMEZONE
)
.execute(&pool_c)
.await?;
Ok(sqlx::query_as_unchecked!( Ok(sqlx::query_as_unchecked!(
Self, Self,
" "
SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ? SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ?
", ",
user_id user_id
) )
@ -121,17 +102,16 @@ SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
error!("Error querying for user: {:?}", e); error!("Error querying for user: {:?}", e);
Err(Box::new(e)) Err(Box::new(e))
}, }
} }
} }
pub async fn commit_changes(&self, pool: &MySqlPool) { pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!( sqlx::query!(
" "
UPDATE users SET name = ?, language = ?, timezone = ? WHERE id = ? UPDATE users SET name = ?, timezone = ? WHERE id = ?
", ",
self.name, self.name,
self.language,
self.timezone, self.timezone,
self.id self.id
) )

552
src/sender.rs Normal file
View File

@ -0,0 +1,552 @@
use chrono::Duration;
use chrono_tz::Tz;
use log::{error, info, warn};
use num_integer::Integer;
use regex::{Captures, Regex};
use serenity::{
builder::CreateEmbed,
http::{CacheHttp, Http, StatusCode},
model::{
channel::{Channel, Embed as SerenityEmbed},
id::ChannelId,
webhook::Webhook,
},
Error, Result,
};
use sqlx::{
types::chrono::{NaiveDateTime, Utc},
MySqlPool,
};
lazy_static! {
pub static ref TIMEFROM_REGEX: Regex =
Regex::new(r#"<<timefrom:(?P<time>\d+):(?P<format>.+)?>>"#).unwrap();
pub static ref TIMENOW_REGEX: Regex =
Regex::new(r#"<<timenow:(?P<timezone>(?:\w|/|_)+):(?P<format>.+)?>>"#).unwrap();
}
fn fmt_displacement(format: &str, seconds: u64) -> String {
let mut seconds = seconds;
let mut days: u64 = 0;
let mut hours: u64 = 0;
let mut minutes: u64 = 0;
for (rep, time_type, div) in
[("%d", &mut days, 86400), ("%h", &mut hours, 3600), ("%m", &mut minutes, 60)].iter_mut()
{
if format.contains(*rep) {
let (divided, new_seconds) = seconds.div_rem(&div);
**time_type = divided;
seconds = new_seconds;
}
}
format
.replace("%s", &seconds.to_string())
.replace("%m", &minutes.to_string())
.replace("%h", &hours.to_string())
.replace("%d", &days.to_string())
}
pub fn substitute(string: &str) -> String {
let new = TIMEFROM_REGEX.replace(string, |caps: &Captures| {
let final_time = caps.name("time").unwrap().as_str();
let format = caps.name("format").unwrap().as_str();
if let Ok(final_time) = final_time.parse::<i64>() {
let dt = NaiveDateTime::from_timestamp(final_time, 0);
let now = Utc::now().naive_utc();
let difference = {
if now < dt {
dt - Utc::now().naive_utc()
} else {
Utc::now().naive_utc() - dt
}
};
fmt_displacement(format, difference.num_seconds() as u64)
} else {
String::new()
}
});
TIMENOW_REGEX
.replace(&new, |caps: &Captures| {
let timezone = caps.name("timezone").unwrap().as_str();
println!("{}", timezone);
if let Ok(tz) = timezone.parse::<Tz>() {
let format = caps.name("format").unwrap().as_str();
let now = Utc::now().with_timezone(&tz);
now.format(format).to_string()
} else {
String::new()
}
})
.to_string()
}
struct Embed {
inner: EmbedInner,
fields: Vec<EmbedField>,
}
struct EmbedInner {
title: String,
description: String,
image_url: Option<String>,
thumbnail_url: Option<String>,
footer: String,
footer_url: Option<String>,
author: String,
author_url: Option<String>,
color: u32,
}
struct EmbedField {
title: String,
value: String,
inline: bool,
}
impl Embed {
pub async fn from_id(pool: &MySqlPool, id: u32) -> Option<Self> {
let mut inner = sqlx::query_as_unchecked!(
EmbedInner,
"
SELECT
`embed_title` AS title,
`embed_description` AS description,
`embed_image_url` AS image_url,
`embed_thumbnail_url` AS thumbnail_url,
`embed_footer` AS footer,
`embed_footer_url` AS footer_url,
`embed_author` AS author,
`embed_author_url` AS author_url,
`embed_color` AS color
FROM
reminders
WHERE
`id` = ?
",
id
)
.fetch_one(&pool.clone())
.await
.unwrap();
inner.title = substitute(&inner.title);
inner.description = substitute(&inner.description);
inner.footer = substitute(&inner.footer);
let mut fields = sqlx::query_as_unchecked!(
EmbedField,
"
SELECT
title,
value,
inline
FROM
embed_fields
WHERE
reminder_id = ?
",
id
)
.fetch_all(pool)
.await
.unwrap();
fields.iter_mut().for_each(|mut field| {
field.title = substitute(&field.title);
field.value = substitute(&field.value);
});
let e = Embed { inner, fields };
if e.has_content() {
Some(e)
} else {
None
}
}
pub fn has_content(&self) -> bool {
if self.inner.title.is_empty()
&& self.inner.description.is_empty()
&& self.inner.image_url.is_none()
&& self.inner.thumbnail_url.is_none()
&& self.inner.footer.is_empty()
&& self.inner.footer_url.is_none()
&& self.inner.author.is_empty()
&& self.inner.author_url.is_none()
&& self.fields.is_empty()
{
false
} else {
true
}
}
}
impl Into<CreateEmbed> for Embed {
fn into(self) -> CreateEmbed {
let mut c = CreateEmbed::default();
c.title(&self.inner.title)
.description(&self.inner.description)
.color(self.inner.color)
.author(|a| {
a.name(&self.inner.author);
if let Some(author_icon) = &self.inner.author_url {
a.icon_url(author_icon);
}
a
})
.footer(|f| {
f.text(&self.inner.footer);
if let Some(footer_icon) = &self.inner.footer_url {
f.icon_url(footer_icon);
}
f
});
for field in &self.fields {
c.field(&field.title, &field.value, field.inline);
}
if let Some(image_url) = &self.inner.image_url {
c.image(image_url);
}
if let Some(thumbnail_url) = &self.inner.thumbnail_url {
c.thumbnail(thumbnail_url);
}
c
}
}
#[derive(Debug)]
pub struct Reminder {
id: u32,
channel_id: u64,
webhook_id: Option<u64>,
webhook_token: Option<String>,
channel_paused: bool,
channel_paused_until: Option<NaiveDateTime>,
enabled: bool,
tts: bool,
pin: bool,
content: String,
attachment: Option<Vec<u8>>,
attachment_name: Option<String>,
utc_time: NaiveDateTime,
timezone: String,
restartable: bool,
expires: Option<NaiveDateTime>,
interval: Option<u32>,
avatar: Option<String>,
username: Option<String>,
}
impl Reminder {
pub async fn fetch_reminders(pool: &MySqlPool) -> Vec<Self> {
sqlx::query_as_unchecked!(
Reminder,
"
SELECT
reminders.`id` AS id,
channels.`channel` AS channel_id,
channels.`webhook_id` AS webhook_id,
channels.`webhook_token` AS webhook_token,
channels.`paused` AS channel_paused,
channels.`paused_until` AS channel_paused_until,
reminders.`enabled` AS enabled,
reminders.`tts` AS tts,
reminders.`pin` AS pin,
reminders.`content` AS content,
reminders.`attachment` AS attachment,
reminders.`attachment_name` AS attachment_name,
reminders.`utc_time` AS 'utc_time',
reminders.`timezone` AS timezone,
reminders.`restartable` AS restartable,
reminders.`expires` AS expires,
reminders.`interval` AS 'interval',
reminders.`avatar` AS avatar,
reminders.`username` AS username
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
WHERE
reminders.`utc_time` < NOW()
",
)
.fetch_all(pool)
.await
.unwrap()
.into_iter()
.map(|mut rem| {
rem.content = substitute(&rem.content);
rem
})
.collect::<Vec<Self>>()
}
async fn reset_webhook(&self, pool: &MySqlPool) {
let _ = sqlx::query!(
"
UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ?
",
self.channel_id
)
.execute(pool)
.await;
}
async fn refresh(&self, pool: &MySqlPool) {
if let Some(interval) = self.interval {
let now = Utc::now().naive_local();
let mut updated_reminder_time = self.utc_time;
while updated_reminder_time < now {
updated_reminder_time += Duration::seconds(interval as i64);
}
if self.expires.map_or(false, |expires| {
NaiveDateTime::from_timestamp(updated_reminder_time.timestamp(), 0) > expires
}) {
self.force_delete(pool).await;
} else {
sqlx::query!(
"
UPDATE reminders SET `utc_time` = ? WHERE `id` = ?
",
updated_reminder_time,
self.id
)
.execute(pool)
.await
.expect(&format!("Could not update time on Reminder {}", self.id));
}
} else {
self.force_delete(pool).await;
}
}
async fn force_delete(&self, pool: &MySqlPool) {
sqlx::query!(
"
DELETE FROM reminders WHERE `id` = ?
",
self.id
)
.execute(pool)
.await
.expect(&format!("Could not delete Reminder {}", self.id));
}
async fn pin_message<M: Into<u64>>(&self, message_id: M, http: impl AsRef<Http>) {
let _ = http.as_ref().pin_message(self.channel_id, message_id.into(), None).await;
}
pub async fn send(&self, pool: MySqlPool, cache_http: impl CacheHttp) {
async fn send_to_channel(
cache_http: impl CacheHttp,
reminder: &Reminder,
embed: Option<CreateEmbed>,
) -> Result<()> {
let channel = ChannelId(reminder.channel_id).to_channel(&cache_http).await;
match channel {
Ok(Channel::Guild(channel)) => {
match channel
.send_message(&cache_http, |m| {
m.content(&reminder.content).tts(reminder.tts);
if let (Some(attachment), Some(name)) =
(&reminder.attachment, &reminder.attachment_name)
{
m.add_file((attachment as &[u8], name.as_str()));
}
if let Some(embed) = embed {
m.set_embed(embed);
}
m
})
.await
{
Ok(m) => {
if reminder.pin {
reminder.pin_message(m.id, cache_http.http()).await;
}
Ok(())
}
Err(e) => Err(e),
}
}
Ok(Channel::Private(channel)) => {
match channel
.send_message(&cache_http.http(), |m| {
m.content(&reminder.content).tts(reminder.tts);
if let (Some(attachment), Some(name)) =
(&reminder.attachment, &reminder.attachment_name)
{
m.add_file((attachment as &[u8], name.as_str()));
}
if let Some(embed) = embed {
m.set_embed(embed);
}
m
})
.await
{
Ok(m) => {
if reminder.pin {
reminder.pin_message(m.id, cache_http.http()).await;
}
Ok(())
}
Err(e) => Err(e),
}
}
Err(e) => Err(e),
_ => Err(Error::Other("Channel not of valid type")),
}
}
async fn send_to_webhook(
cache_http: impl CacheHttp,
reminder: &Reminder,
webhook: Webhook,
embed: Option<CreateEmbed>,
) -> Result<()> {
match webhook
.execute(&cache_http.http(), reminder.pin || reminder.restartable, |w| {
w.content(&reminder.content).tts(reminder.tts);
if let Some(username) = &reminder.username {
w.username(username);
}
if let Some(avatar) = &reminder.avatar {
w.avatar_url(avatar);
}
if let (Some(attachment), Some(name)) =
(&reminder.attachment, &reminder.attachment_name)
{
w.add_file((attachment as &[u8], name.as_str()));
}
if let Some(embed) = embed {
w.embeds(vec![SerenityEmbed::fake(|c| {
*c = embed;
c
})]);
}
w
})
.await
{
Ok(m) => {
if reminder.pin {
if let Some(message) = m {
reminder.pin_message(message.id, cache_http.http()).await;
}
}
Ok(())
}
Err(e) => Err(e),
}
}
if self.enabled
&& !(self.channel_paused
&& self
.channel_paused_until
.map_or(true, |inner| inner >= Utc::now().naive_local()))
{
let _ = sqlx::query!(
"
UPDATE `channels` SET paused = 0, paused_until = NULL WHERE `channel` = ?
",
self.channel_id
)
.execute(&pool.clone())
.await;
let embed = Embed::from_id(&pool.clone(), self.id).await.map(|e| e.into());
let result = if let (Some(webhook_id), Some(webhook_token)) =
(self.webhook_id, &self.webhook_token)
{
let webhook_res =
cache_http.http().get_webhook_with_token(webhook_id, webhook_token).await;
if let Ok(webhook) = webhook_res {
send_to_webhook(cache_http, &self, webhook, embed).await
} else {
warn!("Webhook vanished: {:?}", webhook_res);
self.reset_webhook(&pool.clone()).await;
send_to_channel(cache_http, &self, embed).await
}
} else {
send_to_channel(cache_http, &self, embed).await
};
if let Err(e) = result {
error!("Error sending {:?}: {:?}", self, e);
if let Error::Http(error) = e {
if error.status_code() == Some(StatusCode::from_u16(404).unwrap()) {
error!("Seeing channel is deleted. Removing reminder");
self.force_delete(&pool).await;
} else {
self.refresh(&pool).await;
}
} else {
self.refresh(&pool).await;
}
} else {
self.refresh(&pool).await;
}
} else {
info!("Reminder {} is paused", self.id);
self.refresh(&pool).await;
}
}
}

View File

@ -1,15 +1,16 @@
use std::time::{SystemTime, UNIX_EPOCH}; use std::{
convert::TryFrom,
use std::fmt::{Display, Formatter, Result as FmtResult}; fmt::{Display, Formatter, Result as FmtResult},
str::from_utf8,
use crate::consts::{LOCAL_TIMEZONE, PYTHON_LOCATION}; time::{SystemTime, UNIX_EPOCH},
};
use chrono::{DateTime, Datelike, Timelike, Utc}; use chrono::{DateTime, Datelike, Timelike, Utc};
use chrono_tz::Tz; use chrono_tz::Tz;
use std::convert::TryFrom;
use std::str::from_utf8;
use tokio::process::Command; use tokio::process::Command;
use crate::consts::{LOCAL_TIMEZONE, PYTHON_LOCATION};
#[derive(Debug)] #[derive(Debug)]
pub enum InvalidTime { pub enum InvalidTime {
ParseErrorDMY, ParseErrorDMY,
@ -26,11 +27,13 @@ impl Display for InvalidTime {
impl std::error::Error for InvalidTime {} impl std::error::Error for InvalidTime {}
#[derive(Copy, Clone)]
enum ParseType { enum ParseType {
Explicit, Explicit,
Displacement, Displacement,
} }
#[derive(Clone)]
pub struct TimeParser { pub struct TimeParser {
timezone: Tz, timezone: Tz,
inverted: bool, inverted: bool,
@ -95,10 +98,7 @@ impl TimeParser {
} }
fn process_explicit(&self) -> Result<i64, InvalidTime> { fn process_explicit(&self) -> Result<i64, InvalidTime> {
let mut time = Utc::now() let mut time = Utc::now().with_timezone(&self.timezone).with_second(0).unwrap();
.with_timezone(&self.timezone)
.with_second(0)
.unwrap();
let mut segments = self.time_string.rsplit('-'); let mut segments = self.time_string.rsplit('-');
// this segment will always exist even if split fails // this segment will always exist even if split fails
@ -106,11 +106,9 @@ impl TimeParser {
let h_m_s = hms.split(':'); let h_m_s = hms.split(':');
for (t, setter) in h_m_s.take(3).zip(&[ for (t, setter) in
DateTime::with_hour, h_m_s.take(3).zip(&[DateTime::with_hour, DateTime::with_minute, DateTime::with_second])
DateTime::with_minute, {
DateTime::with_second,
]) {
time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorHMS)?) time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorHMS)?)
.map_or_else(|| Err(InvalidTime::ParseErrorHMS), Ok)?; .map_or_else(|| Err(InvalidTime::ParseErrorHMS), Ok)?;
} }
@ -122,9 +120,7 @@ impl TimeParser {
let month = d_m_y.next(); let month = d_m_y.next();
let year = d_m_y.next(); let year = d_m_y.next();
for (t, setter) in [day, month] for (t, setter) in [day, month].iter().zip(&[DateTime::with_day, DateTime::with_month])
.iter()
.zip(&[DateTime::with_day, DateTime::with_month])
{ {
if let Some(t) = t { if let Some(t) = t {
time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorDMY)?) time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorDMY)?)