3 Commits

Author SHA1 Message Date
bf34721e55 stub of new remind command 2021-05-13 18:19:14 +01:00
2c91a72640 slash commands 2021-04-30 00:13:14 +01:00
4a64238ee4 database migration 2021-04-17 16:57:46 +01:00
41 changed files with 5486 additions and 6034 deletions

966
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,11 @@
[package]
name = "reminder_rs"
version = "1.6.0-beta3"
version = "1.5.0"
authors = ["jellywx <judesouthworth@pm.me>"]
edition = "2018"
[dependencies]
dashmap = "4.0"
dotenv = "0.15"
humantime = "2.1"
tokio = { version = "1", features = ["process", "full"] }
@ -13,34 +14,17 @@ regex = "1.4"
log = "0.4"
env_logger = "0.8"
chrono = "0.4"
chrono-tz = { version = "0.5", features = ["serde"] }
chrono-tz = "0.5"
lazy_static = "1.4"
num-integer = "0.1"
serde = "1.0"
serde_json = "1.0"
serde_repr = "0.1"
rmp-serde = "0.15"
rand = "0.7"
Inflector = "0.11"
levenshtein = "1.0"
# serenity = { version = "0.10", features = ["collector"] }
serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["collector", "unstable_discord_api"] }
sqlx = { version = "0.5", features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal", "chrono"]}
base64 = "0.13.0"
[dependencies.regex_command_attr]
path = "command_attributes"
[dependencies.serenity]
git = "https://github.com/serenity-rs/serenity"
branch = "next"
default-features = false
features = [
"builder",
"client",
"cache",
"gateway",
"http",
"model",
"utils",
"rustls_backend",
"collector",
"unstable_discord_api"
]
path = "./regex_command_attr"

View File

@ -1,5 +1,6 @@
# reminder-rs
Reminder Bot for Discord.
Reminder Bot for Discord, now in Rust.
Old Python version: https://github.com/reminder-bot/bot
## How do I use it?
We offer a hosted version of the bot. You can invite it with: **https://invite.reminder-bot.com**. The catch is that repeating
@ -14,6 +15,7 @@ Reminder Bot can be built by running `cargo build --release` in the top level di
These environment variables must be provided when compiling the bot
* `DATABASE_URL` - the URL of your MySQL database (`mysql://user[:password]@domain/database`)
* `WEBHOOK_AVATAR` - accepts the name of an image file located in `$CARGO_MANIFEST_DIR/assets/` to be used as the avatar when creating webhooks. **IMPORTANT: image file must be 128x128 or smaller in size**
* `STRINGS_FILE` - accepts the name of a compiled strings file located in `$CARGO_MANIFEST_DIR/assets/` to be used for creating messages. Compiled string files can be generated with `compile.py` at https://github.com/reminder-bot/languages
### Setting up Python
Reminder Bot by default looks for a venv within it's working directory to run Python out of. To set up a venv, install `python3-venv` and run `python3 -m venv venv`. Then, run `source venv/bin/activate` to activate the venv, and do `pip install dateparser` to install the required library
@ -27,17 +29,16 @@ __Required Variables__
__Other Variables__
* `MIN_INTERVAL` - default `600`, defines the shortest interval the bot should accept
* `MAX_TIME` - default `1576800000`, defines the maximum time ahead that reminders can be set for
* `LOCAL_TIMEZONE` - default `UTC`, necessary for calculations in the natural language processor
* `DEFAULT_PREFIX` - default `$`, used for the default prefix on new guilds
* `SUBSCRIPTION_ROLES` - default `None`, accepts a list of Discord role IDs that are given to subscribed users
* `CNC_GUILD` - default `None`, accepts a single Discord guild ID for the server that the subscription roles belong to
* `IGNORE_BOTS` - default `1`, if `1`, Reminder Bot will ignore all other bots
* `PYTHON_LOCATION` - default `venv/bin/python3`. Can be changed if your Python executable is located somewhere else
* `LOCAL_LANGUAGE` - default `EN`. Specifies the string set to fall back to if a string cannot be found (and to be used with new users)
* `THEME_COLOR` - default `8fb677`. Specifies the hex value of the color to use on info message embeds
* `CASE_INSENSITIVE` - default `1`, if `1`, commands will be treated with case insensitivity (so both `$help` and `$HELP` will work)
* `SHARD_COUNT` - default `None`, accepts the number of shards that are being ran
* `SHARD_RANGE` - default `None`, if `SHARD_COUNT` is specified, specifies what range of shards to start on this process
* `DM_ENABLED` - default `1`, if `1`, Reminder Bot will respond to direct messages
### Todo List
* Convert aliases to macros
* Help command

View File

@ -1,10 +0,0 @@
pub mod suffixes {
pub const COMMAND: &str = "COMMAND";
pub const ARG: &str = "ARG";
pub const SUBCOMMAND: &str = "SUBCOMMAND";
pub const SUBCOMMAND_GROUP: &str = "GROUP";
pub const CHECK: &str = "CHECK";
pub const HOOK: &str = "HOOK";
}
pub use self::suffixes::*;

View File

@ -1,321 +0,0 @@
#![deny(rust_2018_idioms)]
#![deny(broken_intra_doc_links)]
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use syn::{parse::Error, parse_macro_input, parse_quote, spanned::Spanned, Lit, Type};
use uuid::Uuid;
pub(crate) mod attributes;
pub(crate) mod consts;
pub(crate) mod structures;
#[macro_use]
pub(crate) mod util;
use attributes::*;
use consts::*;
use structures::*;
use util::*;
macro_rules! match_options {
($v:expr, $values:ident, $options:ident, $span:expr => [$($name:ident);*]) => {
match $v {
$(
stringify!($name) => $options.$name = propagate_err!($crate::attributes::parse($values)),
)*
_ => {
return Error::new($span, format_args!("invalid attribute: {:?}", $v))
.to_compile_error()
.into();
},
}
};
}
#[proc_macro_attribute]
pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
enum LastItem {
Fun,
SubFun,
SubGroup,
SubGroupFun,
}
let mut fun = parse_macro_input!(input as CommandFun);
let _name = if !attr.is_empty() {
parse_macro_input!(attr as Lit).to_str()
} else {
fun.name.to_string()
};
let mut hooks: Vec<Ident> = Vec::new();
let mut options = Options::new();
let mut last_desc = LastItem::Fun;
for attribute in &fun.attributes {
let span = attribute.span();
let values = propagate_err!(parse_values(attribute));
let name = values.name.to_string();
let name = &name[..];
match name {
"subcommand" => {
let new_subcommand = Subcommand::new(propagate_err!(attributes::parse(values)));
if let Some(subcommand_group) = options.subcommand_groups.last_mut() {
last_desc = LastItem::SubGroupFun;
subcommand_group.subcommands.push(new_subcommand);
} else {
last_desc = LastItem::SubFun;
options.subcommands.push(new_subcommand);
}
}
"subcommandgroup" => {
let new_group = SubcommandGroup::new(propagate_err!(attributes::parse(values)));
last_desc = LastItem::SubGroup;
options.subcommand_groups.push(new_group);
}
"arg" => {
let arg = propagate_err!(attributes::parse(values));
match last_desc {
LastItem::Fun => {
options.cmd_args.push(arg);
}
LastItem::SubFun => {
options.subcommands.last_mut().unwrap().cmd_args.push(arg);
}
LastItem::SubGroup => {
panic!("Argument not expected under subcommand group");
}
LastItem::SubGroupFun => {
options
.subcommand_groups
.last_mut()
.unwrap()
.subcommands
.last_mut()
.unwrap()
.cmd_args
.push(arg);
}
}
}
"example" => {
options.examples.push(propagate_err!(attributes::parse(values)));
}
"description" => {
let line: String = propagate_err!(attributes::parse(values));
match last_desc {
LastItem::Fun => {
util::append_line(&mut options.description, line);
}
LastItem::SubFun => {
util::append_line(
&mut options.subcommands.last_mut().unwrap().description,
line,
);
}
LastItem::SubGroup => {
util::append_line(
&mut options.subcommand_groups.last_mut().unwrap().description,
line,
);
}
LastItem::SubGroupFun => {
util::append_line(
&mut options
.subcommand_groups
.last_mut()
.unwrap()
.subcommands
.last_mut()
.unwrap()
.description,
line,
);
}
}
}
"hook" => {
hooks.push(propagate_err!(attributes::parse(values)));
}
_ => {
match_options!(name, values, options, span => [
aliases;
group;
can_blacklist;
supports_dm
]);
}
}
}
let Options {
aliases,
description,
group,
examples,
can_blacklist,
supports_dm,
mut cmd_args,
mut subcommands,
mut subcommand_groups,
} = options;
let visibility = fun.visibility;
let name = fun.name.clone();
let body = fun.body;
let root_ident = name.with_suffix(COMMAND);
let command_path = quote!(crate::framework::Command);
populate_fut_lifetimes_on_refs(&mut fun.args);
let mut subcommand_group_idents = subcommand_groups
.iter()
.map(|subcommand| {
root_ident
.with_suffix(subcommand.name.replace("-", "_").as_str())
.with_suffix(SUBCOMMAND_GROUP)
})
.collect::<Vec<Ident>>();
let mut subcommand_idents = subcommands
.iter()
.map(|subcommand| {
root_ident
.with_suffix(subcommand.name.replace("-", "_").as_str())
.with_suffix(SUBCOMMAND)
})
.collect::<Vec<Ident>>();
let mut arg_idents = cmd_args
.iter()
.map(|arg| root_ident.with_suffix(arg.name.replace("-", "_").as_str()).with_suffix(ARG))
.collect::<Vec<Ident>>();
let mut tokens = quote! {};
tokens.extend(
subcommand_groups
.iter_mut()
.zip(subcommand_group_idents.iter())
.map(|(group, group_ident)| group.as_tokens(group_ident))
.fold(quote! {}, |mut a, b| {
a.extend(b);
a
}),
);
tokens.extend(
subcommands
.iter_mut()
.zip(subcommand_idents.iter())
.map(|(subcommand, sc_ident)| subcommand.as_tokens(sc_ident))
.fold(quote! {}, |mut a, b| {
a.extend(b);
a
}),
);
tokens.extend(
cmd_args.iter_mut().zip(arg_idents.iter()).map(|(arg, ident)| arg.as_tokens(ident)).fold(
quote! {},
|mut a, b| {
a.extend(b);
a
},
),
);
arg_idents.append(&mut subcommand_group_idents);
arg_idents.append(&mut subcommand_idents);
let args = fun.args;
let variant = if args.len() == 2 {
quote!(crate::framework::CommandFnType::Multi)
} else {
let string: Type = parse_quote!(String);
let final_arg = args.get(2).unwrap();
if final_arg.kind == string {
quote!(crate::framework::CommandFnType::Text)
} else {
quote!(crate::framework::CommandFnType::Slash)
}
};
tokens.extend(quote! {
#[allow(missing_docs)]
pub static #root_ident: #command_path = #command_path {
fun: #variant(#name),
names: &[#_name, #(#aliases),*],
desc: #description,
group: #group,
examples: &[#(#examples),*],
can_blacklist: #can_blacklist,
supports_dm: #supports_dm,
args: &[#(&#arg_idents),*],
hooks: &[#(&#hooks),*],
};
#[allow(missing_docs)]
#visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, ()> {
use ::serenity::futures::future::FutureExt;
async move {
#(#body)*;
}.boxed()
}
});
tokens.into()
}
#[proc_macro_attribute]
pub fn check(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut fun = parse_macro_input!(input as CommandFun);
let n = fun.name.clone();
let name = n.with_suffix(HOOK);
let fn_name = n.with_suffix(CHECK);
let visibility = fun.visibility;
let body = fun.body;
let ret = fun.ret;
populate_fut_lifetimes_on_refs(&mut fun.args);
let args = fun.args;
let hook_path = quote!(crate::framework::Hook);
let uuid = Uuid::new_v4().as_u128();
(quote! {
#[allow(missing_docs)]
#visibility fn #fn_name<'fut>(#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> {
use ::serenity::futures::future::FutureExt;
async move {
let _output: #ret = { #(#body)* };
#[allow(unreachable_code)]
_output
}.boxed()
}
#[allow(missing_docs)]
pub static #name: #hook_path = #hook_path {
fun: #fn_name,
uuid: #uuid,
};
})
.into()
}

View File

@ -1,331 +0,0 @@
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{
braced,
parse::{Error, Parse, ParseStream, Result},
spanned::Spanned,
Attribute, Block, FnArg, Ident, Pat, ReturnType, Stmt, Token, Type, Visibility,
};
use crate::{
consts::{ARG, SUBCOMMAND},
util::{Argument, IdentExt2, Parenthesised},
};
fn parse_argument(arg: FnArg) -> Result<Argument> {
match arg {
FnArg::Typed(typed) => {
let pat = typed.pat;
let kind = typed.ty;
match *pat {
Pat::Ident(id) => {
let name = id.ident;
let mutable = id.mutability;
Ok(Argument { mutable, name, kind: *kind })
}
Pat::Wild(wild) => {
let token = wild.underscore_token;
let name = Ident::new("_", token.spans[0]);
Ok(Argument { mutable: None, name, kind: *kind })
}
_ => Err(Error::new(pat.span(), format_args!("unsupported pattern: {:?}", pat))),
}
}
FnArg::Receiver(_) => {
Err(Error::new(arg.span(), format_args!("`self` arguments are prohibited: {:?}", arg)))
}
}
}
#[derive(Debug)]
pub struct CommandFun {
/// `#[...]`-style attributes.
pub attributes: Vec<Attribute>,
/// Populated cooked attributes. These are attributes outside of the realm of this crate's procedural macros
/// and will appear in generated output.
pub visibility: Visibility,
pub name: Ident,
pub args: Vec<Argument>,
pub ret: Type,
pub body: Vec<Stmt>,
}
impl Parse for CommandFun {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let attributes = input.call(Attribute::parse_outer)?;
let visibility = input.parse::<Visibility>()?;
input.parse::<Token![async]>()?;
input.parse::<Token![fn]>()?;
let name = input.parse()?;
// (...)
let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?;
let ret = match input.parse::<ReturnType>()? {
ReturnType::Type(_, t) => (*t).clone(),
ReturnType::Default => Type::Verbatim(quote!(())),
};
// { ... }
let bcont;
braced!(bcont in input);
let body = bcont.call(Block::parse_within)?;
let args = args.into_iter().map(parse_argument).collect::<Result<Vec<_>>>()?;
Ok(Self { attributes, visibility, name, args, ret, body })
}
}
impl ToTokens for CommandFun {
fn to_tokens(&self, stream: &mut TokenStream2) {
let Self { attributes: _, visibility, name, args, ret, body } = self;
stream.extend(quote! {
#visibility async fn #name (#(#args),*) -> #ret {
#(#body)*
}
});
}
}
#[derive(Debug)]
pub(crate) enum ApplicationCommandOptionType {
SubCommand,
SubCommandGroup,
String,
Integer,
Boolean,
User,
Channel,
Role,
Mentionable,
Number,
Unknown,
}
impl ApplicationCommandOptionType {
pub fn from_str(s: String) -> Self {
match s.as_str() {
"SubCommand" => Self::SubCommand,
"SubCommandGroup" => Self::SubCommandGroup,
"String" => Self::String,
"Integer" => Self::Integer,
"Boolean" => Self::Boolean,
"User" => Self::User,
"Channel" => Self::Channel,
"Role" => Self::Role,
"Mentionable" => Self::Mentionable,
"Number" => Self::Number,
_ => Self::Unknown,
}
}
}
impl ToTokens for ApplicationCommandOptionType {
fn to_tokens(&self, stream: &mut TokenStream2) {
let path = quote!(
serenity::model::interactions::application_command::ApplicationCommandOptionType
);
let variant = match self {
ApplicationCommandOptionType::SubCommand => quote!(SubCommand),
ApplicationCommandOptionType::SubCommandGroup => quote!(SubCommandGroup),
ApplicationCommandOptionType::String => quote!(String),
ApplicationCommandOptionType::Integer => quote!(Integer),
ApplicationCommandOptionType::Boolean => quote!(Boolean),
ApplicationCommandOptionType::User => quote!(User),
ApplicationCommandOptionType::Channel => quote!(Channel),
ApplicationCommandOptionType::Role => quote!(Role),
ApplicationCommandOptionType::Mentionable => quote!(Mentionable),
ApplicationCommandOptionType::Number => quote!(Number),
ApplicationCommandOptionType::Unknown => quote!(Unknown),
};
stream.extend(quote! {
#path::#variant
});
}
}
#[derive(Debug)]
pub(crate) struct Arg {
pub name: String,
pub description: String,
pub kind: ApplicationCommandOptionType,
pub required: bool,
}
impl Arg {
pub fn as_tokens(&self, ident: &Ident) -> TokenStream2 {
let arg_path = quote!(crate::framework::Arg);
let Arg { name, description, kind, required } = self;
quote! {
#[allow(missing_docs)]
pub static #ident: #arg_path = #arg_path {
name: #name,
description: #description,
kind: #kind,
required: #required,
options: &[]
};
}
}
}
impl Default for Arg {
fn default() -> Self {
Self {
name: String::new(),
description: String::new(),
kind: ApplicationCommandOptionType::String,
required: false,
}
}
}
#[derive(Debug)]
pub(crate) struct Subcommand {
pub name: String,
pub description: String,
pub cmd_args: Vec<Arg>,
}
impl Subcommand {
pub fn as_tokens(&mut self, ident: &Ident) -> TokenStream2 {
let arg_path = quote!(crate::framework::Arg);
let subcommand_path = ApplicationCommandOptionType::SubCommand;
let arg_idents = self
.cmd_args
.iter()
.map(|arg| ident.with_suffix(arg.name.as_str()).with_suffix(ARG))
.collect::<Vec<Ident>>();
let mut tokens = self
.cmd_args
.iter_mut()
.zip(arg_idents.iter())
.map(|(arg, ident)| arg.as_tokens(ident))
.fold(quote! {}, |mut a, b| {
a.extend(b);
a
});
let Subcommand { name, description, .. } = self;
tokens.extend(quote! {
#[allow(missing_docs)]
pub static #ident: #arg_path = #arg_path {
name: #name,
description: #description,
kind: #subcommand_path,
required: false,
options: &[#(&#arg_idents),*],
};
});
tokens
}
}
impl Default for Subcommand {
fn default() -> Self {
Self { name: String::new(), description: String::new(), cmd_args: vec![] }
}
}
impl Subcommand {
pub(crate) fn new(name: String) -> Self {
Self { name, ..Default::default() }
}
}
#[derive(Debug)]
pub(crate) struct SubcommandGroup {
pub name: String,
pub description: String,
pub subcommands: Vec<Subcommand>,
}
impl SubcommandGroup {
pub fn as_tokens(&mut self, ident: &Ident) -> TokenStream2 {
let arg_path = quote!(crate::framework::Arg);
let subcommand_group_path = ApplicationCommandOptionType::SubCommandGroup;
let arg_idents = self
.subcommands
.iter()
.map(|arg| {
ident
.with_suffix(self.name.as_str())
.with_suffix(arg.name.as_str())
.with_suffix(SUBCOMMAND)
})
.collect::<Vec<Ident>>();
let mut tokens = self
.subcommands
.iter_mut()
.zip(arg_idents.iter())
.map(|(subcommand, ident)| subcommand.as_tokens(ident))
.fold(quote! {}, |mut a, b| {
a.extend(b);
a
});
let SubcommandGroup { name, description, .. } = self;
tokens.extend(quote! {
#[allow(missing_docs)]
pub static #ident: #arg_path = #arg_path {
name: #name,
description: #description,
kind: #subcommand_group_path,
required: false,
options: &[#(&#arg_idents),*],
};
});
tokens
}
}
impl Default for SubcommandGroup {
fn default() -> Self {
Self { name: String::new(), description: String::new(), subcommands: vec![] }
}
}
impl SubcommandGroup {
pub(crate) fn new(name: String) -> Self {
Self { name, ..Default::default() }
}
}
#[derive(Debug, Default)]
pub(crate) struct Options {
pub aliases: Vec<String>,
pub description: String,
pub group: String,
pub examples: Vec<String>,
pub can_blacklist: bool,
pub supports_dm: bool,
pub cmd_args: Vec<Arg>,
pub subcommands: Vec<Subcommand>,
pub subcommand_groups: Vec<SubcommandGroup>,
}
impl Options {
#[inline]
pub fn new() -> Self {
Self { group: "None".to_string(), ..Default::default() }
}
}

View File

@ -1,5 +1,3 @@
CREATE DATABASE IF NOT EXISTS reminders;
SET FOREIGN_KEY_CHECKS=0;
USE reminders;

View File

@ -1,13 +0,0 @@
USE reminders;
CREATE TABLE macro (
id INT UNSIGNED AUTO_INCREMENT,
guild_id INT UNSIGNED NOT NULL,
name VARCHAR(100) NOT NULL,
description VARCHAR(100),
commands TEXT NOT NULL,
FOREIGN KEY (guild_id) REFERENCES guilds(id) ON DELETE CASCADE,
PRIMARY KEY (id)
);

View File

@ -48,14 +48,14 @@ CREATE TABLE reminders_new (
PRIMARY KEY (id),
FOREIGN KEY (`channel_id`) REFERENCES channels (`id`) ON DELETE CASCADE,
FOREIGN KEY (`set_by`) REFERENCES users (`id`) ON DELETE SET NULL
FOREIGN KEY (`set_by`) REFERENCES users (`id`) ON DELETE SET NULL,
# disallow having a reminder as restartable if it has no interval
-- , CONSTRAINT restartable_interval_mutex CHECK (`restartable` = 0 OR `interval` IS NULL)
CONSTRAINT restartable_interval_mutex CHECK (`restartable` = 0 OR `interval` IS NULL),
# disallow disabling if interval is unspecified
-- , CONSTRAINT interval_enabled_mutin CHECK (`enabled` = 1 OR `interval` IS NULL)
CONSTRAINT interval_enabled_mutin CHECK (`enabled` = 1 OR `interval` IS NULL),
# disallow an expiry time if interval is unspecified
-- , CONSTRAINT interval_expires_mutin CHECK (`expires` IS NULL OR `interval` IS NOT NULL)
CONSTRAINT interval_expires_mutin CHECK (`expires` IS NULL OR `interval` IS NOT NULL)
);
# import data from other tables
@ -86,7 +86,7 @@ INSERT INTO reminders_new (
reminders.uid,
reminders.name,
reminders.channel_id,
DATE_ADD(FROM_UNIXTIME(0), INTERVAL reminders.`time` SECOND),
FROM_UNIXTIME(reminders.time),
reminders.`interval`,
reminders.enabled,
reminders.expires,
@ -120,7 +120,7 @@ CREATE TABLE embed_fields_new (
PRIMARY KEY (id),
FOREIGN KEY (reminder_id) REFERENCES reminders_new (id) ON DELETE CASCADE
FOREIGN KEY (reminder_id) REFERENCES reminders_new (id)
);
INSERT INTO embed_fields_new (

View File

@ -1,10 +1,9 @@
[package]
name = "regex_command_attr"
version = "0.3.6"
version = "0.2.0"
authors = ["acdenisSK <acdenissk69@gmail.com>", "jellywx <judesouthworth@pm.me>"]
edition = "2018"
description = "Procedural macros for command creation for the Serenity library."
license = "ISC"
description = "Procedural macros for command creation for the RegexFramework for serenity."
[lib]
proc-macro = true
@ -13,4 +12,3 @@ proc-macro = true
quote = "^1.0"
syn = { version = "^1.0", features = ["full", "derive", "extra-traits"] }
proc-macro2 = "1.0"
uuid = { version = "0.8", features = ["v4"] }

View File

@ -1,16 +1,12 @@
use std::fmt::{self, Write};
use proc_macro2::Span;
use syn::{
parse::{Error, Result},
spanned::Spanned,
Attribute, Ident, Lit, LitStr, Meta, NestedMeta, Path,
};
use syn::parse::{Error, Result};
use syn::spanned::Spanned;
use syn::{Attribute, Ident, Lit, LitStr, Meta, NestedMeta, Path};
use crate::{
structures::{ApplicationCommandOptionType, Arg},
util::{AsOption, LitExt},
};
use crate::structures::PermissionLevel;
use crate::util::{AsOption, LitExt};
use std::fmt::{self, Write};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ValueKind {
@ -23,9 +19,6 @@ pub enum ValueKind {
// #[<name>([<value>, <value>, <value>, ...])]
List,
// #[<name>([<prop> = <value>, <prop> = <value>, ...])]
EqualsList,
// #[<name>(<value>)]
SingleList,
}
@ -36,9 +29,6 @@ impl fmt::Display for ValueKind {
ValueKind::Name => f.pad("`#[<name>]`"),
ValueKind::Equals => f.pad("`#[<name> = <value>]`"),
ValueKind::List => f.pad("`#[<name>([<value>, <value>, <value>, ...])]`"),
ValueKind::EqualsList => {
f.pad("`#[<name>([<prop> = <value>, <prop> = <value>, ...])]`")
}
ValueKind::SingleList => f.pad("`#[<name>(<value>)]`"),
}
}
@ -46,15 +36,24 @@ impl fmt::Display for ValueKind {
fn to_ident(p: Path) -> Result<Ident> {
if p.segments.is_empty() {
return Err(Error::new(p.span(), "cannot convert an empty path to an identifier"));
return Err(Error::new(
p.span(),
"cannot convert an empty path to an identifier",
));
}
if p.segments.len() > 1 {
return Err(Error::new(p.span(), "the path must not have more than one segment"));
return Err(Error::new(
p.span(),
"the path must not have more than one segment",
));
}
if !p.segments[0].arguments.is_empty() {
return Err(Error::new(p.span(), "the singular path segment must not have any arguments"));
return Err(Error::new(
p.span(),
"the singular path segment must not have any arguments",
));
}
Ok(p.segments[0].ident.clone())
@ -63,37 +62,24 @@ fn to_ident(p: Path) -> Result<Ident> {
#[derive(Debug)]
pub struct Values {
pub name: Ident,
pub literals: Vec<(Option<String>, Lit)>,
pub literals: Vec<Lit>,
pub kind: ValueKind,
pub span: Span,
}
impl Values {
#[inline]
pub fn new(
name: Ident,
kind: ValueKind,
literals: Vec<(Option<String>, Lit)>,
span: Span,
) -> Self {
Values { name, literals, kind, span }
pub fn new(name: Ident, kind: ValueKind, literals: Vec<Lit>, span: Span) -> Self {
Values {
name,
literals,
kind,
span,
}
}
}
pub fn parse_values(attr: &Attribute) -> Result<Values> {
fn is_list_or_named_list(meta: &NestedMeta) -> ValueKind {
match meta {
// catch if the nested value is a literal value
NestedMeta::Lit(_) => ValueKind::List,
// catch if the nested value is a meta value
NestedMeta::Meta(m) => match m {
// path => some quoted value
Meta::Path(_) => ValueKind::List,
Meta::List(_) | Meta::NameValue(_) => ValueKind::EqualsList,
},
}
}
let meta = attr.parse_meta()?;
match meta {
@ -110,62 +96,36 @@ pub fn parse_values(attr: &Attribute) -> Result<Values> {
return Err(Error::new(attr.span(), "list cannot be empty"));
}
if is_list_or_named_list(nested.first().unwrap()) == ValueKind::List {
let mut lits = Vec::with_capacity(nested.len());
let mut lits = Vec::with_capacity(nested.len());
for meta in nested {
match meta {
// catch if the nested value is a literal value
NestedMeta::Lit(l) => lits.push((None, l)),
// catch if the nested value is a meta value
NestedMeta::Meta(m) => match m {
// path => some quoted value
Meta::Path(path) => {
let i = to_ident(path)?;
lits.push((None, Lit::Str(LitStr::new(&i.to_string(), i.span()))))
}
Meta::List(_) | Meta::NameValue(_) => {
return Err(Error::new(attr.span(), "cannot nest a list; only accept literals and identifiers at this level"))
}
},
}
}
let kind = if lits.len() == 1 { ValueKind::SingleList } else { ValueKind::List };
Ok(Values::new(name, kind, lits, attr.span()))
} else {
let mut lits = Vec::with_capacity(nested.len());
for meta in nested {
match meta {
// catch if the nested value is a literal value
NestedMeta::Lit(_) => {
return Err(Error::new(attr.span(), "key-value pairs expected"))
for meta in nested {
match meta {
NestedMeta::Lit(l) => lits.push(l),
NestedMeta::Meta(m) => match m {
Meta::Path(path) => {
let i = to_ident(path)?;
lits.push(Lit::Str(LitStr::new(&i.to_string(), i.span())))
}
// catch if the nested value is a meta value
NestedMeta::Meta(m) => match m {
Meta::NameValue(n) => {
let name = to_ident(n.path)?.to_string();
let value = n.lit;
lits.push((Some(name), value));
}
Meta::List(_) | Meta::Path(_) => {
return Err(Error::new(attr.span(), "key-value pairs expected"))
}
},
}
Meta::List(_) | Meta::NameValue(_) => {
return Err(Error::new(attr.span(), "cannot nest a list; only accept literals and identifiers at this level"))
}
},
}
Ok(Values::new(name, ValueKind::EqualsList, lits, attr.span()))
}
let kind = if lits.len() == 1 {
ValueKind::SingleList
} else {
ValueKind::List
};
Ok(Values::new(name, kind, lits, attr.span()))
}
Meta::NameValue(meta) => {
let name = to_ident(meta.path)?;
let lit = meta.lit;
Ok(Values::new(name, ValueKind::Equals, vec![(None, lit)], attr.span()))
Ok(Values::new(name, ValueKind::Equals, vec![lit], attr.span()))
}
}
}
@ -208,7 +168,10 @@ fn validate(values: &Values, forms: &[ValueKind]) -> Result<()> {
return Err(Error::new(
values.span,
// Using the `_args` version here to avoid an allocation.
format_args!("the attribute must be in of these forms:\n{}", DisplaySlice(forms)),
format_args!(
"the attribute must be in of these forms:\n{}",
DisplaySlice(forms)
),
));
}
@ -228,7 +191,11 @@ impl AttributeOption for Vec<String> {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::List])?;
Ok(values.literals.into_iter().map(|(_, l)| l.to_str()).collect())
Ok(values
.literals
.into_iter()
.map(|lit| lit.to_str())
.collect())
}
}
@ -237,7 +204,7 @@ impl AttributeOption for String {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::Equals, ValueKind::SingleList])?;
Ok(values.literals[0].1.to_str())
Ok(values.literals[0].to_str())
}
}
@ -246,7 +213,7 @@ impl AttributeOption for bool {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::Name, ValueKind::SingleList])?;
Ok(values.literals.get(0).map_or(true, |(_, l)| l.to_bool()))
Ok(values.literals.get(0).map_or(true, |l| l.to_bool()))
}
}
@ -255,7 +222,7 @@ impl AttributeOption for Ident {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::SingleList])?;
Ok(values.literals[0].1.to_ident())
Ok(values.literals[0].to_ident())
}
}
@ -264,7 +231,7 @@ impl AttributeOption for Vec<Ident> {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::List])?;
Ok(values.literals.into_iter().map(|(_, l)| l.to_ident()).collect())
Ok(values.literals.into_iter().map(|l| l.to_ident()).collect())
}
}
@ -272,40 +239,15 @@ impl AttributeOption for Option<String> {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::Name, ValueKind::Equals, ValueKind::SingleList])?;
Ok(values.literals.get(0).map(|(_, l)| l.to_str()))
Ok(values.literals.get(0).map(|l| l.to_str()))
}
}
impl AttributeOption for Arg {
impl AttributeOption for PermissionLevel {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::EqualsList])?;
validate(&values, &[ValueKind::SingleList])?;
let mut arg: Arg = Default::default();
for (key, value) in &values.literals {
match key {
Some(s) => match s.as_str() {
"name" => {
arg.name = value.to_str();
}
"description" => {
arg.description = value.to_str();
}
"required" => {
arg.required = value.to_bool();
}
"kind" => arg.kind = ApplicationCommandOptionType::from_str(value.to_str()),
_ => {
return Err(Error::new(key.span(), "unexpected attribute"));
}
},
_ => {
return Err(Error::new(key.span(), "unnamed attribute"));
}
}
}
Ok(arg)
Ok(values.literals.get(0).map(|l| PermissionLevel::from_str(&*l.to_str()).unwrap()).unwrap())
}
}
@ -323,7 +265,7 @@ macro_rules! attr_option_num {
fn parse(values: Values) -> Result<Self> {
validate(&values, &[ValueKind::SingleList])?;
Ok(match &values.literals[0].1 {
Ok(match &values.literals[0] {
Lit::Int(l) => l.base10_parse::<$n>()?,
l => {
let s = l.to_str();

View File

@ -0,0 +1,5 @@
pub mod suffixes {
pub const COMMAND: &str = "COMMAND";
}
pub use self::suffixes::*;

View File

@ -0,0 +1,102 @@
#![deny(rust_2018_idioms)]
// FIXME: Remove this in a foreseeable future.
// Currently exists for backwards compatibility to previous Rust versions.
#![recursion_limit = "128"]
#[allow(unused_extern_crates)]
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse::Error, parse_macro_input, spanned::Spanned, Lit};
pub(crate) mod attributes;
pub(crate) mod consts;
pub(crate) mod structures;
#[macro_use]
pub(crate) mod util;
use attributes::*;
use consts::*;
use structures::*;
use util::*;
macro_rules! match_options {
($v:expr, $values:ident, $options:ident, $span:expr => [$($name:ident);*]) => {
match $v {
$(
stringify!($name) => $options.$name = propagate_err!($crate::attributes::parse($values)),
)*
_ => {
return Error::new($span, format_args!("invalid attribute: {:?}", $v))
.to_compile_error()
.into();
},
}
};
}
#[proc_macro_attribute]
pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
let mut fun = parse_macro_input!(input as CommandFun);
let lit_name = if !attr.is_empty() {
parse_macro_input!(attr as Lit).to_str()
} else {
fun.name.to_string()
};
let mut options = Options::new();
for attribute in &fun.attributes {
let span = attribute.span();
let values = propagate_err!(parse_values(attribute));
let name = values.name.to_string();
let name = &name[..];
match_options!(name, values, options, span => [
permission_level;
supports_dm;
can_blacklist
]);
}
let Options {
permission_level,
supports_dm,
can_blacklist,
} = options;
let visibility = fun.visibility;
let name = fun.name.clone();
let body = fun.body;
let n = name.with_suffix(COMMAND);
let cooked = fun.cooked.clone();
let command_path = quote!(crate::framework::Command);
populate_fut_lifetimes_on_refs(&mut fun.args);
let args = fun.args;
(quote! {
#(#cooked)*
pub static #n: #command_path = #command_path {
func: #name,
name: #lit_name,
required_perms: #permission_level,
supports_dm: #supports_dm,
can_blacklist: #can_blacklist,
};
#visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, ()> {
use ::serenity::futures::future::FutureExt;
async move { #(#body)* }.boxed()
}
})
.into()
}

View File

@ -0,0 +1,231 @@
use crate::util::{Argument, Parenthesised};
use proc_macro2::Span;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{
braced,
parse::{Error, Parse, ParseStream, Result},
spanned::Spanned,
Attribute, Block, FnArg, Ident, Pat, Path, PathSegment, Stmt, Token, Visibility,
};
fn parse_argument(arg: FnArg) -> Result<Argument> {
match arg {
FnArg::Typed(typed) => {
let pat = typed.pat;
let kind = typed.ty;
match *pat {
Pat::Ident(id) => {
let name = id.ident;
let mutable = id.mutability;
Ok(Argument {
mutable,
name,
kind: *kind,
})
}
Pat::Wild(wild) => {
let token = wild.underscore_token;
let name = Ident::new("_", token.spans[0]);
Ok(Argument {
mutable: None,
name,
kind: *kind,
})
}
_ => Err(Error::new(
pat.span(),
format_args!("unsupported pattern: {:?}", pat),
)),
}
}
FnArg::Receiver(_) => Err(Error::new(
arg.span(),
format_args!("`self` arguments are prohibited: {:?}", arg),
)),
}
}
/// Test if the attribute is cooked.
fn is_cooked(attr: &Attribute) -> bool {
const COOKED_ATTRIBUTE_NAMES: &[&str] = &[
"cfg", "cfg_attr", "doc", "derive", "inline", "allow", "warn", "deny", "forbid",
];
COOKED_ATTRIBUTE_NAMES.iter().any(|n| attr.path.is_ident(n))
}
/// Removes cooked attributes from a vector of attributes. Uncooked attributes are left in the vector.
///
/// # Return
///
/// Returns a vector of cooked attributes that have been removed from the input vector.
fn remove_cooked(attrs: &mut Vec<Attribute>) -> Vec<Attribute> {
let mut cooked = Vec::new();
// FIXME: Replace with `Vec::drain_filter` once it is stable.
let mut i = 0;
while i < attrs.len() {
if !is_cooked(&attrs[i]) {
i += 1;
continue;
}
cooked.push(attrs.remove(i));
}
cooked
}
#[derive(Debug)]
pub struct CommandFun {
/// `#[...]`-style attributes.
pub attributes: Vec<Attribute>,
/// Populated cooked attributes. These are attributes outside of the realm of this crate's procedural macros
/// and will appear in generated output.
pub cooked: Vec<Attribute>,
pub visibility: Visibility,
pub name: Ident,
pub args: Vec<Argument>,
pub body: Vec<Stmt>,
}
impl Parse for CommandFun {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let mut attributes = input.call(Attribute::parse_outer)?;
// `#[doc = "..."]` is a cooked attribute but it is special-cased for commands.
for attr in &mut attributes {
// Rename documentation comment attributes (`#[doc = "..."]`) to `#[description = "..."]`.
if attr.path.is_ident("doc") {
attr.path = Path::from(PathSegment::from(Ident::new(
"description",
Span::call_site(),
)));
}
}
let cooked = remove_cooked(&mut attributes);
let visibility = input.parse::<Visibility>()?;
input.parse::<Token![async]>()?;
input.parse::<Token![fn]>()?;
let name = input.parse()?;
// (...)
let Parenthesised(args) = input.parse::<Parenthesised<FnArg>>()?;
// { ... }
let bcont;
braced!(bcont in input);
let body = bcont.call(Block::parse_within)?;
let args = args
.into_iter()
.map(parse_argument)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
attributes,
cooked,
visibility,
name,
args,
body,
})
}
}
impl ToTokens for CommandFun {
fn to_tokens(&self, stream: &mut TokenStream2) {
let Self {
attributes: _,
cooked,
visibility,
name,
args,
body,
} = self;
stream.extend(quote! {
#(#cooked)*
#visibility async fn #name (#(#args),*) -> () {
#(#body)*
}
});
}
}
#[derive(Debug)]
pub enum PermissionLevel {
Unrestricted,
Managed,
Restricted,
}
impl Default for PermissionLevel {
fn default() -> Self {
Self::Unrestricted
}
}
impl PermissionLevel {
pub fn from_str(s: &str) -> Option<Self> {
Some(match s.to_uppercase().as_str() {
"UNRESTRICTED" => Self::Unrestricted,
"MANAGED" => Self::Managed,
"RESTRICTED" => Self::Restricted,
_ => return None,
})
}
}
impl ToTokens for PermissionLevel {
fn to_tokens(&self, stream: &mut TokenStream2) {
let path = quote!(crate::framework::PermissionLevel);
let variant;
match self {
Self::Unrestricted => {
variant = quote!(Unrestricted);
}
Self::Managed => {
variant = quote!(Managed);
}
Self::Restricted => {
variant = quote!(Restricted);
}
}
stream.extend(quote! {
#path::#variant
});
}
}
#[derive(Debug, Default)]
pub struct Options {
pub permission_level: PermissionLevel,
pub supports_dm: bool,
pub can_blacklist: bool,
}
impl Options {
#[inline]
pub fn new() -> Self {
let mut options = Self::default();
options.can_blacklist = true;
options.supports_dm = true;
options
}
}

View File

@ -1,5 +1,6 @@
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use proc_macro2::Span;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{
braced, bracketed, parenthesized,
@ -157,20 +158,3 @@ pub fn populate_fut_lifetimes_on_refs(args: &mut Vec<Argument>) {
}
}
}
pub fn append_line(desc: &mut String, mut line: String) {
if line.starts_with(' ') {
line.remove(0);
}
match line.rfind("\\$") {
Some(i) => {
desc.push_str(line[..i].trim_end());
desc.push(' ');
}
None => {
desc.push_str(&line);
desc.push('\n');
}
}
}

View File

@ -1,3 +0,0 @@
imports_granularity = "Crate"
group_imports = "StdExternalCrate"
use_small_heuristics = "Max"

View File

@ -1,15 +1,43 @@
use chrono::offset::Utc;
use regex_command_attr::command;
use serenity::{builder::CreateEmbedFooter, client::Context};
use crate::{
framework::{CommandInvoke, CreateGenericResponse},
models::CtxData,
THEME_COLOR,
use serenity::{
builder::CreateEmbedFooter,
client::Context,
model::{
channel::Message,
interactions::{Interaction, InteractionResponseType},
},
};
fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter {
let shard_count = ctx.cache.shard_count();
use chrono::offset::Utc;
use crate::{
command_help, consts::DEFAULT_PREFIX, get_ctx_data, language_manager::LanguageManager,
models::CtxGuildData, models::UserData, FrameworkCtx, THEME_COLOR,
};
use inflector::Inflector;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
#[command]
#[can_blacklist(false)]
async fn ping(ctx: &Context, msg: &Message, _args: String) {
let now = SystemTime::now();
let since_epoch = now
.duration_since(UNIX_EPOCH)
.expect("Time calculated as going backwards. Very bad");
let delta = since_epoch.as_millis() as i64 - msg.timestamp.timestamp_millis();
let _ = msg
.channel_id
.say(&ctx, format!("Time taken to receive message: {}ms", delta))
.await;
}
async fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEmbedFooter {
let shard_count = ctx.cache.shard_count().await;
let shard = ctx.shard_id;
move |f| {
@ -23,140 +51,386 @@ fn footer(ctx: &Context) -> impl FnOnce(&mut CreateEmbedFooter) -> &mut CreateEm
}
#[command]
#[description("Get an overview of the bot commands")]
async fn help(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx);
#[can_blacklist(false)]
async fn help(ctx: &Context, msg: &Message, args: String) {
async fn default_help(
ctx: &Context,
msg: &Message,
lm: Arc<LanguageManager>,
prefix: &str,
language: &str,
) {
let desc = lm.get(language, "help/desc").replace("{prefix}", prefix);
let footer = footer(ctx).await;
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().embed(|e| {
e.title("Help")
.color(*THEME_COLOR)
.description(
"__Info Commands__
`/help` `/info` `/donate` `/dashboard` `/clock`
*run these commands with no options*
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Help Menu")
.description(desc)
.field(
lm.get(language, "help/setup_title"),
"`lang` `timezone` `meridian`",
true,
)
.field(
lm.get(language, "help/mod_title"),
"`prefix` `blacklist` `restrict` `alias`",
true,
)
.field(
lm.get(language, "help/reminder_title"),
"`remind` `interval` `natural` `look` `countdown`",
true,
)
.field(
lm.get(language, "help/reminder_mod_title"),
"`del` `offset` `pause` `nudge`",
true,
)
.field(
lm.get(language, "help/info_title"),
"`help` `info` `donate` `clock`",
true,
)
.field(
lm.get(language, "help/todo_title"),
"`todo` `todos` `todoc`",
true,
)
.field(lm.get(language, "help/other_title"), "`timer`", true)
.footer(footer)
.color(*THEME_COLOR)
})
})
.await;
}
__Reminder Commands__
`/remind` - Create a new reminder that will send a message at a certain time
`/timer` - Start a timer from now, that will count time passed. Also used to view and remove timers
let (pool, lm) = get_ctx_data(&ctx).await;
__Reminder Management__
`/del` - Delete reminders
`/look` - View reminders
`/pause` - Pause all reminders on the channel
`/offset` - Move all reminders by a certain time
`/nudge` - Move all new reminders on this channel by a certain time
let language = UserData::language_of(&msg.author, &pool);
let prefix = ctx.prefix(msg.guild_id);
__Todo Commands__
`/todo` - Add, view and manage the server, channel or user todo lists
if !args.is_empty() {
let framework = ctx
.data
.read()
.await
.get::<FrameworkCtx>()
.cloned()
.expect("Could not get FrameworkCtx from data");
__Setup Commands__
`/timezone` - Set your timezone (necessary for `/remind` to work properly)
let matched = framework
.commands
.get(args.as_str())
.map(|inner| inner.name);
__Advanced Commands__
`/macro` - Record and replay command sequences
",
)
.footer(footer)
}),
)
.await;
if let Some(command_name) = matched {
command_help(ctx, msg, lm, &prefix.await, &language.await, command_name).await
} else {
default_help(ctx, msg, lm, &prefix.await, &language.await).await;
}
} else {
default_help(ctx, msg, lm, &prefix.await, &language.await).await;
}
}
pub async fn help_interaction(ctx: &Context, interaction: Interaction) {
async fn default_help(
ctx: &Context,
interaction: Interaction,
lm: Arc<LanguageManager>,
language: &str,
) {
let desc = lm.get(language, "help/desc").replace("{prefix}", "/");
let footer = footer(ctx).await;
interaction
.create_interaction_response(ctx, |response| {
response
.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|data| {
data.embed(move |e| {
e.title("Help Menu")
.description(desc)
.field(
lm.get(language, "help/setup_title"),
"`lang` `timezone` `meridian`",
true,
)
.field(
lm.get(language, "help/mod_title"),
"`prefix` `blacklist` `restrict` `alias`",
true,
)
.field(
lm.get(language, "help/reminder_title"),
"`remind` `interval` `natural` `look` `countdown`",
true,
)
.field(
lm.get(language, "help/reminder_mod_title"),
"`del` `offset` `pause` `nudge`",
true,
)
.field(
lm.get(language, "help/info_title"),
"`help` `info` `donate` `clock`",
true,
)
.field(
lm.get(language, "help/todo_title"),
"`todo` `todos` `todoc`",
true,
)
.field(lm.get(language, "help/other_title"), "`timer`", true)
.footer(footer)
.color(*THEME_COLOR)
})
})
})
.await
.unwrap();
}
async fn command_help(
ctx: &Context,
interaction: Interaction,
lm: Arc<LanguageManager>,
language: &str,
command_name: &str,
) {
interaction
.create_interaction_response(ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|data| {
data.embed(move |e| {
e.title(format!("{} Help", command_name.to_title_case()))
.description(
lm.get(&language, &format!("help/{}", command_name))
.replace("{prefix}", "/"),
)
.footer(|f| {
f.text(concat!(
env!("CARGO_PKG_NAME"),
" ver ",
env!("CARGO_PKG_VERSION")
))
})
.color(*THEME_COLOR)
})
})
})
.await
.unwrap();
}
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(interaction.member.user.id, &pool);
if let Some(data) = &interaction.data {
if let Some(command_name) = data
.options
.first()
.map(|opt| {
opt.value
.clone()
.map(|inner| inner.as_str().unwrap().to_string())
})
.flatten()
{
let framework = ctx
.data
.read()
.await
.get::<FrameworkCtx>()
.cloned()
.expect("Could not get FrameworkCtx from data");
let matched = framework
.commands
.get(&command_name)
.map(|inner| inner.name);
if let Some(command_name) = matched {
command_help(ctx, interaction, lm, &language.await, command_name).await
} else {
default_help(ctx, interaction, lm, &language.await).await;
}
} else {
default_help(ctx, interaction, lm, &language.await).await;
}
} else {
default_help(ctx, interaction, lm, &language.await).await;
}
}
#[command]
#[aliases("invite")]
#[description("Get information about the bot")]
async fn info(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx);
async fn info(ctx: &Context, msg: &Message, _args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
let language = UserData::language_of(&msg.author, &pool);
let prefix = ctx.prefix(msg.guild_id);
let current_user = ctx.cache.current_user();
let footer = footer(ctx).await;
let desc = lm
.get(&language.await, "info")
.replacen("{user}", &current_user.await.name, 1)
.replace("{default_prefix}", &*DEFAULT_PREFIX)
.replace("{prefix}", &prefix.await);
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Info")
.description(format!(
"Help: `/help`
**Welcome to Reminder Bot!**
Developer: <@203532103185465344>
Icon: <@253202252821430272>
Find me on https://discord.jellywx.com and on https://github.com/JellyWX :)
Invite the bot: https://invite.reminder-bot.com/
Use our dashboard: https://reminder-bot.com/",
))
.description(desc)
.footer(footer)
.color(*THEME_COLOR)
}),
)
})
})
.await;
}
#[command]
#[description("Details on supporting the bot and Patreon benefits")]
#[group("Info")]
async fn donate(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx);
pub async fn info_interaction(ctx: &Context, interaction: Interaction) {
let (pool, lm) = get_ctx_data(&ctx).await;
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
let language = UserData::language_of(&interaction.member, &pool);
let current_user = ctx.cache.current_user();
let footer = footer(ctx).await;
let desc = lm
.get(&language.await, "info")
.replacen("{user}", &current_user.await.name, 1)
.replace("{default_prefix}", &*DEFAULT_PREFIX)
.replace("{prefix}", "/");
interaction
.create_interaction_response(ctx, |response| {
response
.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|data| {
data.embed(move |e| {
e.title("Info")
.description(desc)
.footer(footer)
.color(*THEME_COLOR)
})
})
})
.await
.unwrap();
}
#[command]
async fn donate(ctx: &Context, msg: &Message, _args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&msg.author, &pool).await;
let desc = lm.get(&language, "donate");
let footer = footer(ctx).await;
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Donate")
.description("Thinking of adding a monthly contribution? Click below for my Patreon and official bot server :)
**https://www.patreon.com/jellywx/**
**https://discord.jellywx.com/**
When you subscribe, Patreon will automatically rank you up on our Discord server (make sure you link your Patreon and Discord accounts!)
With your new rank, you'll be able to:
• Set repeating reminders with `interval`, `natural` or the dashboard
• Use unlimited uploads on SoundFX
(Also, members of servers you __own__ will be able to set repeating reminders via commands)
Just $2 USD/month!
*Please note, you must be in the JellyWX Discord server to receive Patreon features*")
.description(desc)
.footer(footer)
.color(*THEME_COLOR)
}),
)
})
})
.await;
}
#[command]
#[description("Get the link to the online dashboard")]
#[group("Info")]
async fn dashboard(ctx: &Context, invoke: &mut CommandInvoke) {
let footer = footer(ctx);
pub async fn donate_interaction(ctx: &Context, interaction: Interaction) {
let (pool, lm) = get_ctx_data(&ctx).await;
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().embed(|e| {
let language = UserData::language_of(&interaction.member, &pool).await;
let desc = lm.get(&language, "donate");
let footer = footer(ctx).await;
interaction
.create_interaction_response(ctx, |response| {
response
.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|data| {
data.embed(move |e| {
e.title("Donate")
.description(desc)
.footer(footer)
.color(*THEME_COLOR)
})
})
})
.await
.unwrap();
}
#[command]
async fn dashboard(ctx: &Context, msg: &Message, _args: String) {
let footer = footer(ctx).await;
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Dashboard")
.description("**https://reminder-bot.com/dashboard**")
.description("https://reminder-bot.com/dashboard")
.footer(footer)
.color(*THEME_COLOR)
}),
)
})
})
.await;
}
#[command]
#[description("View the current time in your selected timezone")]
#[group("Info")]
async fn clock(ctx: &Context, invoke: &mut CommandInvoke) {
let ud = ctx.user_data(&invoke.author_id()).await.unwrap();
let now = Utc::now().with_timezone(&ud.timezone());
async fn clock(ctx: &Context, msg: &Message, _args: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let _ = invoke
.respond(
ctx.http.clone(),
CreateGenericResponse::new().content(format!("Current time: {}", now.format("%H:%M"))),
let language = UserData::language_of(&msg.author, &pool).await;
let timezone = UserData::timezone_of(&msg.author, &pool).await;
let meridian = UserData::meridian_of(&msg.author, &pool).await;
let now = Utc::now().with_timezone(&timezone);
let clock_display = lm.get(&language, "clock/time");
let _ = msg
.channel_id
.say(
&ctx,
clock_display.replacen("{}", &now.format(meridian.fmt_str()).to_string(), 1),
)
.await;
}
pub async fn clock_interaction(ctx: &Context, interaction: Interaction) {
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&interaction.member, &pool).await;
let timezone = UserData::timezone_of(&interaction.member, &pool).await;
let meridian = UserData::meridian_of(&interaction.member, &pool).await;
let now = Utc::now().with_timezone(&timezone);
let clock_display = lm.get(&language, "clock/time");
interaction
.create_interaction_response(ctx, |response| {
response
.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|data| {
data.content(clock_display.replacen(
"{}",
&now.format(meridian.fmt_str()).to_string(),
1,
))
})
})
.await
.unwrap();
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,260 +1,443 @@
use regex_command_attr::command;
use serenity::client::Context;
use crate::{
component_models::{
pager::{Pager, TodoPager},
ComponentDataModel, TodoSelector,
use serenity::{
async_trait,
client::Context,
constants::MESSAGE_CODE_LIMIT,
model::{
channel::Message,
id::{ChannelId, GuildId, UserId},
},
consts::{EMBED_DESCRIPTION_MAX_LENGTH, SELECT_MAX_ENTRIES, THEME_COLOR},
framework::{CommandInvoke, CommandOptions, CreateGenericResponse},
hooks::CHECK_GUILD_PERMISSIONS_HOOK,
SQLPool,
};
#[command]
#[description("Manage todo lists")]
#[subcommandgroup("server")]
#[description("Manage the server todo list")]
#[subcommand("add")]
#[description("Add an item to the server todo list")]
#[arg(
name = "task",
description = "The task to add to the todo list",
kind = "String",
required = true
)]
#[subcommand("view")]
#[description("View and remove from the server todo list")]
#[subcommandgroup("channel")]
#[description("Manage the channel todo list")]
#[subcommand("add")]
#[description("Add to the channel todo list")]
#[arg(
name = "task",
description = "The task to add to the todo list",
kind = "String",
required = true
)]
#[subcommand("view")]
#[description("View and remove from the channel todo list")]
#[subcommandgroup("user")]
#[description("Manage your personal todo list")]
#[subcommand("add")]
#[description("Add to your personal todo list")]
#[arg(
name = "task",
description = "The task to add to the todo list",
kind = "String",
required = true
)]
#[subcommand("view")]
#[description("View and remove from your personal todo list")]
#[hook(CHECK_GUILD_PERMISSIONS_HOOK)]
async fn todo(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions) {
if invoke.guild_id().is_none() && args.subcommand_group != Some("user".to_string()) {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("Please use `/todo user` in direct messages"),
)
.await;
} else {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
use std::fmt;
let keys = match args.subcommand_group.as_ref().unwrap().as_str() {
"server" => (None, None, invoke.guild_id().map(|g| g.0)),
"channel" => (None, Some(invoke.channel_id().0), invoke.guild_id().map(|g| g.0)),
_ => (Some(invoke.author_id().0), None, None),
use crate::models::CtxGuildData;
use crate::{command_help, get_ctx_data, models::UserData};
use sqlx::MySqlPool;
use std::convert::TryFrom;
#[derive(Debug)]
struct TodoNotFound;
impl std::error::Error for TodoNotFound {}
impl fmt::Display for TodoNotFound {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Todo not found")
}
}
struct Todo {
id: u32,
value: String,
}
struct TodoTarget {
user: UserId,
guild: Option<GuildId>,
channel: Option<ChannelId>,
}
impl TodoTarget {
pub fn command(&self, subcommand_opt: Option<SubCommand>) -> String {
let context = if self.channel.is_some() {
"channel"
} else if self.guild.is_some() {
"guild"
} else {
"user"
};
match args.get("task") {
Some(task) => {
let task = task.to_string();
if let Some(subcommand) = subcommand_opt {
format!("todo {} {}", context, subcommand.to_string())
} else {
format!("todo {}", context)
}
}
sqlx::query!(
"INSERT INTO todos (user_id, channel_id, guild_id, value) VALUES ((SELECT id FROM users WHERE user = ?), (SELECT id FROM channels WHERE channel = ?), (SELECT id FROM guilds WHERE guild = ?), ?)",
keys.0,
keys.1,
keys.2,
task
)
.execute(&pool)
.await
.unwrap();
pub fn name(&self) -> String {
if self.channel.is_some() {
"Channel"
} else if self.guild.is_some() {
"Guild"
} else {
"User"
}
.to_string()
}
let _ = invoke
.respond(&ctx, CreateGenericResponse::new().content("Item added to todo list"))
.await;
pub async fn view(
&self,
pool: MySqlPool,
) -> Result<Vec<Todo>, Box<dyn std::error::Error + Send + Sync>> {
Ok(if let Some(cid) = self.channel {
sqlx::query_as!(
Todo,
"
SELECT id, value FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
",
cid.as_u64()
)
.fetch_all(&pool)
.await?
} else if let Some(gid) = self.guild {
sqlx::query_as!(
Todo,
"
SELECT id, value FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
",
gid.as_u64()
)
.fetch_all(&pool)
.await?
} else {
sqlx::query_as!(
Todo,
"
SELECT id, value FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
",
self.user.as_u64()
)
.fetch_all(&pool)
.await?
})
}
pub async fn add(
&self,
value: String,
pool: MySqlPool,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let (Some(cid), Some(gid)) = (self.channel, self.guild) {
sqlx::query!(
"
INSERT INTO todos (user_id, guild_id, channel_id, value) VALUES (
(SELECT id FROM users WHERE user = ?),
(SELECT id FROM guilds WHERE guild = ?),
(SELECT id FROM channels WHERE channel = ?),
?
)
",
self.user.as_u64(),
gid.as_u64(),
cid.as_u64(),
value
)
.execute(&pool)
.await?;
} else if let Some(gid) = self.guild {
sqlx::query!(
"
INSERT INTO todos (user_id, guild_id, value) VALUES (
(SELECT id FROM users WHERE user = ?),
(SELECT id FROM guilds WHERE guild = ?),
?
)
",
self.user.as_u64(),
gid.as_u64(),
value
)
.execute(&pool)
.await?;
} else {
sqlx::query!(
"
INSERT INTO todos (user_id, value) VALUES (
(SELECT id FROM users WHERE user = ?),
?
)
",
self.user.as_u64(),
value
)
.execute(&pool)
.await?;
}
Ok(())
}
pub async fn remove(
&self,
num: usize,
pool: &MySqlPool,
) -> Result<Todo, Box<dyn std::error::Error + Sync + Send>> {
let todos = self.view(pool.clone()).await?;
if let Some(removal_todo) = todos.get(num) {
let deleting = sqlx::query_as!(
Todo,
"
SELECT id, value FROM todos WHERE id = ?
",
removal_todo.id
)
.fetch_one(&pool.clone())
.await?;
sqlx::query!(
"
DELETE FROM todos WHERE id = ?
",
removal_todo.id
)
.execute(pool)
.await?;
Ok(deleting)
} else {
Err(Box::new(TodoNotFound))
}
}
pub async fn clear(
&self,
pool: &MySqlPool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
if let Some(cid) = self.channel {
sqlx::query!(
"
DELETE FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
",
cid.as_u64()
)
.execute(pool)
.await?;
} else if let Some(gid) = self.guild {
sqlx::query!(
"
DELETE FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
",
gid.as_u64()
)
.execute(pool)
.await?;
} else {
sqlx::query!(
"
DELETE FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
",
self.user.as_u64()
)
.execute(pool)
.await?;
}
Ok(())
}
async fn execute(&self, ctx: &Context, msg: &Message, subcommand: SubCommand, extra: String) {
let (pool, lm) = get_ctx_data(&ctx).await;
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let prefix = ctx.prefix(msg.guild_id).await;
match subcommand {
SubCommand::View => {
let todo_items = self.view(pool).await.unwrap();
let mut todo_groups = vec!["".to_string()];
let mut char_count = 0;
todo_items.iter().enumerate().for_each(|(count, todo)| {
let display = format!("{}: {}\n", count + 1, todo.value);
if char_count + display.len() > MESSAGE_CODE_LIMIT as usize {
char_count = display.len();
todo_groups.push(display);
} else {
char_count += display.len();
let last_group = todo_groups.pop().unwrap();
todo_groups.push(format!("{}{}", last_group, display));
}
});
for group in todo_groups {
let _ = msg
.channel_id
.send_message(&ctx, |m| {
m.embed(|e| e.title(format!("{} Todo", self.name())).description(group))
})
.await;
}
}
None => {
let values = if let Some(uid) = keys.0 {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN users ON todos.user_id = users.id
WHERE users.user = ?",
uid,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
} else if let Some(cid) = keys.1 {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN channels ON todos.channel_id = channels.id
WHERE channels.channel = ?",
cid,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
SubCommand::Add => {
let content = lm
.get(&user_data.language, "todo/added")
.replacen("{name}", &extra, 1);
self.add(extra, pool).await.unwrap();
let _ = msg.channel_id.say(&ctx, content).await;
}
SubCommand::Remove => {
if let Ok(num) = extra.parse::<usize>() {
if let Ok(todo) = self.remove(num - 1, &pool).await {
let content = lm.get(&user_data.language, "todo/removed").replacen(
"{}",
&todo.value,
1,
);
let _ = msg.channel_id.say(&ctx, content).await;
} else {
let _ = msg
.channel_id
.say(&ctx, lm.get(&user_data.language, "todo/error_index"))
.await;
}
} else {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN guilds ON todos.guild_id = guilds.id
WHERE guilds.guild = ?",
keys.2,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
};
let content = lm
.get(&user_data.language, "todo/error_value")
.replacen("{prefix}", &prefix, 1)
.replacen("{command}", &self.command(Some(subcommand)), 1);
let resp = show_todo_page(&values, 0, keys.0, keys.1, keys.2);
let _ = msg.channel_id.say(&ctx, content).await;
}
}
invoke.respond(&ctx, resp).await.unwrap();
SubCommand::Clear => {
self.clear(&pool).await.unwrap();
let content = lm.get(&user_data.language, "todo/cleared");
let _ = msg.channel_id.say(&ctx, content).await;
}
}
}
}
pub fn max_todo_page(todo_values: &[(usize, String)]) -> usize {
let mut rows = 0;
let mut char_count = 0;
todo_values.iter().enumerate().map(|(c, (_, v))| format!("{}: {}", c, v)).fold(
1,
|mut pages, text| {
rows += 1;
char_count += text.len();
if char_count > EMBED_DESCRIPTION_MAX_LENGTH || rows > SELECT_MAX_ENTRIES {
rows = 1;
char_count = text.len();
pages += 1;
}
pages
},
)
enum SubCommand {
View,
Add,
Remove,
Clear,
}
pub fn show_todo_page(
todo_values: &[(usize, String)],
page: usize,
user_id: Option<u64>,
channel_id: Option<u64>,
guild_id: Option<u64>,
) -> CreateGenericResponse {
let pager = TodoPager::new(page, user_id, channel_id, guild_id);
impl TryFrom<Option<&str>> for SubCommand {
type Error = ();
let pages = max_todo_page(todo_values);
let mut page = page;
if page >= pages {
page = pages - 1;
fn try_from(value: Option<&str>) -> Result<Self, Self::Error> {
match value {
Some("add") => Ok(SubCommand::Add),
Some("remove") => Ok(SubCommand::Remove),
Some("clear") => Ok(SubCommand::Clear),
None | Some("") => Ok(SubCommand::View),
Some(_unrecognised) => Err(()),
}
}
}
let mut char_count = 0;
let mut rows = 0;
let mut skipped_rows = 0;
let mut skipped_char_count = 0;
let mut first_num = 0;
impl ToString for SubCommand {
fn to_string(&self) -> String {
match self {
SubCommand::View => "",
SubCommand::Add => "add",
SubCommand::Remove => "remove",
SubCommand::Clear => "clear",
}
.to_string()
}
}
let mut skipped_pages = 0;
#[async_trait]
trait Execute {
async fn execute(self, ctx: &Context, msg: &Message, extra: String, target: TodoTarget);
}
let (todo_ids, display_vec): (Vec<usize>, Vec<String>) = todo_values
.iter()
.enumerate()
.map(|(c, (i, v))| (i, format!("`{}`: {}", c + 1, v)))
.skip_while(|(_, p)| {
first_num += 1;
skipped_rows += 1;
skipped_char_count += p.len();
#[async_trait]
impl Execute for Result<SubCommand, ()> {
async fn execute(self, ctx: &Context, msg: &Message, extra: String, target: TodoTarget) {
if let Ok(subcommand) = self {
target.execute(ctx, msg, subcommand, extra).await;
} else {
show_help(&ctx, msg, Some(target)).await;
}
}
}
if skipped_char_count > EMBED_DESCRIPTION_MAX_LENGTH
|| skipped_rows > SELECT_MAX_ENTRIES
{
skipped_rows = 1;
skipped_char_count = p.len();
skipped_pages += 1;
}
#[command("todo")]
async fn todo_user(ctx: &Context, msg: &Message, args: String) {
let mut split = args.split(' ');
skipped_pages < page
})
.take_while(|(_, p)| {
rows += 1;
char_count += p.len();
char_count < EMBED_DESCRIPTION_MAX_LENGTH && rows <= SELECT_MAX_ENTRIES
})
.unzip();
let display = display_vec.join("\n");
let title = if user_id.is_some() {
"Your"
} else if channel_id.is_some() {
"Channel"
} else {
"Server"
let target = TodoTarget {
user: msg.author.id,
guild: None,
channel: None,
};
if todo_ids.is_empty() {
CreateGenericResponse::new().embed(|e| {
e.title(format!("{} Todo List", title))
.description("Todo List Empty!")
.footer(|f| f.text(format!("Page {} of {}", page + 1, pages)))
.color(*THEME_COLOR)
})
} else {
let todo_selector =
ComponentDataModel::TodoSelector(TodoSelector { page, user_id, channel_id, guild_id });
let subcommand_opt = SubCommand::try_from(split.next());
CreateGenericResponse::new()
.embed(|e| {
e.title(format!("{} Todo List", title))
.description(display)
.footer(|f| f.text(format!("Page {} of {}", page + 1, pages)))
.color(*THEME_COLOR)
})
.components(|comp| {
pager.create_button_row(pages, comp);
comp.create_action_row(|row| {
row.create_select_menu(|menu| {
menu.custom_id(todo_selector.to_custom_id()).options(|opt| {
for (count, (id, disp)) in todo_ids.iter().zip(&display_vec).enumerate()
{
opt.create_option(|o| {
o.label(format!("Mark {} complete", count + first_num))
.value(id)
.description(disp.split_once(" ").unwrap_or(("", "")).1)
});
}
opt
})
})
})
})
}
subcommand_opt
.execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
.await;
}
#[command("todoc")]
#[supports_dm(false)]
#[permission_level(Managed)]
async fn todo_channel(ctx: &Context, msg: &Message, args: String) {
let mut split = args.split(' ');
let target = TodoTarget {
user: msg.author.id,
guild: msg.guild_id,
channel: Some(msg.channel_id),
};
let subcommand_opt = SubCommand::try_from(split.next());
subcommand_opt
.execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
.await;
}
#[command("todos")]
#[supports_dm(false)]
#[permission_level(Managed)]
async fn todo_guild(ctx: &Context, msg: &Message, args: String) {
let mut split = args.split(' ');
let target = TodoTarget {
user: msg.author.id,
guild: msg.guild_id,
channel: None,
};
let subcommand_opt = SubCommand::try_from(split.next());
subcommand_opt
.execute(ctx, msg, split.collect::<Vec<&str>>().join(" "), target)
.await;
}
async fn show_help(ctx: &Context, msg: &Message, target: Option<TodoTarget>) {
let (pool, lm) = get_ctx_data(&ctx).await;
let language = UserData::language_of(&msg.author, &pool);
let prefix = ctx.prefix(msg.guild_id);
let command = match target {
None => "todo",
Some(t) => {
if t.channel.is_some() {
"todoc"
} else if t.guild.is_some() {
"todos"
} else {
"todo"
}
}
};
command_help(ctx, msg, lm, &prefix.await, &language.await, command).await;
}

View File

@ -1,310 +0,0 @@
pub(crate) mod pager;
use std::io::Cursor;
use chrono_tz::Tz;
use num_integer::Integer;
use rmp_serde::Serializer;
use serde::{Deserialize, Serialize};
use serenity::{
builder::CreateEmbed,
client::Context,
model::{
channel::Channel,
interactions::{message_component::MessageComponentInteraction, InteractionResponseType},
prelude::InteractionApplicationCommandCallbackDataFlags,
},
};
use crate::{
commands::{
moderation_cmds::{max_macro_page, show_macro_page},
reminder_cmds::{max_delete_page, show_delete_page},
todo_cmds::{max_todo_page, show_todo_page},
},
component_models::pager::{DelPager, LookPager, MacroPager, Pager, TodoPager},
consts::{EMBED_DESCRIPTION_MAX_LENGTH, THEME_COLOR},
framework::CommandInvoke,
models::{command_macro::CommandMacro, reminder::Reminder},
SQLPool,
};
#[derive(Deserialize, Serialize)]
#[serde(tag = "type")]
#[repr(u8)]
pub enum ComponentDataModel {
LookPager(LookPager),
DelPager(DelPager),
TodoPager(TodoPager),
DelSelector(DelSelector),
TodoSelector(TodoSelector),
MacroPager(MacroPager),
}
impl ComponentDataModel {
pub fn to_custom_id(&self) -> String {
let mut buf = Vec::new();
self.serialize(&mut Serializer::new(&mut buf)).unwrap();
base64::encode(buf)
}
pub fn from_custom_id(data: &String) -> Self {
let buf = base64::decode(data)
.map_err(|e| format!("Could not decode `custom_id' {}: {:?}", data, e))
.unwrap();
let cur = Cursor::new(buf);
rmp_serde::from_read(cur).unwrap()
}
pub async fn act(&self, ctx: &Context, component: MessageComponentInteraction) {
match self {
ComponentDataModel::LookPager(pager) => {
let flags = pager.flags;
let channel_opt = component.channel_id.to_channel_cached(&ctx);
let channel_id = if let Some(Channel::Guild(channel)) = channel_opt {
if Some(channel.guild_id) == component.guild_id {
flags.channel_id.unwrap_or(component.channel_id)
} else {
component.channel_id
}
} else {
component.channel_id
};
let reminders = Reminder::from_channel(ctx, channel_id, &flags).await;
let pages = reminders
.iter()
.map(|reminder| reminder.display(&flags, &pager.timezone))
.fold(0, |t, r| t + r.len())
.div_ceil(&EMBED_DESCRIPTION_MAX_LENGTH);
let channel_name =
if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx) {
Some(channel.name)
} else {
None
};
let next_page = pager.next_page(pages);
let mut char_count = 0;
let mut skip_char_count = 0;
let display = reminders
.iter()
.map(|reminder| reminder.display(&flags, &pager.timezone))
.skip_while(|p| {
skip_char_count += p.len();
skip_char_count < EMBED_DESCRIPTION_MAX_LENGTH * next_page as usize
})
.take_while(|p| {
char_count += p.len();
char_count < EMBED_DESCRIPTION_MAX_LENGTH
})
.collect::<Vec<String>>()
.join("\n");
let mut embed = CreateEmbed::default();
embed
.title(format!(
"Reminders{}",
channel_name.map_or(String::new(), |n| format!(" on #{}", n))
))
.description(display)
.footer(|f| f.text(format!("Page {} of {}", next_page + 1, pages)))
.color(*THEME_COLOR);
let _ = component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::UpdateMessage).interaction_response_data(
|response| {
response.embeds(vec![embed]).components(|comp| {
pager.create_button_row(pages, comp);
comp
})
},
)
})
.await;
}
ComponentDataModel::DelPager(pager) => {
let reminders =
Reminder::from_guild(ctx, component.guild_id, component.user.id).await;
let max_pages = max_delete_page(&reminders, &pager.timezone);
let resp = show_delete_page(&reminders, pager.next_page(max_pages), pager.timezone);
let mut invoke = CommandInvoke::component(component);
let _ = invoke.respond(&ctx, resp).await;
}
ComponentDataModel::DelSelector(selector) => {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let selected_id = component.data.values.join(",");
sqlx::query!("DELETE FROM reminders WHERE FIND_IN_SET(id, ?)", selected_id)
.execute(&pool)
.await
.unwrap();
let reminders =
Reminder::from_guild(ctx, component.guild_id, component.user.id).await;
let resp = show_delete_page(&reminders, selector.page, selector.timezone);
let mut invoke = CommandInvoke::component(component);
let _ = invoke.respond(&ctx, resp).await;
}
ComponentDataModel::TodoPager(pager) => {
if Some(component.user.id.0) == pager.user_id || pager.user_id.is_none() {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let values = if let Some(uid) = pager.user_id {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN users ON todos.user_id = users.id
WHERE users.user = ?",
uid,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
} else if let Some(cid) = pager.channel_id {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN channels ON todos.channel_id = channels.id
WHERE channels.channel = ?",
cid,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
} else {
sqlx::query!(
"SELECT todos.id, value FROM todos
INNER JOIN guilds ON todos.guild_id = guilds.id
WHERE guilds.guild = ?",
pager.guild_id,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>()
};
let max_pages = max_todo_page(&values);
let resp = show_todo_page(
&values,
pager.next_page(max_pages),
pager.user_id,
pager.channel_id,
pager.guild_id,
);
let mut invoke = CommandInvoke::component(component);
let _ = invoke.respond(&ctx, resp).await;
} else {
let _ = component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.flags(
InteractionApplicationCommandCallbackDataFlags::EPHEMERAL,
)
.content("Only the user who performed the command can use these components")
})
})
.await;
}
}
ComponentDataModel::TodoSelector(selector) => {
if Some(component.user.id.0) == selector.user_id || selector.user_id.is_none() {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let selected_id = component.data.values.join(",");
sqlx::query!("DELETE FROM todos WHERE FIND_IN_SET(id, ?)", selected_id)
.execute(&pool)
.await
.unwrap();
let values = sqlx::query!(
// fucking braindead mysql use <=> instead of = for null comparison
"SELECT id, value FROM todos WHERE user_id <=> ? AND channel_id <=> ? AND guild_id <=> ?",
selector.user_id,
selector.channel_id,
selector.guild_id,
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| (row.id as usize, row.value.clone()))
.collect::<Vec<(usize, String)>>();
let resp = show_todo_page(
&values,
selector.page,
selector.user_id,
selector.channel_id,
selector.guild_id,
);
let mut invoke = CommandInvoke::component(component);
let _ = invoke.respond(&ctx, resp).await;
} else {
let _ = component
.create_interaction_response(&ctx, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.flags(
InteractionApplicationCommandCallbackDataFlags::EPHEMERAL,
)
.content("Only the user who performed the command can use these components")
})
})
.await;
}
}
ComponentDataModel::MacroPager(pager) => {
let mut invoke = CommandInvoke::component(component);
let macros = CommandMacro::from_guild(ctx, invoke.guild_id().unwrap()).await;
let max_page = max_macro_page(&macros);
let page = pager.next_page(max_page);
let resp = show_macro_page(&macros, page);
let _ = invoke.respond(&ctx, resp).await;
}
}
}
}
#[derive(Serialize, Deserialize)]
pub struct DelSelector {
pub page: usize,
pub timezone: Tz,
}
#[derive(Serialize, Deserialize)]
pub struct TodoSelector {
pub page: usize,
pub user_id: Option<u64>,
pub channel_id: Option<u64>,
pub guild_id: Option<u64>,
}

View File

@ -1,411 +0,0 @@
// todo split pager out into a single struct
use chrono_tz::Tz;
use serde::{Deserialize, Serialize};
use serde_repr::*;
use serenity::{builder::CreateComponents, model::interactions::message_component::ButtonStyle};
use crate::{component_models::ComponentDataModel, models::reminder::look_flags::LookFlags};
pub trait Pager {
fn next_page(&self, max_pages: usize) -> usize;
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents);
}
#[derive(Serialize_repr, Deserialize_repr)]
#[repr(u8)]
enum PageAction {
First = 0,
Previous = 1,
Refresh = 2,
Next = 3,
Last = 4,
}
#[derive(Serialize, Deserialize)]
pub struct LookPager {
pub flags: LookFlags,
pub page: usize,
action: PageAction,
pub timezone: Tz,
}
impl Pager for LookPager {
fn next_page(&self, max_pages: usize) -> usize {
match self.action {
PageAction::First => 0,
PageAction::Previous => 0.max(self.page - 1),
PageAction::Refresh => self.page,
PageAction::Next => (max_pages - 1).min(self.page + 1),
PageAction::Last => max_pages - 1,
}
}
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
let next_page = self.next_page(max_pages);
let (page_first, page_prev, page_refresh, page_next, page_last) =
LookPager::buttons(self.flags, next_page, self.timezone);
comp.create_action_row(|row| {
row.create_button(|b| {
b.label("⏮️")
.style(ButtonStyle::Primary)
.custom_id(page_first.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("◀️")
.style(ButtonStyle::Secondary)
.custom_id(page_prev.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
})
.create_button(|b| {
b.label("▶️")
.style(ButtonStyle::Secondary)
.custom_id(page_next.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
.create_button(|b| {
b.label("⏭️")
.style(ButtonStyle::Primary)
.custom_id(page_last.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
});
}
}
impl LookPager {
pub fn new(flags: LookFlags, timezone: Tz) -> Self {
Self { flags, page: 0, action: PageAction::First, timezone }
}
pub fn buttons(
flags: LookFlags,
page: usize,
timezone: Tz,
) -> (
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
) {
(
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::First,
timezone,
}),
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::Previous,
timezone,
}),
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::Refresh,
timezone,
}),
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::Next,
timezone,
}),
ComponentDataModel::LookPager(LookPager {
flags,
page,
action: PageAction::Last,
timezone,
}),
)
}
}
#[derive(Serialize, Deserialize)]
pub struct DelPager {
pub page: usize,
action: PageAction,
pub timezone: Tz,
}
impl Pager for DelPager {
fn next_page(&self, max_pages: usize) -> usize {
match self.action {
PageAction::First => 0,
PageAction::Previous => 0.max(self.page - 1),
PageAction::Refresh => self.page,
PageAction::Next => (max_pages - 1).min(self.page + 1),
PageAction::Last => max_pages - 1,
}
}
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
let next_page = self.next_page(max_pages);
let (page_first, page_prev, page_refresh, page_next, page_last) =
DelPager::buttons(next_page, self.timezone);
comp.create_action_row(|row| {
row.create_button(|b| {
b.label("⏮️")
.style(ButtonStyle::Primary)
.custom_id(page_first.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("◀️")
.style(ButtonStyle::Secondary)
.custom_id(page_prev.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
})
.create_button(|b| {
b.label("▶️")
.style(ButtonStyle::Secondary)
.custom_id(page_next.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
.create_button(|b| {
b.label("⏭️")
.style(ButtonStyle::Primary)
.custom_id(page_last.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
});
}
}
impl DelPager {
pub fn new(page: usize, timezone: Tz) -> Self {
Self { page, action: PageAction::Refresh, timezone }
}
pub fn buttons(
page: usize,
timezone: Tz,
) -> (
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
) {
(
ComponentDataModel::DelPager(DelPager { page, action: PageAction::First, timezone }),
ComponentDataModel::DelPager(DelPager { page, action: PageAction::Previous, timezone }),
ComponentDataModel::DelPager(DelPager { page, action: PageAction::Refresh, timezone }),
ComponentDataModel::DelPager(DelPager { page, action: PageAction::Next, timezone }),
ComponentDataModel::DelPager(DelPager { page, action: PageAction::Last, timezone }),
)
}
}
#[derive(Deserialize, Serialize)]
pub struct TodoPager {
pub page: usize,
action: PageAction,
pub user_id: Option<u64>,
pub channel_id: Option<u64>,
pub guild_id: Option<u64>,
}
impl Pager for TodoPager {
fn next_page(&self, max_pages: usize) -> usize {
match self.action {
PageAction::First => 0,
PageAction::Previous => 0.max(self.page - 1),
PageAction::Refresh => self.page,
PageAction::Next => (max_pages - 1).min(self.page + 1),
PageAction::Last => max_pages - 1,
}
}
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
let next_page = self.next_page(max_pages);
let (page_first, page_prev, page_refresh, page_next, page_last) =
TodoPager::buttons(next_page, self.user_id, self.channel_id, self.guild_id);
comp.create_action_row(|row| {
row.create_button(|b| {
b.label("⏮️")
.style(ButtonStyle::Primary)
.custom_id(page_first.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("◀️")
.style(ButtonStyle::Secondary)
.custom_id(page_prev.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
})
.create_button(|b| {
b.label("▶️")
.style(ButtonStyle::Secondary)
.custom_id(page_next.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
.create_button(|b| {
b.label("⏭️")
.style(ButtonStyle::Primary)
.custom_id(page_last.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
});
}
}
impl TodoPager {
pub fn new(
page: usize,
user_id: Option<u64>,
channel_id: Option<u64>,
guild_id: Option<u64>,
) -> Self {
Self { page, action: PageAction::Refresh, user_id, channel_id, guild_id }
}
pub fn buttons(
page: usize,
user_id: Option<u64>,
channel_id: Option<u64>,
guild_id: Option<u64>,
) -> (
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
) {
(
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::First,
user_id,
channel_id,
guild_id,
}),
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::Previous,
user_id,
channel_id,
guild_id,
}),
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::Refresh,
user_id,
channel_id,
guild_id,
}),
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::Next,
user_id,
channel_id,
guild_id,
}),
ComponentDataModel::TodoPager(TodoPager {
page,
action: PageAction::Last,
user_id,
channel_id,
guild_id,
}),
)
}
}
#[derive(Serialize, Deserialize)]
pub struct MacroPager {
pub page: usize,
action: PageAction,
}
impl Pager for MacroPager {
fn next_page(&self, max_pages: usize) -> usize {
match self.action {
PageAction::First => 0,
PageAction::Previous => 0.max(self.page - 1),
PageAction::Refresh => self.page,
PageAction::Next => (max_pages - 1).min(self.page + 1),
PageAction::Last => max_pages - 1,
}
}
fn create_button_row(&self, max_pages: usize, comp: &mut CreateComponents) {
let next_page = self.next_page(max_pages);
let (page_first, page_prev, page_refresh, page_next, page_last) =
MacroPager::buttons(next_page);
comp.create_action_row(|row| {
row.create_button(|b| {
b.label("⏮️")
.style(ButtonStyle::Primary)
.custom_id(page_first.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("◀️")
.style(ButtonStyle::Secondary)
.custom_id(page_prev.to_custom_id())
.disabled(next_page == 0)
})
.create_button(|b| {
b.label("🔁").style(ButtonStyle::Secondary).custom_id(page_refresh.to_custom_id())
})
.create_button(|b| {
b.label("▶️")
.style(ButtonStyle::Secondary)
.custom_id(page_next.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
.create_button(|b| {
b.label("⏭️")
.style(ButtonStyle::Primary)
.custom_id(page_last.to_custom_id())
.disabled(next_page + 1 == max_pages)
})
});
}
}
impl MacroPager {
pub fn new(page: usize) -> Self {
Self { page, action: PageAction::Refresh }
}
pub fn buttons(
page: usize,
) -> (
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
ComponentDataModel,
) {
(
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::First }),
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Previous }),
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Refresh }),
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Next }),
ComponentDataModel::MacroPager(MacroPager { page, action: PageAction::Last }),
)
}
}

View File

@ -1,8 +1,6 @@
pub const DAY: u64 = 86_400;
pub const HOUR: u64 = 3_600;
pub const MINUTE: u64 = 60;
pub const EMBED_DESCRIPTION_MAX_LENGTH: usize = 4000;
pub const SELECT_MAX_ENTRIES: usize = 25;
pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
@ -10,50 +8,83 @@ const THEME_COLOR_FALLBACK: u32 = 0x8fb677;
use std::{collections::HashSet, env, iter::FromIterator};
use regex::Regex;
use serenity::http::AttachmentType;
use regex::{Regex, RegexBuilder};
lazy_static! {
pub static ref REMIND_INTERVAL: u64 = env::var("REMIND_INTERVAL")
.map(|inner| inner.parse::<u64>().ok())
.ok()
.flatten()
.unwrap_or(10);
pub static ref DEFAULT_AVATAR: AttachmentType<'static> = (
include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/assets/",
env!("WEBHOOK_AVATAR", "WEBHOOK_AVATAR not provided for compilation")
)) as &[u8],
env!("WEBHOOK_AVATAR"),
)
.into();
pub static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
pub static ref REGEX_ROLE: Regex = Regex::new(r#"<@&(\d+)>"#).unwrap();
pub static ref REGEX_COMMANDS: Regex = Regex::new(r#"([a-z]+)"#).unwrap();
pub static ref REGEX_ALIAS: Regex =
Regex::new(r#"(?P<name>[\S]{1,12})(?:(?: (?P<cmd>.*)$)|$)"#).unwrap();
pub static ref REGEX_CONTENT_SUBSTITUTION: Regex = Regex::new(r#"<<((?P<user>\d+)|(?P<role>.{1,100}))>>"#).unwrap();
pub static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"\s*<(#|@)(?:!)?(\d+)>\s*"#).unwrap();
pub static ref REGEX_REMIND_COMMAND: Regex = RegexBuilder::new(
r#"(?P<mentions>(?:<@\d+>\s+|<@!\d+>\s+|<#\d+>\s+)*)(?P<time>(?:(?:\d+)(?:s|m|h|d|:|/|-|))+)(?:\s+(?P<interval>(?:(?:\d+)(?:s|m|h|d|))+))?(?:\s+(?P<expires>(?:(?:\d+)(?:s|m|h|d|:|/|-|))+))?\s+(?P<content>.*)"#
)
.dot_matches_new_line(true)
.build()
.unwrap();
pub static ref REGEX_NATURAL_COMMAND_1: Regex = RegexBuilder::new(
r#"(?P<time>.*?)(?:\s+)(?:send|say)(?:\s+)(?P<msg>.*?)(?:(?:\s+)to(?:\s+)(?P<mentions>((?:<@\d+>)|(?:<@!\d+>)|(?:<#\d+>)|(?:\s+))+))?$"#
)
.dot_matches_new_line(true)
.build()
.unwrap();
pub static ref REGEX_NATURAL_COMMAND_2: Regex = RegexBuilder::new(
r#"(?P<msg>.*)(?:\s+)every(?:\s+)(?P<interval>.*?)(?:(?:\s+)(?:until|for)(?:\s+)(?P<expires>.*?))?$"#
)
.dot_matches_new_line(true)
.build()
.unwrap();
pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(
env::var("SUBSCRIPTION_ROLES")
.map(|var| var
.split(',')
.filter_map(|item| { item.parse::<u64>().ok() })
.collect::<Vec<u64>>())
.unwrap_or_else(|_| Vec::new())
.unwrap_or_else(|_| vec![])
);
pub static ref CNC_GUILD: Option<u64> =
env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten();
pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD")
.map(|var| var.parse::<u64>().ok())
.ok()
.flatten();
pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL")
.ok()
.map(|inner| inner.parse::<i64>().ok())
.flatten()
.unwrap_or(600);
pub static ref MAX_TIME: i64 = env::var("MAX_TIME")
.ok()
.map(|inner| inner.parse::<i64>().ok())
.flatten()
.unwrap_or(60 * 60 * 24 * 365 * 50);
pub static ref LOCAL_TIMEZONE: String =
env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string());
pub static ref THEME_COLOR: u32 = env::var("THEME_COLOR")
.map_or(THEME_COLOR_FALLBACK, |inner| u32::from_str_radix(&inner, 16)
.unwrap_or(THEME_COLOR_FALLBACK));
pub static ref LOCAL_LANGUAGE: String =
env::var("LOCAL_LANGUAGE").unwrap_or_else(|_| "EN".to_string());
pub static ref DEFAULT_PREFIX: String =
env::var("DEFAULT_PREFIX").unwrap_or_else(|_| "$".to_string());
pub static ref THEME_COLOR: u32 = env::var("THEME_COLOR").map_or(
THEME_COLOR_FALLBACK,
|inner| u32::from_str_radix(&inner, 16).unwrap_or(THEME_COLOR_FALLBACK)
);
pub static ref PYTHON_LOCATION: String =
env::var("PYTHON_LOCATION").unwrap_or_else(|_| "venv/bin/python3".to_string());
}

File diff suppressed because it is too large Load Diff

View File

@ -1,152 +0,0 @@
use regex_command_attr::check;
use serenity::{client::Context, model::channel::Channel};
use crate::{
framework::{CommandInvoke, CommandOptions, CreateGenericResponse, HookResult},
moderation_cmds, RecordingMacros,
};
#[check]
pub async fn guild_only(
ctx: &Context,
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if invoke.guild_id().is_some() {
HookResult::Continue
} else {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("This command can only be used in servers"),
)
.await;
HookResult::Halt
}
}
#[check]
pub async fn macro_check(
ctx: &Context,
invoke: &mut CommandInvoke,
args: &CommandOptions,
) -> HookResult {
if let Some(guild_id) = invoke.guild_id() {
if args.command != moderation_cmds::MACRO_CMD_COMMAND.names[0] {
let active_recordings =
ctx.data.read().await.get::<RecordingMacros>().cloned().unwrap();
let mut lock = active_recordings.write().await;
if let Some(command_macro) = lock.get_mut(&(guild_id, invoke.author_id())) {
if command_macro.commands.len() >= 5 {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("5 commands already recorded. Please use `/macro finish` to end recording."),
)
.await;
} else {
command_macro.commands.push(args.clone());
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content("Command recorded to macro"),
)
.await;
}
HookResult::Halt
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
}
#[check]
pub async fn check_self_permissions(
ctx: &Context,
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let user_id = ctx.cache.current_user_id();
let manage_webhooks =
guild.member_permissions(&ctx, user_id).await.map_or(false, |p| p.manage_webhooks());
let (view_channel, send_messages, embed_links) = invoke
.channel_id()
.to_channel_cached(&ctx)
.map(|c| {
if let Channel::Guild(channel) = c {
channel.permissions_for_user(ctx, user_id).ok()
} else {
None
}
})
.flatten()
.map_or((false, false, false), |p| {
(p.read_messages(), p.send_messages(), p.embed_links())
});
if manage_webhooks && send_messages && embed_links {
HookResult::Continue
} else {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(format!(
"Please ensure the bot has the correct permissions:
{} **View Channel**
{} **Send Message**
{} **Embed Links**
{} **Manage Webhooks**",
if view_channel { "" } else { "" },
if send_messages { "" } else { "" },
if manage_webhooks { "" } else { "" },
if embed_links { "" } else { "" },
)),
)
.await;
HookResult::Halt
}
} else {
HookResult::Continue
}
}
#[check]
pub async fn check_guild_permissions(
ctx: &Context,
invoke: &mut CommandInvoke,
_args: &CommandOptions,
) -> HookResult {
if let Some(guild) = invoke.guild(&ctx) {
let permissions = guild.member_permissions(&ctx, invoke.author_id()).await.unwrap();
if !permissions.manage_guild() {
let _ = invoke
.respond(
&ctx,
CreateGenericResponse::new().content(
"You must have the \"Manage Server\" permission to use this command",
),
)
.await;
HookResult::Halt
} else {
HookResult::Continue
}
} else {
HookResult::Continue
}
}

65
src/language_manager.rs Normal file
View File

@ -0,0 +1,65 @@
use serde::Deserialize;
use serde_json::from_str;
use serenity::prelude::TypeMapKey;
use std::{collections::HashMap, error::Error, sync::Arc};
use crate::consts::LOCAL_LANGUAGE;
#[derive(Deserialize)]
pub struct LanguageManager {
languages: HashMap<String, String>,
strings: HashMap<String, HashMap<String, String>>,
}
impl LanguageManager {
pub fn from_compiled(content: &'static str) -> Result<Self, Box<dyn Error + Send + Sync>> {
let new: Self = from_str(content.as_ref())?;
Ok(new)
}
pub fn get(&self, language: &str, name: &str) -> &str {
self.strings
.get(language)
.map(|sm| sm.get(name))
.expect(&format!(r#"Language does not exist: "{}""#, language))
.unwrap_or_else(|| {
self.strings
.get(&*LOCAL_LANGUAGE)
.map(|sm| {
sm.get(name)
.expect(&format!(r#"String does not exist: "{}""#, name))
})
.expect("LOCAL_LANGUAGE is not available")
})
}
pub fn get_language(&self, language: &str) -> Option<&str> {
let language_normal = language.to_lowercase();
self.languages
.iter()
.filter(|(k, v)| {
k.to_lowercase() == language_normal || v.to_lowercase() == language_normal
})
.map(|(k, _)| k.as_str())
.next()
}
pub fn get_language_by_flag(&self, flag: &str) -> Option<&str> {
self.languages
.iter()
.filter(|(k, _)| self.get(k, "flag") == flag)
.map(|(k, _)| k.as_str())
.next()
}
pub fn all_languages(&self) -> impl Iterator<Item = (&str, &str)> {
self.languages.iter().map(|(k, v)| (k.as_str(), v.as_str()))
}
}
impl TypeMapKey for LanguageManager {
type Value = Arc<Self>;
}

View File

@ -1,56 +1,62 @@
#![feature(int_roundings)]
#[macro_use]
extern crate lazy_static;
mod commands;
mod component_models;
mod consts;
mod framework;
mod hooks;
mod language_manager;
mod models;
mod sender;
mod time_parser;
use std::{
collections::HashMap,
env,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use chrono_tz::Tz;
use dotenv::dotenv;
use log::info;
use serenity::{
async_trait,
cache::Cache,
client::{bridge::gateway::GatewayIntents, Client},
http::{client::Http, CacheHttp},
model::{
channel::GuildChannel,
gateway::{Activity, Ready},
channel::Message,
guild::{Guild, GuildUnavailable},
id::{GuildId, UserId},
interactions::Interaction,
},
prelude::{Context, EventHandler, TypeMapKey},
utils::shard_id,
};
use sqlx::mysql::MySqlPool;
use tokio::{
sync::RwLock,
time::{Duration, Instant},
};
use dotenv::dotenv;
use std::{collections::HashMap, env, sync::Arc};
use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds},
component_models::ComponentDataModel,
consts::{CNC_GUILD, REMIND_INTERVAL, SUBSCRIPTION_ROLES, THEME_COLOR},
consts::{CNC_GUILD, DEFAULT_PREFIX, SUBSCRIPTION_ROLES, THEME_COLOR},
framework::RegexFramework,
models::command_macro::CommandMacro,
language_manager::LanguageManager,
models::GuildData,
};
use serenity::futures::TryFutureExt;
use inflector::Inflector;
use log::info;
use dashmap::DashMap;
use tokio::sync::RwLock;
use chrono_tz::Tz;
use serenity::model::interactions::{Interaction, InteractionType};
use serenity::model::prelude::ApplicationCommandOptionType;
use std::collections::HashSet;
struct GuildDataCache;
impl TypeMapKey for GuildDataCache {
type Value = Arc<DashMap<GuildId, Arc<RwLock<GuildData>>>>;
}
struct SQLPool;
impl TypeMapKey for SQLPool {
@ -63,54 +69,22 @@ impl TypeMapKey for ReqwestClient {
type Value = Arc<reqwest::Client>;
}
struct FrameworkCtx;
impl TypeMapKey for FrameworkCtx {
type Value = Arc<RegexFramework>;
}
struct PopularTimezones;
impl TypeMapKey for PopularTimezones {
type Value = Arc<Vec<Tz>>;
}
struct RecordingMacros;
impl TypeMapKey for RecordingMacros {
type Value = Arc<RwLock<HashMap<(GuildId, UserId), CommandMacro>>>;
}
struct Handler {
is_loop_running: AtomicBool,
}
struct Handler;
#[async_trait]
impl EventHandler for Handler {
async fn cache_ready(&self, ctx_base: Context, _guilds: Vec<GuildId>) {
info!("Cache Ready!");
info!("Preparing to send reminders");
if !self.is_loop_running.load(Ordering::Relaxed) {
let ctx = ctx_base.clone();
tokio::spawn(async move {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
loop {
let sleep_until = Instant::now() + Duration::from_secs(*REMIND_INTERVAL);
let reminders = sender::Reminder::fetch_reminders(&pool).await;
if reminders.len() > 0 {
info!("Preparing to send {} reminders.", reminders.len());
for reminder in reminders {
reminder.send(pool.clone(), ctx.clone()).await;
}
}
tokio::time::sleep_until(sleep_until).await;
}
});
self.is_loop_running.swap(true, Ordering::Relaxed);
}
}
async fn channel_delete(&self, ctx: Context, channel: &GuildChannel) {
let pool = ctx
.data
@ -136,20 +110,28 @@ DELETE FROM channels WHERE channel = ?
let guild_id = guild.id.as_u64().to_owned();
{
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let _ = sqlx::query!("INSERT INTO guilds (guild) VALUES (?)", guild_id)
.execute(&pool)
.await;
GuildData::from_guild(guild, &pool).await.expect(&format!(
"Failed to create new guild object for {}",
guild_id
));
}
if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
let shard_count = ctx.cache.shard_count();
let shard_count = ctx.cache.shard_count().await;
let current_shard_id = shard_id(guild_id, shard_count);
let guild_count = ctx
.cache
.guilds()
.await
.iter()
.filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id)
.count() as u64;
@ -171,7 +153,7 @@ DELETE FROM channels WHERE channel = ?
.post(
format!(
"https://top.gg/api/bots/{}/stats",
ctx.cache.current_user_id().as_u64()
ctx.cache.current_user_id().await.as_u64()
)
.as_str(),
)
@ -187,34 +169,55 @@ DELETE FROM channels WHERE channel = ?
}
}
async fn guild_delete(&self, ctx: Context, incomplete: GuildUnavailable, _full: Option<Guild>) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0)
.execute(&pool)
.await;
}
async fn guild_delete(&self, ctx: Context, guild: GuildUnavailable, _guild: Option<Guild>) {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
async fn ready(&self, ctx: Context, _: Ready) {
ctx.set_activity(Activity::watching("for /remind")).await;
let guild_data_cache = ctx
.data
.read()
.await
.get::<GuildDataCache>()
.cloned()
.unwrap();
guild_data_cache.remove(&guild.id);
sqlx::query!(
"
DELETE FROM guilds WHERE guild = ?
",
guild.id.as_u64()
)
.execute(&pool)
.await
.unwrap();
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
match interaction {
Interaction::ApplicationCommand(application_command) => {
let framework = ctx
.data
.read()
.await
.get::<RegexFramework>()
.cloned()
.expect("RegexFramework not found in context");
match interaction.kind {
InteractionType::ApplicationCommand => {
if let Some(data) = &interaction.data {
match data.name.as_str() {
"timezone" => {
moderation_cmds::timezone_interaction(&ctx, interaction).await
}
"lang" => moderation_cmds::language_interaction(&ctx, interaction).await,
"prefix" => moderation_cmds::prefix_interaction(&ctx, interaction).await,
"help" => info_cmds::help_interaction(&ctx, interaction).await,
"info" => info_cmds::info_interaction(&ctx, interaction).await,
"donate" => info_cmds::donate_interaction(&ctx, interaction).await,
"clock" => info_cmds::clock_interaction(&ctx, interaction).await,
"remind" => reminder_cmds::set_reminder(&ctx, interaction).await,
_ => {}
}
}
}
framework.execute(ctx, application_command).await;
}
Interaction::MessageComponent(component) => {
let component_model = ComponentDataModel::from_custom_id(&component.data.custom_id);
component_model.act(&ctx, component).await;
}
_ => {}
}
}
@ -228,53 +231,94 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let token = env::var("DISCORD_TOKEN").expect("Missing DISCORD_TOKEN from environment");
let application_id = {
let http = Http::new_with_token(&token);
let http = Http::new_with_token(&token);
http.get_current_application_info().await?.id
};
let logged_in_id = http
.get_current_user()
.map_ok(|user| user.id.as_u64().to_owned())
.await?;
let dm_enabled = env::var("DM_ENABLED").map_or(true, |var| var == "1");
let framework = RegexFramework::new()
let framework = RegexFramework::new(logged_in_id)
.default_prefix(DEFAULT_PREFIX.clone())
.case_insensitive(env::var("CASE_INSENSITIVE").map_or(true, |var| var == "1"))
.ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
.debug_guild(env::var("DEBUG_GUILD").map_or(None, |g| {
Some(GuildId(g.parse::<u64>().expect("DEBUG_GUILD must be a guild ID")))
}))
.dm_enabled(dm_enabled)
// info commands
.add_command(&info_cmds::HELP_COMMAND)
.add_command(&info_cmds::INFO_COMMAND)
.add_command(&info_cmds::DONATE_COMMAND)
.add_command(&info_cmds::DASHBOARD_COMMAND)
.add_command(&info_cmds::CLOCK_COMMAND)
.add_command("ping", &info_cmds::PING_COMMAND)
.add_command("help", &info_cmds::HELP_COMMAND)
.add_command("info", &info_cmds::INFO_COMMAND)
.add_command("invite", &info_cmds::INFO_COMMAND)
.add_command("donate", &info_cmds::DONATE_COMMAND)
.add_command("dashboard", &info_cmds::DASHBOARD_COMMAND)
.add_command("clock", &info_cmds::CLOCK_COMMAND)
// reminder commands
.add_command(&reminder_cmds::TIMER_COMMAND)
.add_command(&reminder_cmds::REMIND_COMMAND)
.add_command("timer", &reminder_cmds::TIMER_COMMAND)
.add_command("remind", &reminder_cmds::REMIND_COMMAND)
.add_command("r", &reminder_cmds::REMIND_COMMAND)
.add_command("interval", &reminder_cmds::INTERVAL_COMMAND)
.add_command("i", &reminder_cmds::INTERVAL_COMMAND)
.add_command("natural", &reminder_cmds::NATURAL_COMMAND)
.add_command("n", &reminder_cmds::NATURAL_COMMAND)
.add_command("", &reminder_cmds::NATURAL_COMMAND)
.add_command("countdown", &reminder_cmds::COUNTDOWN_COMMAND)
// management commands
.add_command(&reminder_cmds::DELETE_COMMAND)
.add_command(&reminder_cmds::LOOK_COMMAND)
.add_command(&reminder_cmds::PAUSE_COMMAND)
.add_command(&reminder_cmds::OFFSET_COMMAND)
.add_command(&reminder_cmds::NUDGE_COMMAND)
.add_command("look", &reminder_cmds::LOOK_COMMAND)
.add_command("del", &reminder_cmds::DELETE_COMMAND)
// to-do commands
.add_command(&todo_cmds::TODO_COMMAND)
.add_command("todo", &todo_cmds::TODO_USER_COMMAND)
.add_command("todo user", &todo_cmds::TODO_USER_COMMAND)
.add_command("todoc", &todo_cmds::TODO_CHANNEL_COMMAND)
.add_command("todo channel", &todo_cmds::TODO_CHANNEL_COMMAND)
.add_command("todos", &todo_cmds::TODO_GUILD_COMMAND)
.add_command("todo server", &todo_cmds::TODO_GUILD_COMMAND)
.add_command("todo guild", &todo_cmds::TODO_GUILD_COMMAND)
// moderation commands
.add_command(&moderation_cmds::TIMEZONE_COMMAND)
.add_command(&moderation_cmds::MACRO_CMD_COMMAND)
.add_hook(&hooks::CHECK_SELF_PERMISSIONS_HOOK)
.add_hook(&hooks::MACRO_CHECK_HOOK);
.add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND)
.add_command("restrict", &moderation_cmds::RESTRICT_COMMAND)
.add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND)
.add_command("meridian", &moderation_cmds::CHANGE_MERIDIAN_COMMAND)
.add_command("prefix", &moderation_cmds::PREFIX_COMMAND)
.add_command("lang", &moderation_cmds::LANGUAGE_COMMAND)
.add_command("pause", &reminder_cmds::PAUSE_COMMAND)
.add_command("offset", &reminder_cmds::OFFSET_COMMAND)
.add_command("nudge", &reminder_cmds::NUDGE_COMMAND)
.add_command("alias", &moderation_cmds::ALIAS_COMMAND)
.add_command("a", &moderation_cmds::ALIAS_COMMAND)
.build();
let framework_arc = Arc::new(framework);
let mut client = Client::builder(&token)
.intents(GatewayIntents::GUILDS)
.application_id(application_id.0)
.event_handler(Handler { is_loop_running: AtomicBool::from(false) })
.intents(if dm_enabled {
GatewayIntents::GUILD_MESSAGES
| GatewayIntents::GUILDS
| GatewayIntents::GUILD_MESSAGE_REACTIONS
| GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::DIRECT_MESSAGE_REACTIONS
} else {
GatewayIntents::GUILD_MESSAGES
| GatewayIntents::GUILDS
| GatewayIntents::GUILD_MESSAGE_REACTIONS
})
.event_handler(Handler)
.framework_arc(framework_arc.clone())
.await
.expect("Error occurred creating client");
let language_manager = Arc::new(
LanguageManager::from_compiled(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/assets/",
env!("STRINGS_FILE")
)))
.unwrap(),
);
{
let guild_data_cache = dashmap::DashMap::new();
let pool = MySqlPool::connect(
&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment"),
)
@ -293,18 +337,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut data = client.data.write().await;
data.insert::<GuildDataCache>(Arc::new(guild_data_cache));
data.insert::<SQLPool>(pool);
data.insert::<PopularTimezones>(Arc::new(popular_timezones));
data.insert::<ReqwestClient>(Arc::new(reqwest::Client::new()));
data.insert::<RegexFramework>(framework_arc.clone());
data.insert::<RecordingMacros>(Arc::new(RwLock::new(HashMap::new())));
data.insert::<FrameworkCtx>(framework_arc.clone());
data.insert::<LanguageManager>(language_manager.clone())
}
framework_arc.build_slash(&client.cache_and_http.http).await;
create_interactions(
&client.cache_and_http,
framework_arc.clone(),
language_manager.clone(),
)
.await;
if let Ok((Some(lower), Some(upper))) = env::var("SHARD_RANGE").map(|sr| {
let mut split =
sr.split(',').map(|val| val.parse::<u64>().expect("SHARD_RANGE not an integer"));
let mut split = sr
.split(',')
.map(|val| val.parse::<u64>().expect("SHARD_RANGE not an integer"));
(split.next(), split.next())
}) {
@ -314,14 +366,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.flatten()
.expect("No SHARD_COUNT provided, but SHARD_RANGE was provided");
assert!(lower < upper, "SHARD_RANGE lower limit is not less than the upper limit");
assert!(
lower < upper,
"SHARD_RANGE lower limit is not less than the upper limit"
);
info!("Starting client fragment with shards {}-{}/{}", lower, upper, total_shards);
info!(
"Starting client fragment with shards {}-{}/{}",
lower, upper, total_shards
);
client.start_shard_range([lower, upper], total_shards).await?;
} else if let Ok(total_shards) = env::var("SHARD_COUNT")
.map(|shard_count| shard_count.parse::<u64>().expect("SHARD_COUNT not an integer"))
{
client
.start_shard_range([lower, upper], total_shards)
.await?;
} else if let Ok(total_shards) = env::var("SHARD_COUNT").map(|shard_count| {
shard_count
.parse::<u64>()
.expect("SHARD_COUNT not an integer")
}) {
info!("Starting client with {} shards", total_shards);
client.start_shards(total_shards).await?;
@ -334,9 +396,170 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
async fn create_interactions(
cache_http: impl CacheHttp,
framework: Arc<RegexFramework>,
lm: Arc<LanguageManager>,
) {
let http = cache_http.http();
let app_id = {
let app_info = http.get_current_application_info().await.unwrap();
app_info.id.as_u64().to_owned()
};
if let Some(guild_id) = env::var("TEST_GUILD")
.map(|i| i.parse::<u64>().ok().map(|u| GuildId(u)))
.ok()
.flatten()
{
guild_id
.create_application_command(&http, app_id, |command| {
command
.name("timezone")
.description("Select your local timezone. Do `/timezone` for more information")
.create_interaction_option(|option| {
option
.name("region")
.description("Name of your time region")
.kind(ApplicationCommandOptionType::String)
})
})
.await
.unwrap();
guild_id
.create_application_command(&http, app_id, |command| {
command
.name("lang")
.description("Select your language")
.create_interaction_option(|option| {
option
.name("language")
.description("Name of supported language you wish to use")
.kind(ApplicationCommandOptionType::String)
.required(true);
for (code, language) in lm.all_languages() {
option.add_string_choice(language, code);
}
option
})
})
.await
.unwrap();
guild_id
.create_application_command(&http, app_id, |command| {
command
.name("prefix")
.description("Select the prefix for normal commands")
.create_interaction_option(|option| {
option
.name("prefix")
.description("New prefix to use")
.kind(ApplicationCommandOptionType::String)
.required(true)
})
})
.await
.unwrap();
guild_id
.create_application_command(&http, app_id, |command| {
command
.name("info")
.description("Get information about the bot")
})
.await
.unwrap();
guild_id
.create_application_command(&http, app_id, |command| {
command
.name("donate")
.description("View information about the Patreon")
})
.await
.unwrap();
guild_id
.create_application_command(&http, app_id, |command| {
command
.name("clock")
.description("View the current time in your timezone")
})
.await
.unwrap();
guild_id
.create_application_command(&http, app_id, |command| {
command
.name("help")
.description("Get details about commands. Do `/help` to view all commands")
.create_interaction_option(|option| {
option
.name("command")
.description("Name of the command to view help for")
.kind(ApplicationCommandOptionType::String);
let mut command_set = HashSet::new();
command_set.insert("help");
command_set.insert("info");
command_set.insert("donate");
for (_, command) in &framework.commands {
if !command_set.contains(command.name) {
option.add_string_choice(&command.name, &command.name);
command_set.insert(command.name);
}
}
option
})
})
.await
.unwrap();
guild_id
.create_application_command(&http, app_id, |command| {
command
.name("remind")
.description("Set a reminder")
.create_interaction_option(|option| {
option
.name("message")
.description("Message to send with the reminder")
.kind(ApplicationCommandOptionType::String)
.required(true)
})
.create_interaction_option(|option| {
option
.name("time")
.description("Time to send the reminder")
.kind(ApplicationCommandOptionType::String)
.required(true)
})
.create_interaction_option(|option| {
option
.name("channel")
.description("Channel to send reminder to (default: this channel)")
.kind(ApplicationCommandOptionType::Channel)
.required(false)
})
})
.await
.unwrap();
}
}
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await;
let guild_member = GuildId(subscription_guild)
.member(cache_http, user_id)
.await;
if let Ok(member) = guild_member {
for role in member.roles {
@ -352,15 +575,65 @@ pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<U
}
}
pub async fn check_guild_subscription(
cache_http: impl CacheHttp,
guild_id: impl Into<GuildId>,
pub async fn check_subscription_on_message(
cache_http: impl CacheHttp + AsRef<Cache>,
msg: &Message,
) -> bool {
if let Some(guild) = cache_http.cache().unwrap().guild(guild_id) {
let owner = guild.owner_id;
check_subscription(&cache_http, owner).await
} else {
false
}
check_subscription(&cache_http, &msg.author).await
|| if let Some(guild) = msg.guild(&cache_http).await {
check_subscription(&cache_http, guild.owner_id).await
} else {
false
}
}
pub async fn get_ctx_data(ctx: &&Context) -> (MySqlPool, Arc<LanguageManager>) {
let pool;
let lm;
{
let data = ctx.data.read().await;
pool = data
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool");
lm = data
.get::<LanguageManager>()
.cloned()
.expect("Could not get LanguageManager");
}
(pool, lm)
}
async fn command_help(
ctx: &Context,
msg: &Message,
lm: Arc<LanguageManager>,
prefix: &str,
language: &str,
command_name: &str,
) {
let _ = msg
.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title(format!("{} Help", command_name.to_title_case()))
.description(
lm.get(&language, &format!("help/{}", command_name))
.replace("{prefix}", &prefix),
)
.footer(|f| {
f.text(concat!(
env!("CARGO_PKG_NAME"),
" ver ",
env!("CARGO_PKG_VERSION")
))
})
.color(*THEME_COLOR)
})
})
.await;
}

452
src/models.rs Normal file
View File

@ -0,0 +1,452 @@
use serenity::{
async_trait,
http::CacheHttp,
model::{
channel::Channel,
guild::Guild,
id::{GuildId, UserId},
user::User,
},
prelude::Context,
};
use sqlx::MySqlPool;
use chrono::NaiveDateTime;
use chrono_tz::Tz;
use log::error;
use crate::{
consts::{DEFAULT_PREFIX, LOCAL_LANGUAGE, LOCAL_TIMEZONE},
GuildDataCache, SQLPool,
};
use std::sync::Arc;
use tokio::sync::RwLock;
#[async_trait]
pub trait CtxGuildData {
async fn guild_data<G: Into<GuildId> + Send + Sync>(
&self,
guild_id: G,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error>;
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String;
}
#[async_trait]
impl CtxGuildData for Context {
async fn guild_data<G: Into<GuildId> + Send + Sync>(
&self,
guild_id: G,
) -> Result<Arc<RwLock<GuildData>>, sqlx::Error> {
let guild_id = guild_id.into();
let guild = guild_id.to_guild_cached(&self.cache).await.unwrap();
let guild_cache = self
.data
.read()
.await
.get::<GuildDataCache>()
.cloned()
.unwrap();
let x = if let Some(guild_data) = guild_cache.get(&guild_id) {
Ok(guild_data.clone())
} else {
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
match GuildData::from_guild(guild, &pool).await {
Ok(d) => {
let lock = Arc::new(RwLock::new(d));
guild_cache.insert(guild_id, lock.clone());
Ok(lock)
}
Err(e) => Err(e),
}
};
x
}
async fn prefix<G: Into<GuildId> + Send + Sync>(&self, guild_id: Option<G>) -> String {
if let Some(guild_id) = guild_id {
self.guild_data(guild_id)
.await
.unwrap()
.read()
.await
.prefix
.clone()
} else {
DEFAULT_PREFIX.clone()
}
}
}
pub struct GuildData {
pub id: u32,
pub name: Option<String>,
pub prefix: String,
}
impl GuildData {
pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result<Self, sqlx::Error> {
let guild_id = guild.id.as_u64().to_owned();
match sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
",
guild_id
)
.fetch_one(pool)
.await
{
Ok(mut g) => {
g.name = Some(guild.name);
Ok(g)
}
Err(sqlx::Error::RowNotFound) => {
sqlx::query!(
"
INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?)
",
guild_id,
guild.name,
*DEFAULT_PREFIX
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
",
guild_id
)
.fetch_one(pool)
.await?)
}
Err(e) => {
error!("Unexpected error in guild query: {:?}", e);
Err(e)
}
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE guilds SET name = ?, prefix = ? WHERE id = ?
",
self.name,
self.prefix,
self.id
)
.execute(pool)
.await
.unwrap();
}
}
pub struct ChannelData {
pub id: u32,
pub name: Option<String>,
pub nudge: i16,
pub blacklisted: bool,
pub webhook_id: Option<u64>,
pub webhook_token: Option<String>,
pub paused: bool,
pub paused_until: Option<NaiveDateTime>,
}
impl ChannelData {
pub async fn from_channel(
channel: Channel,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel.id().as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id)
.fetch_one(pool)
.await {
Ok(c)
}
else {
let props = channel.guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
let (guild_id, channel_name) = if let Some((a, b)) = props {
(Some(a), Some(b))
} else {
(None, None)
};
sqlx::query!(
"
INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))
", channel_id, channel_name, guild_id)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as_unchecked!(Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
", channel_id)
.fetch_one(pool)
.await?)
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until = ? WHERE id = ?
", self.name, self.nudge, self.blacklisted, self.webhook_id, self.webhook_token, self.paused, self.paused_until, self.id)
.execute(pool)
.await.unwrap();
}
}
pub struct UserData {
pub id: u32,
pub user: u64,
pub name: String,
pub dm_channel: u32,
pub language: String,
pub timezone: String,
pub meridian_time: bool,
}
pub struct MeridianType(bool);
impl MeridianType {
pub fn fmt_str(&self) -> &str {
if self.0 {
"%Y-%m-%d %I:%M:%S %p"
} else {
"%Y-%m-%d %H:%M:%S"
}
}
pub fn fmt_str_short(&self) -> &str {
if self.0 {
"%I:%M %p"
} else {
"%H:%M"
}
}
}
impl UserData {
pub async fn language_of<U>(user: U, pool: &MySqlPool) -> String
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT language FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => r.language,
Err(_) => LOCAL_LANGUAGE.clone(),
}
}
pub async fn timezone_of<U>(user: U, pool: &MySqlPool) -> Tz
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT timezone FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => r.timezone,
Err(_) => LOCAL_TIMEZONE.clone(),
}
.parse()
.unwrap()
}
pub async fn meridian_of<U>(user: U, pool: &MySqlPool) -> MeridianType
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT meridian_time FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => MeridianType(r.meridian_time != 0),
Err(_) => MeridianType(false),
}
}
pub async fn from_user(
user: &User,
ctx: impl CacheHttp,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().to_owned();
match sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, IF(language IS NULL, ?, language) AS language, IF(timezone IS NULL, ?, timezone) AS timezone, meridian_time FROM users WHERE user = ?
",
*LOCAL_LANGUAGE, *LOCAL_TIMEZONE, user_id
)
.fetch_one(pool)
.await
{
Ok(c) => Ok(c),
Err(sqlx::Error::RowNotFound) => {
let dm_channel = user.create_dm_channel(ctx).await?;
let dm_id = dm_channel.id.as_u64().to_owned();
let pool_c = pool.clone();
sqlx::query!(
"
INSERT IGNORE INTO channels (channel) VALUES (?)
",
dm_id
)
.execute(&pool_c)
.await?;
sqlx::query!(
"
INSERT INTO users (user, name, dm_channel, language, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?, ?)
", user_id, user.name, dm_id, *LOCAL_LANGUAGE, *LOCAL_TIMEZONE)
.execute(&pool_c)
.await?;
Ok(sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, language, timezone, meridian_time FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await?)
}
Err(e) => {
error!("Error querying for user: {:?}", e);
Err(Box::new(e))
},
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE users SET name = ?, language = ?, timezone = ?, meridian_time = ? WHERE id = ?
",
self.name,
self.language,
self.timezone,
self.meridian_time,
self.id
)
.execute(pool)
.await
.unwrap();
}
pub fn timezone(&self) -> Tz {
self.timezone.parse().unwrap()
}
pub fn meridian(&self) -> MeridianType {
MeridianType(self.meridian_time)
}
}
pub struct Timer {
pub name: String,
pub start_time: NaiveDateTime,
pub owner: u64,
}
impl Timer {
pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> {
sqlx::query_as_unchecked!(
Timer,
"
SELECT name, start_time, owner FROM timers WHERE owner = ?
",
owner
)
.fetch_all(pool)
.await
.unwrap()
}
pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 {
sqlx::query!(
"
SELECT COUNT(1) as count FROM timers WHERE owner = ?
",
owner
)
.fetch_one(pool)
.await
.unwrap()
.count as u32
}
pub async fn create(name: &str, owner: u64, pool: &MySqlPool) {
sqlx::query!(
"
INSERT INTO timers (name, owner) VALUES (?, ?)
",
name,
owner
)
.execute(pool)
.await
.unwrap();
}
}

View File

@ -1,81 +0,0 @@
use chrono::NaiveDateTime;
use serenity::model::channel::Channel;
use sqlx::MySqlPool;
pub struct ChannelData {
pub id: u32,
pub name: Option<String>,
pub nudge: i16,
pub blacklisted: bool,
pub webhook_id: Option<u64>,
pub webhook_token: Option<String>,
pub paused: bool,
pub paused_until: Option<NaiveDateTime>,
}
impl ChannelData {
pub async fn from_channel(
channel: &Channel,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel.id().as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(
Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
",
channel_id
)
.fetch_one(pool)
.await
{
Ok(c)
} else {
let props = channel.to_owned().guild().map(|g| (g.guild_id.as_u64().to_owned(), g.name));
let (guild_id, channel_name) = if let Some((a, b)) = props { (Some(a), Some(b)) } else { (None, None) };
sqlx::query!(
"
INSERT IGNORE INTO channels (channel, name, guild_id) VALUES (?, ?, (SELECT id FROM guilds WHERE guild = ?))
",
channel_id,
channel_name,
guild_id
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as_unchecked!(
Self,
"
SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_until FROM channels WHERE channel = ?
",
channel_id
)
.fetch_one(pool)
.await?)
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE channels SET name = ?, nudge = ?, blacklisted = ?, webhook_id = ?, webhook_token = ?, paused = ?, paused_until \
= ? WHERE id = ?
",
self.name,
self.nudge,
self.blacklisted,
self.webhook_id,
self.webhook_token,
self.paused,
self.paused_until,
self.id
)
.execute(pool)
.await
.unwrap();
}
}

View File

@ -1,33 +0,0 @@
use serenity::{client::Context, model::id::GuildId};
use crate::{framework::CommandOptions, SQLPool};
pub struct CommandMacro {
pub guild_id: GuildId,
pub name: String,
pub description: Option<String>,
pub commands: Vec<CommandOptions>,
}
impl CommandMacro {
pub async fn from_guild(ctx: &Context, guild_id: impl Into<GuildId>) -> Vec<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let guild_id = guild_id.into();
sqlx::query!(
"SELECT * FROM macro WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)",
guild_id.0
)
.fetch_all(&pool)
.await
.unwrap()
.iter()
.map(|row| Self {
guild_id,
name: row.name.clone(),
description: row.description.clone(),
commands: serde_json::from_str(&row.commands).unwrap(),
})
.collect::<Vec<Self>>()
}
}

View File

@ -1,66 +0,0 @@
pub mod channel_data;
pub mod command_macro;
pub mod reminder;
pub mod timer;
pub mod user_data;
use chrono_tz::Tz;
use serenity::{
async_trait,
model::id::{ChannelId, UserId},
prelude::Context,
};
use crate::{
models::{channel_data::ChannelData, user_data::UserData},
SQLPool,
};
#[async_trait]
pub trait CtxData {
async fn user_data<U: Into<UserId> + Send + Sync>(
&self,
user_id: U,
) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>>;
async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz;
async fn channel_data<C: Into<ChannelId> + Send + Sync>(
&self,
channel_id: C,
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>>;
}
#[async_trait]
impl CtxData for Context {
async fn user_data<U: Into<UserId> + Send + Sync>(
&self,
user_id: U,
) -> Result<UserData, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user_id.into();
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
let user = user_id.to_user(self).await.unwrap();
UserData::from_user(&user, &self, &pool).await
}
async fn timezone<U: Into<UserId> + Send + Sync>(&self, user_id: U) -> Tz {
let user_id = user_id.into();
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
UserData::timezone_of(user_id, &pool).await
}
async fn channel_data<C: Into<ChannelId> + Send + Sync>(
&self,
channel_id: C,
) -> Result<ChannelData, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel_id.into();
let pool = self.data.read().await.get::<SQLPool>().cloned().unwrap();
let channel = channel_id.to_channel_cached(&self).unwrap();
ChannelData::from_channel(&channel, &pool).await
}
}

View File

@ -1,305 +0,0 @@
use std::{collections::HashSet, fmt::Display};
use chrono::{Duration, NaiveDateTime, Utc};
use chrono_tz::Tz;
use serenity::{
client::Context,
http::CacheHttp,
model::{
channel::GuildChannel,
id::{ChannelId, GuildId, UserId},
webhook::Webhook,
},
Result as SerenityResult,
};
use sqlx::MySqlPool;
use crate::{
consts,
consts::{MAX_TIME, MIN_INTERVAL},
models::{
channel_data::ChannelData,
reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
user_data::UserData,
},
SQLPool,
};
async fn create_webhook(
ctx: impl CacheHttp,
channel: GuildChannel,
name: impl Display,
) -> SerenityResult<Webhook> {
channel.create_webhook_with_avatar(ctx.http(), name, consts::DEFAULT_AVATAR.clone()).await
}
#[derive(Hash, PartialEq, Eq)]
pub enum ReminderScope {
User(u64),
Channel(u64),
}
impl ReminderScope {
pub fn mention(&self) -> String {
match self {
Self::User(id) => format!("<@{}>", id),
Self::Channel(id) => format!("<#{}>", id),
}
}
}
pub struct ReminderBuilder {
pool: MySqlPool,
uid: String,
channel: u32,
utc_time: NaiveDateTime,
timezone: String,
interval: Option<i64>,
expires: Option<NaiveDateTime>,
content: String,
tts: bool,
attachment_name: Option<String>,
attachment: Option<Vec<u8>>,
set_by: Option<u32>,
}
impl ReminderBuilder {
pub async fn build(self) -> Result<Reminder, ReminderError> {
let queried_time = sqlx::query!(
"SELECT DATE_ADD(?, INTERVAL (SELECT nudge FROM channels WHERE id = ?) SECOND) AS `utc_time`",
self.utc_time,
self.channel,
)
.fetch_one(&self.pool)
.await
.unwrap();
match queried_time.utc_time {
Some(utc_time) => {
if utc_time < (Utc::now() - Duration::seconds(60)).naive_local() {
Err(ReminderError::PastTime)
} else {
sqlx::query!(
"
INSERT INTO reminders (
`uid`,
`channel_id`,
`utc_time`,
`timezone`,
`interval`,
`expires`,
`content`,
`tts`,
`attachment_name`,
`attachment`,
`set_by`
) VALUES (
?,
?,
?,
?,
?,
?,
?,
?,
?,
?,
?
)
",
self.uid,
self.channel,
utc_time,
self.timezone,
self.interval,
self.expires,
self.content,
self.tts,
self.attachment_name,
self.attachment,
self.set_by
)
.execute(&self.pool)
.await
.unwrap();
Ok(Reminder::from_uid(&self.pool, self.uid).await.unwrap())
}
}
None => Err(ReminderError::LongTime),
}
}
}
pub struct MultiReminderBuilder<'a> {
scopes: Vec<ReminderScope>,
utc_time: NaiveDateTime,
timezone: Tz,
interval: Option<i64>,
expires: Option<NaiveDateTime>,
content: Content,
set_by: Option<u32>,
ctx: &'a Context,
guild_id: Option<GuildId>,
}
impl<'a> MultiReminderBuilder<'a> {
pub fn new(ctx: &'a Context, guild_id: Option<GuildId>) -> Self {
MultiReminderBuilder {
scopes: vec![],
utc_time: Utc::now().naive_utc(),
timezone: Tz::UTC,
interval: None,
expires: None,
content: Content::new(),
set_by: None,
ctx,
guild_id,
}
}
pub fn content(mut self, content: Content) -> Self {
self.content = content;
self
}
pub fn time<T: Into<i64>>(mut self, time: T) -> Self {
self.utc_time = NaiveDateTime::from_timestamp(time.into(), 0);
self
}
pub fn expires<T: Into<i64>>(mut self, time: Option<T>) -> Self {
if let Some(t) = time {
self.expires = Some(NaiveDateTime::from_timestamp(t.into(), 0));
} else {
self.expires = None;
}
self
}
pub fn author(mut self, user: UserData) -> Self {
self.set_by = Some(user.id);
self.timezone = user.timezone();
self
}
pub fn interval(mut self, interval: Option<i64>) -> Self {
self.interval = interval;
self
}
pub fn set_scopes(&mut self, scopes: Vec<ReminderScope>) {
self.scopes = scopes;
}
pub async fn build(self) -> (HashSet<ReminderError>, HashSet<ReminderScope>) {
let pool = self.ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let mut errors = HashSet::new();
let mut ok_locs = HashSet::new();
if self.interval.map_or(false, |i| (i as i64) < *MIN_INTERVAL) {
errors.insert(ReminderError::ShortInterval);
} else if self.interval.map_or(false, |i| (i as i64) > *MAX_TIME) {
errors.insert(ReminderError::LongInterval);
} else {
for scope in self.scopes {
let db_channel_id = match scope {
ReminderScope::User(user_id) => {
if let Ok(user) = UserId(user_id).to_user(&self.ctx).await {
let user_data =
UserData::from_user(&user, &self.ctx, &pool).await.unwrap();
if let Some(guild_id) = self.guild_id {
if guild_id.member(&self.ctx, user).await.is_err() {
Err(ReminderError::InvalidTag)
} else {
Ok(user_data.dm_channel)
}
} else {
Ok(user_data.dm_channel)
}
} else {
Err(ReminderError::InvalidTag)
}
}
ReminderScope::Channel(channel_id) => {
let channel = ChannelId(channel_id).to_channel(&self.ctx).await.unwrap();
if let Some(guild_channel) = channel.clone().guild() {
if Some(guild_channel.guild_id) != self.guild_id {
Err(ReminderError::InvalidTag)
} else {
let mut channel_data =
ChannelData::from_channel(&channel, &pool).await.unwrap();
if channel_data.webhook_id.is_none()
|| channel_data.webhook_token.is_none()
{
match create_webhook(&self.ctx, guild_channel, "Reminder").await
{
Ok(webhook) => {
channel_data.webhook_id =
Some(webhook.id.as_u64().to_owned());
channel_data.webhook_token = webhook.token;
channel_data.commit_changes(&pool).await;
Ok(channel_data.id)
}
Err(e) => Err(ReminderError::DiscordError(e.to_string())),
}
} else {
Ok(channel_data.id)
}
}
} else {
Err(ReminderError::InvalidTag)
}
}
};
match db_channel_id {
Ok(c) => {
let builder = ReminderBuilder {
pool: pool.clone(),
uid: generate_uid(),
channel: c,
utc_time: self.utc_time,
timezone: self.timezone.to_string(),
interval: self.interval,
expires: self.expires,
content: self.content.content.clone(),
tts: self.content.tts,
attachment_name: self.content.attachment_name.clone(),
attachment: self.content.attachment.clone(),
set_by: self.set_by,
};
match builder.build().await {
Ok(_) => {
ok_locs.insert(scope);
}
Err(e) => {
errors.insert(e);
}
}
}
Err(e) => {
errors.insert(e);
}
}
}
}
(errors, ok_locs)
}
}

View File

@ -1,12 +0,0 @@
pub struct Content {
pub content: String,
pub tts: bool,
pub attachment: Option<Vec<u8>>,
pub attachment_name: Option<String>,
}
impl Content {
pub fn new() -> Self {
Self { content: "".to_string(), tts: false, attachment: None, attachment_name: None }
}
}

View File

@ -1,36 +0,0 @@
use crate::consts::{MAX_TIME, MIN_INTERVAL};
#[derive(PartialEq, Eq, Hash, Debug)]
pub enum ReminderError {
LongTime,
LongInterval,
PastTime,
ShortInterval,
InvalidTag,
DiscordError(String),
}
impl ToString for ReminderError {
fn to_string(&self) -> String {
match self {
ReminderError::LongTime => {
"That time is too far in the future. Please specify a shorter time.".to_string()
}
ReminderError::LongInterval => format!(
"Please ensure the interval specified is less than {max_time} days",
max_time = *MAX_TIME / 86_400
),
ReminderError::PastTime => {
"Please ensure the time provided is in the future. If the time should be in the future, please be more specific with the definition.".to_string()
}
ReminderError::ShortInterval => format!(
"Please ensure the interval provided is longer than {min_interval} seconds",
min_interval = *MIN_INTERVAL
),
ReminderError::InvalidTag => {
"Couldn't find a location by your tag. Your tag must be either a channel or a user (not a role)".to_string()
}
ReminderError::DiscordError(s) => format!("A Discord error occurred: **{}**", s),
}
}
}

View File

@ -1,31 +0,0 @@
use num_integer::Integer;
use rand::{rngs::OsRng, seq::IteratorRandom};
use crate::consts::{CHARACTERS, DAY, HOUR, MINUTE};
pub fn longhand_displacement(seconds: u64) -> String {
let (days, seconds) = seconds.div_rem(&DAY);
let (hours, seconds) = seconds.div_rem(&HOUR);
let (minutes, seconds) = seconds.div_rem(&MINUTE);
let mut sections = vec![];
for (var, name) in
[days, hours, minutes, seconds].iter().zip(["days", "hours", "minutes", "seconds"].iter())
{
if *var > 0 {
sections.push(format!("{} {}", var, name));
}
}
sections.join(", ")
}
pub fn generate_uid() -> String {
let mut generator: OsRng = Default::default();
(0..64)
.map(|_| CHARACTERS.chars().choose(&mut generator).unwrap().to_owned().to_string())
.collect::<Vec<String>>()
.join("")
}

View File

@ -1,23 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_repr::*;
use serenity::model::id::ChannelId;
#[derive(Serialize_repr, Deserialize_repr, Copy, Clone, Debug)]
#[repr(u8)]
pub enum TimeDisplayType {
Absolute = 0,
Relative = 1,
}
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
pub struct LookFlags {
pub show_disabled: bool,
pub channel_id: Option<ChannelId>,
pub time_display: TimeDisplayType,
}
impl Default for LookFlags {
fn default() -> Self {
Self { show_disabled: true, channel_id: None, time_display: TimeDisplayType::Relative }
}
}

View File

@ -1,284 +0,0 @@
pub mod builder;
pub mod content;
pub mod errors;
mod helper;
pub mod look_flags;
use chrono::{NaiveDateTime, TimeZone};
use chrono_tz::Tz;
use serenity::{
client::Context,
model::id::{ChannelId, GuildId, UserId},
};
use sqlx::MySqlPool;
use crate::{
models::reminder::{
helper::longhand_displacement,
look_flags::{LookFlags, TimeDisplayType},
},
SQLPool,
};
#[derive(Debug, Clone)]
pub struct Reminder {
pub id: u32,
pub uid: String,
pub channel: u64,
pub utc_time: NaiveDateTime,
pub interval: Option<u32>,
pub expires: Option<NaiveDateTime>,
pub enabled: bool,
pub content: String,
pub embed_description: String,
pub set_by: Option<u64>,
}
impl Reminder {
pub async fn from_uid(pool: &MySqlPool, uid: String) -> Option<Self> {
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
reminders.uid = ?
",
uid
)
.fetch_one(pool)
.await
.ok()
}
pub async fn from_channel<C: Into<ChannelId>>(
ctx: &Context,
channel_id: C,
flags: &LookFlags,
) -> Vec<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let enabled = if flags.show_disabled { "0,1" } else { "1" };
let channel_id = channel_id.into();
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
channels.channel = ? AND
FIND_IN_SET(reminders.enabled, ?)
ORDER BY
reminders.utc_time
",
channel_id.as_u64(),
enabled,
)
.fetch_all(&pool)
.await
.unwrap()
}
pub async fn from_guild(ctx: &Context, guild_id: Option<GuildId>, user: UserId) -> Vec<Self> {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
if let Some(guild_id) = guild_id {
let guild_opt = guild_id.to_guild_cached(&ctx);
if let Some(guild) = guild_opt {
let channels = guild
.channels
.keys()
.into_iter()
.map(|k| k.as_u64().to_string())
.collect::<Vec<String>>()
.join(",");
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
LEFT JOIN
channels
ON
channels.id = reminders.channel_id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
FIND_IN_SET(channels.channel, ?)
",
channels
)
.fetch_all(&pool)
.await
} else {
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
LEFT JOIN
channels
ON
channels.id = reminders.channel_id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
channels.guild_id = (SELECT id FROM guilds WHERE guild = ?)
",
guild_id.as_u64()
)
.fetch_all(&pool)
.await
}
} else {
sqlx::query_as_unchecked!(
Self,
"
SELECT
reminders.id,
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.expires,
reminders.enabled,
reminders.content,
reminders.embed_description,
users.user AS set_by
FROM
reminders
INNER JOIN
channels
ON
channels.id = reminders.channel_id
LEFT JOIN
users
ON
reminders.set_by = users.id
WHERE
channels.id = (SELECT dm_channel FROM users WHERE user = ?)
",
user.as_u64()
)
.fetch_all(&pool)
.await
}
.unwrap()
}
pub fn display_content(&self) -> &str {
if self.content.is_empty() {
&self.embed_description
} else {
&self.content
}
}
pub fn display_del(&self, count: usize, timezone: &Tz) -> String {
format!(
"**{}**: '{}' *<#{}>* at **{}**",
count + 1,
self.display_content(),
self.channel,
timezone
.timestamp(self.utc_time.timestamp(), 0)
.format("%Y-%m-%d %H:%M:%S")
.to_string()
)
}
pub fn display(&self, flags: &LookFlags, timezone: &Tz) -> String {
let time_display = match flags.time_display {
TimeDisplayType::Absolute => timezone
.timestamp(self.utc_time.timestamp(), 0)
.format("%Y-%m-%d %H:%M:%S")
.to_string(),
TimeDisplayType::Relative => format!("<t:{}:R>", self.utc_time.timestamp()),
};
if let Some(interval) = self.interval {
format!(
"'{}' *occurs next at* **{}**, repeating every **{}** (set by {})",
self.display_content(),
time_display,
longhand_displacement(interval as u64),
self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
)
} else {
format!(
"'{}' *occurs next at* **{}** (set by {})",
self.display_content(),
time_display,
self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
)
}
}
}

View File

@ -1,49 +0,0 @@
use chrono::NaiveDateTime;
use sqlx::MySqlPool;
pub struct Timer {
pub name: String,
pub start_time: NaiveDateTime,
pub owner: u64,
}
impl Timer {
pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> {
sqlx::query_as_unchecked!(
Timer,
"
SELECT name, start_time, owner FROM timers WHERE owner = ?
",
owner
)
.fetch_all(pool)
.await
.unwrap()
}
pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 {
sqlx::query!(
"
SELECT COUNT(1) as count FROM timers WHERE owner = ?
",
owner
)
.fetch_one(pool)
.await
.unwrap()
.count as u32
}
pub async fn create(name: &str, owner: u64, pool: &MySqlPool) {
sqlx::query!(
"
INSERT INTO timers (name, owner) VALUES (?, ?)
",
name,
owner
)
.execute(pool)
.await
.unwrap();
}
}

View File

@ -1,126 +0,0 @@
use chrono_tz::Tz;
use log::error;
use serenity::{
http::CacheHttp,
model::{id::UserId, user::User},
};
use sqlx::MySqlPool;
use crate::consts::LOCAL_TIMEZONE;
pub struct UserData {
pub id: u32,
pub user: u64,
pub name: String,
pub dm_channel: u32,
pub timezone: String,
}
impl UserData {
pub async fn timezone_of<U>(user: U, pool: &MySqlPool) -> Tz
where
U: Into<UserId>,
{
let user_id = user.into().as_u64().to_owned();
match sqlx::query!(
"
SELECT timezone FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await
{
Ok(r) => r.timezone,
Err(_) => LOCAL_TIMEZONE.clone(),
}
.parse()
.unwrap()
}
pub async fn from_user(
user: &User,
ctx: impl CacheHttp,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().to_owned();
match sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, IF(timezone IS NULL, ?, timezone) AS timezone FROM users WHERE user = ?
",
*LOCAL_TIMEZONE,
user_id
)
.fetch_one(pool)
.await
{
Ok(c) => Ok(c),
Err(sqlx::Error::RowNotFound) => {
let dm_channel = user.create_dm_channel(ctx).await?;
let dm_id = dm_channel.id.as_u64().to_owned();
let pool_c = pool.clone();
sqlx::query!(
"
INSERT IGNORE INTO channels (channel) VALUES (?)
",
dm_id
)
.execute(&pool_c)
.await?;
sqlx::query!(
"
INSERT INTO users (user, name, dm_channel, timezone) VALUES (?, ?, (SELECT id FROM channels WHERE channel = ?), ?)
",
user_id,
user.name,
dm_id,
*LOCAL_TIMEZONE
)
.execute(&pool_c)
.await?;
Ok(sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, timezone FROM users WHERE user = ?
",
user_id
)
.fetch_one(pool)
.await?)
}
Err(e) => {
error!("Error querying for user: {:?}", e);
Err(Box::new(e))
}
}
}
pub async fn commit_changes(&self, pool: &MySqlPool) {
sqlx::query!(
"
UPDATE users SET name = ?, timezone = ? WHERE id = ?
",
self.name,
self.timezone,
self.id
)
.execute(pool)
.await
.unwrap();
}
pub fn timezone(&self) -> Tz {
self.timezone.parse().unwrap()
}
}

View File

@ -1,552 +0,0 @@
use chrono::Duration;
use chrono_tz::Tz;
use log::{error, info, warn};
use num_integer::Integer;
use regex::{Captures, Regex};
use serenity::{
builder::CreateEmbed,
http::{CacheHttp, Http, StatusCode},
model::{
channel::{Channel, Embed as SerenityEmbed},
id::ChannelId,
webhook::Webhook,
},
Error, Result,
};
use sqlx::{
types::chrono::{NaiveDateTime, Utc},
MySqlPool,
};
lazy_static! {
pub static ref TIMEFROM_REGEX: Regex =
Regex::new(r#"<<timefrom:(?P<time>\d+):(?P<format>.+)?>>"#).unwrap();
pub static ref TIMENOW_REGEX: Regex =
Regex::new(r#"<<timenow:(?P<timezone>(?:\w|/|_)+):(?P<format>.+)?>>"#).unwrap();
}
fn fmt_displacement(format: &str, seconds: u64) -> String {
let mut seconds = seconds;
let mut days: u64 = 0;
let mut hours: u64 = 0;
let mut minutes: u64 = 0;
for (rep, time_type, div) in
[("%d", &mut days, 86400), ("%h", &mut hours, 3600), ("%m", &mut minutes, 60)].iter_mut()
{
if format.contains(*rep) {
let (divided, new_seconds) = seconds.div_rem(&div);
**time_type = divided;
seconds = new_seconds;
}
}
format
.replace("%s", &seconds.to_string())
.replace("%m", &minutes.to_string())
.replace("%h", &hours.to_string())
.replace("%d", &days.to_string())
}
pub fn substitute(string: &str) -> String {
let new = TIMEFROM_REGEX.replace(string, |caps: &Captures| {
let final_time = caps.name("time").unwrap().as_str();
let format = caps.name("format").unwrap().as_str();
if let Ok(final_time) = final_time.parse::<i64>() {
let dt = NaiveDateTime::from_timestamp(final_time, 0);
let now = Utc::now().naive_utc();
let difference = {
if now < dt {
dt - Utc::now().naive_utc()
} else {
Utc::now().naive_utc() - dt
}
};
fmt_displacement(format, difference.num_seconds() as u64)
} else {
String::new()
}
});
TIMENOW_REGEX
.replace(&new, |caps: &Captures| {
let timezone = caps.name("timezone").unwrap().as_str();
println!("{}", timezone);
if let Ok(tz) = timezone.parse::<Tz>() {
let format = caps.name("format").unwrap().as_str();
let now = Utc::now().with_timezone(&tz);
now.format(format).to_string()
} else {
String::new()
}
})
.to_string()
}
struct Embed {
inner: EmbedInner,
fields: Vec<EmbedField>,
}
struct EmbedInner {
title: String,
description: String,
image_url: Option<String>,
thumbnail_url: Option<String>,
footer: String,
footer_url: Option<String>,
author: String,
author_url: Option<String>,
color: u32,
}
struct EmbedField {
title: String,
value: String,
inline: bool,
}
impl Embed {
pub async fn from_id(pool: &MySqlPool, id: u32) -> Option<Self> {
let mut inner = sqlx::query_as_unchecked!(
EmbedInner,
"
SELECT
`embed_title` AS title,
`embed_description` AS description,
`embed_image_url` AS image_url,
`embed_thumbnail_url` AS thumbnail_url,
`embed_footer` AS footer,
`embed_footer_url` AS footer_url,
`embed_author` AS author,
`embed_author_url` AS author_url,
`embed_color` AS color
FROM
reminders
WHERE
`id` = ?
",
id
)
.fetch_one(&pool.clone())
.await
.unwrap();
inner.title = substitute(&inner.title);
inner.description = substitute(&inner.description);
inner.footer = substitute(&inner.footer);
let mut fields = sqlx::query_as_unchecked!(
EmbedField,
"
SELECT
title,
value,
inline
FROM
embed_fields
WHERE
reminder_id = ?
",
id
)
.fetch_all(pool)
.await
.unwrap();
fields.iter_mut().for_each(|mut field| {
field.title = substitute(&field.title);
field.value = substitute(&field.value);
});
let e = Embed { inner, fields };
if e.has_content() {
Some(e)
} else {
None
}
}
pub fn has_content(&self) -> bool {
if self.inner.title.is_empty()
&& self.inner.description.is_empty()
&& self.inner.image_url.is_none()
&& self.inner.thumbnail_url.is_none()
&& self.inner.footer.is_empty()
&& self.inner.footer_url.is_none()
&& self.inner.author.is_empty()
&& self.inner.author_url.is_none()
&& self.fields.is_empty()
{
false
} else {
true
}
}
}
impl Into<CreateEmbed> for Embed {
fn into(self) -> CreateEmbed {
let mut c = CreateEmbed::default();
c.title(&self.inner.title)
.description(&self.inner.description)
.color(self.inner.color)
.author(|a| {
a.name(&self.inner.author);
if let Some(author_icon) = &self.inner.author_url {
a.icon_url(author_icon);
}
a
})
.footer(|f| {
f.text(&self.inner.footer);
if let Some(footer_icon) = &self.inner.footer_url {
f.icon_url(footer_icon);
}
f
});
for field in &self.fields {
c.field(&field.title, &field.value, field.inline);
}
if let Some(image_url) = &self.inner.image_url {
c.image(image_url);
}
if let Some(thumbnail_url) = &self.inner.thumbnail_url {
c.thumbnail(thumbnail_url);
}
c
}
}
#[derive(Debug)]
pub struct Reminder {
id: u32,
channel_id: u64,
webhook_id: Option<u64>,
webhook_token: Option<String>,
channel_paused: bool,
channel_paused_until: Option<NaiveDateTime>,
enabled: bool,
tts: bool,
pin: bool,
content: String,
attachment: Option<Vec<u8>>,
attachment_name: Option<String>,
utc_time: NaiveDateTime,
timezone: String,
restartable: bool,
expires: Option<NaiveDateTime>,
interval: Option<u32>,
avatar: Option<String>,
username: Option<String>,
}
impl Reminder {
pub async fn fetch_reminders(pool: &MySqlPool) -> Vec<Self> {
sqlx::query_as_unchecked!(
Reminder,
"
SELECT
reminders.`id` AS id,
channels.`channel` AS channel_id,
channels.`webhook_id` AS webhook_id,
channels.`webhook_token` AS webhook_token,
channels.`paused` AS channel_paused,
channels.`paused_until` AS channel_paused_until,
reminders.`enabled` AS enabled,
reminders.`tts` AS tts,
reminders.`pin` AS pin,
reminders.`content` AS content,
reminders.`attachment` AS attachment,
reminders.`attachment_name` AS attachment_name,
reminders.`utc_time` AS 'utc_time',
reminders.`timezone` AS timezone,
reminders.`restartable` AS restartable,
reminders.`expires` AS expires,
reminders.`interval` AS 'interval',
reminders.`avatar` AS avatar,
reminders.`username` AS username
FROM
reminders
INNER JOIN
channels
ON
reminders.channel_id = channels.id
WHERE
reminders.`utc_time` < NOW()
",
)
.fetch_all(pool)
.await
.unwrap()
.into_iter()
.map(|mut rem| {
rem.content = substitute(&rem.content);
rem
})
.collect::<Vec<Self>>()
}
async fn reset_webhook(&self, pool: &MySqlPool) {
let _ = sqlx::query!(
"
UPDATE channels SET webhook_id = NULL, webhook_token = NULL WHERE channel = ?
",
self.channel_id
)
.execute(pool)
.await;
}
async fn refresh(&self, pool: &MySqlPool) {
if let Some(interval) = self.interval {
let now = Utc::now().naive_local();
let mut updated_reminder_time = self.utc_time;
while updated_reminder_time < now {
updated_reminder_time += Duration::seconds(interval as i64);
}
if self.expires.map_or(false, |expires| {
NaiveDateTime::from_timestamp(updated_reminder_time.timestamp(), 0) > expires
}) {
self.force_delete(pool).await;
} else {
sqlx::query!(
"
UPDATE reminders SET `utc_time` = ? WHERE `id` = ?
",
updated_reminder_time,
self.id
)
.execute(pool)
.await
.expect(&format!("Could not update time on Reminder {}", self.id));
}
} else {
self.force_delete(pool).await;
}
}
async fn force_delete(&self, pool: &MySqlPool) {
sqlx::query!(
"
DELETE FROM reminders WHERE `id` = ?
",
self.id
)
.execute(pool)
.await
.expect(&format!("Could not delete Reminder {}", self.id));
}
async fn pin_message<M: Into<u64>>(&self, message_id: M, http: impl AsRef<Http>) {
let _ = http.as_ref().pin_message(self.channel_id, message_id.into(), None).await;
}
pub async fn send(&self, pool: MySqlPool, cache_http: impl CacheHttp) {
async fn send_to_channel(
cache_http: impl CacheHttp,
reminder: &Reminder,
embed: Option<CreateEmbed>,
) -> Result<()> {
let channel = ChannelId(reminder.channel_id).to_channel(&cache_http).await;
match channel {
Ok(Channel::Guild(channel)) => {
match channel
.send_message(&cache_http, |m| {
m.content(&reminder.content).tts(reminder.tts);
if let (Some(attachment), Some(name)) =
(&reminder.attachment, &reminder.attachment_name)
{
m.add_file((attachment as &[u8], name.as_str()));
}
if let Some(embed) = embed {
m.set_embed(embed);
}
m
})
.await
{
Ok(m) => {
if reminder.pin {
reminder.pin_message(m.id, cache_http.http()).await;
}
Ok(())
}
Err(e) => Err(e),
}
}
Ok(Channel::Private(channel)) => {
match channel
.send_message(&cache_http.http(), |m| {
m.content(&reminder.content).tts(reminder.tts);
if let (Some(attachment), Some(name)) =
(&reminder.attachment, &reminder.attachment_name)
{
m.add_file((attachment as &[u8], name.as_str()));
}
if let Some(embed) = embed {
m.set_embed(embed);
}
m
})
.await
{
Ok(m) => {
if reminder.pin {
reminder.pin_message(m.id, cache_http.http()).await;
}
Ok(())
}
Err(e) => Err(e),
}
}
Err(e) => Err(e),
_ => Err(Error::Other("Channel not of valid type")),
}
}
async fn send_to_webhook(
cache_http: impl CacheHttp,
reminder: &Reminder,
webhook: Webhook,
embed: Option<CreateEmbed>,
) -> Result<()> {
match webhook
.execute(&cache_http.http(), reminder.pin || reminder.restartable, |w| {
w.content(&reminder.content).tts(reminder.tts);
if let Some(username) = &reminder.username {
w.username(username);
}
if let Some(avatar) = &reminder.avatar {
w.avatar_url(avatar);
}
if let (Some(attachment), Some(name)) =
(&reminder.attachment, &reminder.attachment_name)
{
w.add_file((attachment as &[u8], name.as_str()));
}
if let Some(embed) = embed {
w.embeds(vec![SerenityEmbed::fake(|c| {
*c = embed;
c
})]);
}
w
})
.await
{
Ok(m) => {
if reminder.pin {
if let Some(message) = m {
reminder.pin_message(message.id, cache_http.http()).await;
}
}
Ok(())
}
Err(e) => Err(e),
}
}
if self.enabled
&& !(self.channel_paused
&& self
.channel_paused_until
.map_or(true, |inner| inner >= Utc::now().naive_local()))
{
let _ = sqlx::query!(
"
UPDATE `channels` SET paused = 0, paused_until = NULL WHERE `channel` = ?
",
self.channel_id
)
.execute(&pool.clone())
.await;
let embed = Embed::from_id(&pool.clone(), self.id).await.map(|e| e.into());
let result = if let (Some(webhook_id), Some(webhook_token)) =
(self.webhook_id, &self.webhook_token)
{
let webhook_res =
cache_http.http().get_webhook_with_token(webhook_id, webhook_token).await;
if let Ok(webhook) = webhook_res {
send_to_webhook(cache_http, &self, webhook, embed).await
} else {
warn!("Webhook vanished: {:?}", webhook_res);
self.reset_webhook(&pool.clone()).await;
send_to_channel(cache_http, &self, embed).await
}
} else {
send_to_channel(cache_http, &self, embed).await
};
if let Err(e) = result {
error!("Error sending {:?}: {:?}", self, e);
if let Error::Http(error) = e {
if error.status_code() == Some(StatusCode::from_u16(404).unwrap()) {
error!("Seeing channel is deleted. Removing reminder");
self.force_delete(&pool).await;
} else {
self.refresh(&pool).await;
}
} else {
self.refresh(&pool).await;
}
} else {
self.refresh(&pool).await;
}
} else {
info!("Reminder {} is paused", self.id);
self.refresh(&pool).await;
}
}
}

View File

@ -1,16 +1,15 @@
use std::{
convert::TryFrom,
fmt::{Display, Formatter, Result as FmtResult},
str::from_utf8,
time::{SystemTime, UNIX_EPOCH},
};
use std::time::{SystemTime, UNIX_EPOCH};
use std::fmt::{Display, Formatter, Result as FmtResult};
use crate::consts::{LOCAL_TIMEZONE, PYTHON_LOCATION};
use chrono::{DateTime, Datelike, Timelike, Utc};
use chrono_tz::Tz;
use std::convert::TryFrom;
use std::str::from_utf8;
use tokio::process::Command;
use crate::consts::{LOCAL_TIMEZONE, PYTHON_LOCATION};
#[derive(Debug)]
pub enum InvalidTime {
ParseErrorDMY,
@ -27,15 +26,13 @@ impl Display for InvalidTime {
impl std::error::Error for InvalidTime {}
#[derive(Copy, Clone)]
enum ParseType {
Explicit,
Displacement,
}
#[derive(Clone)]
pub struct TimeParser {
timezone: Tz,
pub timezone: Tz,
inverted: bool,
time_string: String,
parse_type: ParseType,
@ -98,7 +95,10 @@ impl TimeParser {
}
fn process_explicit(&self) -> Result<i64, InvalidTime> {
let mut time = Utc::now().with_timezone(&self.timezone).with_second(0).unwrap();
let mut time = Utc::now()
.with_timezone(&self.timezone)
.with_second(0)
.unwrap();
let mut segments = self.time_string.rsplit('-');
// this segment will always exist even if split fails
@ -106,11 +106,13 @@ impl TimeParser {
let h_m_s = hms.split(':');
for (t, setter) in
h_m_s.take(3).zip(&[DateTime::with_hour, DateTime::with_minute, DateTime::with_second])
{
for (t, setter) in h_m_s.take(3).zip(&[
DateTime::with_hour,
DateTime::with_minute,
DateTime::with_second,
]) {
time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorHMS)?)
.map_or_else(|| Err(InvalidTime::ParseErrorHMS), Ok)?;
.map_or_else(|| Err(InvalidTime::ParseErrorHMS), |inner| Ok(inner))?;
}
if let Some(dmy) = segments.next() {
@ -120,11 +122,13 @@ impl TimeParser {
let month = d_m_y.next();
let year = d_m_y.next();
for (t, setter) in [day, month].iter().zip(&[DateTime::with_day, DateTime::with_month])
for (t, setter) in [day, month]
.iter()
.zip(&[DateTime::with_day, DateTime::with_month])
{
if let Some(t) = t {
time = setter(&time, t.parse().map_err(|_| InvalidTime::ParseErrorDMY)?)
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?;
}
}
@ -132,7 +136,7 @@ impl TimeParser {
if year.len() == 4 {
time = time
.with_year(year.parse().map_err(|_| InvalidTime::ParseErrorDMY)?)
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?;
} else if year.len() == 2 {
time = time
.with_year(
@ -140,9 +144,9 @@ impl TimeParser {
.parse()
.map_err(|_| InvalidTime::ParseErrorDMY)?,
)
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), Ok)?;
.map_or_else(|| Err(InvalidTime::ParseErrorDMY), |inner| Ok(inner))?;
} else {
return Err(InvalidTime::ParseErrorDMY);
Err(InvalidTime::ParseErrorDMY)?;
}
}
}
@ -153,10 +157,10 @@ impl TimeParser {
fn process_displacement(&self) -> Result<i64, InvalidTime> {
let mut current_buffer = "0".to_string();
let mut seconds = 0_i64;
let mut minutes = 0_i64;
let mut hours = 0_i64;
let mut days = 0_i64;
let mut seconds = 0 as i64;
let mut minutes = 0 as i64;
let mut hours = 0 as i64;
let mut days = 0 as i64;
for character in self.time_string.chars() {
match character {
@ -201,7 +205,7 @@ impl TimeParser {
}
}
pub async fn natural_parser(time: &str, timezone: &str) -> Option<i64> {
pub(crate) async fn natural_parser(time: &str, timezone: &str) -> Option<i64> {
Command::new(&*PYTHON_LOCATION)
.arg("-c")
.arg(include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/dp.py")))