interval months/interval seconds

This commit is contained in:
jellywx 2022-02-01 23:04:31 +00:00
parent 4f9eb58c16
commit fad28faabb
8 changed files with 619 additions and 276 deletions

565
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
USE reminders;
ALTER TABLE reminders.reminders RENAME COLUMN `interval` TO `interval_seconds`;
ALTER TABLE reminders.reminders ADD COLUMN `interval_months` INT UNSIGNED DEFAULT NULL;

View File

@ -19,6 +19,7 @@ use crate::{
consts::{EMBED_DESCRIPTION_MAX_LENGTH, REGEX_CHANNEL_USER, SELECT_MAX_ENTRIES, THEME_COLOR},
framework::{CommandInvoke, CommandOptions, CreateGenericResponse, OptionValue},
hooks::CHECK_GUILD_PERMISSIONS_HOOK,
interval_parser::parse_duration,
models::{
reminder::{
builder::{MultiReminderBuilder, ReminderScope},
@ -720,11 +721,8 @@ async fn remind(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions)
&& check_guild_subscription(&ctx, invoke.guild_id().unwrap()).await)
{
(
humantime::parse_duration(&repeat.to_string())
.or_else(|_| {
humantime::parse_duration(&format!("1 {}", repeat.to_string()))
})
.map(|duration| duration.as_secs() as i64)
parse_duration(&repeat.to_string())
.or_else(|_| parse_duration(&format!("1 {}", repeat.to_string())))
.ok(),
{
if let Some(arg) = args.get("expires") {
@ -769,6 +767,7 @@ async fn remind(ctx: &Context, invoke: &mut CommandInvoke, args: CommandOptions)
.author(user_data)
.content(content)
.time(time)
.timezone(timezone)
.expires(expires)
.interval(interval);

View File

@ -11,7 +11,7 @@ const THEME_COLOR_FALLBACK: u32 = 0x8fb677;
use std::{collections::HashSet, env, iter::FromIterator};
use regex::Regex;
use serenity::http::AttachmentType;
use serenity::model::prelude::AttachmentType;
lazy_static! {
pub static ref DEFAULT_AVATAR: AttachmentType<'static> = (

247
src/interval_parser.rs Normal file
View File

@ -0,0 +1,247 @@
/*
Copyright 2021 Paul Colomiets, 2022 Jude Southworth
Permission is hereby granted, free of charge, to any person obtaining a copy of this software
and associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or
substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
use std::{error::Error as StdError, fmt, str::Chars};
/// Error parsing human-friendly duration
#[derive(Debug, PartialEq, Clone)]
pub enum Error {
/// Invalid character during parsing
///
/// More specifically anything that is not alphanumeric is prohibited
///
/// The field is an byte offset of the character in the string.
InvalidCharacter(usize),
/// Non-numeric value where number is expected
///
/// This usually means that either time unit is broken into words,
/// e.g. `m sec` instead of `msec`, or just number is omitted,
/// for example `2 hours min` instead of `2 hours 1 min`
///
/// The field is an byte offset of the errorneous character
/// in the string.
NumberExpected(usize),
/// Unit in the number is not one of allowed units
///
/// See documentation of `parse_duration` for the list of supported
/// time units.
///
/// The two fields are start and end (exclusive) of the slice from
/// the original string, containing errorneous value
UnknownUnit {
/// Start of the invalid unit inside the original string
start: usize,
/// End of the invalid unit inside the original string
end: usize,
/// The unit verbatim
unit: String,
/// A number associated with the unit
value: u64,
},
/// The numeric value is too large
///
/// Usually this means value is too large to be useful. If user writes
/// data in subsecond units, then the maximum is about 3k years. When
/// using seconds, or larger units, the limit is even larger.
NumberOverflow,
/// The value was an empty string (or consists only whitespace)
Empty,
}
impl StdError for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::InvalidCharacter(offset) => write!(f, "invalid character at {}", offset),
Error::NumberExpected(offset) => write!(f, "expected number at {}", offset),
Error::UnknownUnit { unit, value, .. } if &unit == &"" => {
write!(f, "time unit needed, for example {0}sec or {0}ms", value,)
}
Error::UnknownUnit { unit, .. } => {
write!(
f,
"unknown time unit {:?}, \
supported units: ns, us, ms, sec, min, hours, days, \
weeks, months, years (and few variations)",
unit
)
}
Error::NumberOverflow => write!(f, "number is too large"),
Error::Empty => write!(f, "value was empty"),
}
}
}
trait OverflowOp: Sized {
fn mul(self, other: Self) -> Result<Self, Error>;
fn add(self, other: Self) -> Result<Self, Error>;
}
impl OverflowOp for u64 {
fn mul(self, other: Self) -> Result<Self, Error> {
self.checked_mul(other).ok_or(Error::NumberOverflow)
}
fn add(self, other: Self) -> Result<Self, Error> {
self.checked_add(other).ok_or(Error::NumberOverflow)
}
}
#[derive(Copy, Clone)]
pub struct Interval {
pub month: u64,
pub sec: u64,
}
struct Parser<'a> {
iter: Chars<'a>,
src: &'a str,
current: (u64, u64, u64),
}
impl<'a> Parser<'a> {
fn off(&self) -> usize {
self.src.len() - self.iter.as_str().len()
}
fn parse_first_char(&mut self) -> Result<Option<u64>, Error> {
let off = self.off();
for c in self.iter.by_ref() {
match c {
'0'..='9' => {
return Ok(Some(c as u64 - '0' as u64));
}
c if c.is_whitespace() => continue,
_ => {
return Err(Error::NumberExpected(off));
}
}
}
Ok(None)
}
fn parse_unit(&mut self, n: u64, start: usize, end: usize) -> Result<(), Error> {
let (mut month, mut sec, nsec) = match &self.src[start..end] {
"nanos" | "nsec" | "ns" => (0u64, 0u64, n),
"usec" | "us" => (0, 0u64, n.mul(1000)?),
"millis" | "msec" | "ms" => (0, 0u64, n.mul(1_000_000)?),
"seconds" | "second" | "secs" | "sec" | "s" => (0, n, 0),
"minutes" | "minute" | "min" | "mins" | "m" => (0, n.mul(60)?, 0),
"hours" | "hour" | "hr" | "hrs" | "h" => (0, n.mul(3600)?, 0),
"days" | "day" | "d" => (0, n.mul(86400)?, 0),
"weeks" | "week" | "w" => (0, n.mul(86400 * 7)?, 0),
"months" | "month" | "M" => (n, 0, 0),
"years" | "year" | "y" => (12, 0, 0),
_ => {
return Err(Error::UnknownUnit {
start,
end,
unit: self.src[start..end].to_string(),
value: n,
});
}
};
let mut nsec = self.current.2 + nsec;
if nsec > 1_000_000_000 {
sec = sec + nsec / 1_000_000_000;
nsec %= 1_000_000_000;
}
sec = self.current.1 + sec;
month = self.current.0 + month;
self.current = (month, sec, nsec);
Ok(())
}
fn parse(mut self) -> Result<Interval, Error> {
let mut n = self.parse_first_char()?.ok_or(Error::Empty)?;
'outer: loop {
let mut off = self.off();
while let Some(c) = self.iter.next() {
match c {
'0'..='9' => {
n = n
.checked_mul(10)
.and_then(|x| x.checked_add(c as u64 - '0' as u64))
.ok_or(Error::NumberOverflow)?;
}
c if c.is_whitespace() => {}
'a'..='z' | 'A'..='Z' => {
break;
}
_ => {
return Err(Error::InvalidCharacter(off));
}
}
off = self.off();
}
let start = off;
let mut off = self.off();
while let Some(c) = self.iter.next() {
match c {
'0'..='9' => {
self.parse_unit(n, start, off)?;
n = c as u64 - '0' as u64;
continue 'outer;
}
c if c.is_whitespace() => break,
'a'..='z' | 'A'..='Z' => {}
_ => {
return Err(Error::InvalidCharacter(off));
}
}
off = self.off();
}
self.parse_unit(n, start, off)?;
n = match self.parse_first_char()? {
Some(n) => n,
None => return Ok(Interval { month: self.current.0, sec: self.current.1 }),
};
}
}
}
/// Parse duration object `1hour 12min 5s`
///
/// The duration object is a concatenation of time spans. Where each time
/// span is an integer number and a suffix. Supported suffixes:
///
/// * `nsec`, `ns` -- nanoseconds
/// * `usec`, `us` -- microseconds
/// * `msec`, `ms` -- milliseconds
/// * `seconds`, `second`, `sec`, `s`
/// * `minutes`, `minute`, `min`, `m`
/// * `hours`, `hour`, `hr`, `h`
/// * `days`, `day`, `d`
/// * `weeks`, `week`, `w`
/// * `months`, `month`, `M` -- defined as 30.44 days
/// * `years`, `year`, `y` -- defined as 365.25 days
///
/// # Examples
///
/// ```
/// use std::time::Duration;
/// use humantime::parse_duration;
///
/// assert_eq!(parse_duration("2h 37min"), Ok(Duration::new(9420, 0)));
/// assert_eq!(parse_duration("32ms"), Ok(Duration::new(0, 32_000_000)));
/// ```
pub fn parse_duration(s: &str) -> Result<Interval, Error> {
Parser { iter: s.chars(), src: s, current: (0, 0, 0) }.parse()
}

View File

@ -7,6 +7,7 @@ mod component_models;
mod consts;
mod framework;
mod hooks;
mod interval_parser;
mod models;
mod time_parser;
@ -17,12 +18,12 @@ use dotenv::dotenv;
use log::info;
use serenity::{
async_trait,
client::{bridge::gateway::GatewayIntents, Client},
client::Client,
http::{client::Http, CacheHttp},
model::{
channel::GuildChannel,
gateway::{Activity, Ready},
guild::{Guild, GuildUnavailable},
gateway::{Activity, GatewayIntents, Ready},
guild::{Guild, UnavailableGuild},
id::{GuildId, UserId},
interactions::Interaction,
},
@ -144,7 +145,7 @@ DELETE FROM channels WHERE channel = ?
}
}
async fn guild_delete(&self, ctx: Context, incomplete: GuildUnavailable, _full: Option<Guild>) {
async fn guild_delete(&self, ctx: Context, incomplete: UnavailableGuild, _full: Option<Guild>) {
let pool = ctx.data.read().await.get::<SQLPool>().cloned().unwrap();
let _ = sqlx::query!("DELETE FROM guilds WHERE guild = ?", incomplete.id.0)
.execute(&pool)

View File

@ -16,7 +16,8 @@ use sqlx::MySqlPool;
use crate::{
consts,
consts::{MAX_TIME, MIN_INTERVAL},
consts::{DAY, MAX_TIME, MIN_INTERVAL},
interval_parser::Interval,
models::{
channel_data::ChannelData,
reminder::{content::Content, errors::ReminderError, helper::generate_uid, Reminder},
@ -54,7 +55,8 @@ pub struct ReminderBuilder {
channel: u32,
utc_time: NaiveDateTime,
timezone: String,
interval: Option<i64>,
interval_secs: Option<i64>,
interval_months: Option<i64>,
expires: Option<NaiveDateTime>,
content: String,
tts: bool,
@ -86,7 +88,8 @@ INSERT INTO reminders (
`channel_id`,
`utc_time`,
`timezone`,
`interval`,
`interval_seconds`,
`interval_months`,
`expires`,
`content`,
`tts`,
@ -104,6 +107,7 @@ INSERT INTO reminders (
?,
?,
?,
?,
?
)
",
@ -111,7 +115,8 @@ INSERT INTO reminders (
self.channel,
utc_time,
self.timezone,
self.interval,
self.interval_secs,
self.interval_months,
self.expires,
self.content,
self.tts,
@ -136,7 +141,7 @@ pub struct MultiReminderBuilder<'a> {
scopes: Vec<ReminderScope>,
utc_time: NaiveDateTime,
timezone: Tz,
interval: Option<i64>,
interval: Option<Interval>,
expires: Option<NaiveDateTime>,
content: Content,
set_by: Option<u32>,
@ -159,6 +164,12 @@ impl<'a> MultiReminderBuilder<'a> {
}
}
pub fn timezone(mut self, timezone: Tz) -> Self {
self.timezone = timezone;
self
}
pub fn content(mut self, content: Content) -> Self {
self.content = content;
@ -188,7 +199,7 @@ impl<'a> MultiReminderBuilder<'a> {
self
}
pub fn interval(mut self, interval: Option<i64>) -> Self {
pub fn interval(mut self, interval: Option<Interval>) -> Self {
self.interval = interval;
self
@ -205,9 +216,10 @@ impl<'a> MultiReminderBuilder<'a> {
let mut ok_locs = HashSet::new();
if self.interval.map_or(false, |i| (i as i64) < *MIN_INTERVAL) {
if self.interval.map_or(false, |i| ((i.sec + i.month * 30 * DAY) as i64) < *MIN_INTERVAL) {
errors.insert(ReminderError::ShortInterval);
} else if self.interval.map_or(false, |i| (i as i64) > *MAX_TIME) {
} else if self.interval.map_or(false, |i| ((i.sec + i.month * 30 * DAY) as i64) > *MAX_TIME)
{
errors.insert(ReminderError::LongInterval);
} else {
for scope in self.scopes {
@ -275,7 +287,8 @@ impl<'a> MultiReminderBuilder<'a> {
channel: c,
utc_time: self.utc_time,
timezone: self.timezone.to_string(),
interval: self.interval,
interval_secs: self.interval.map(|i| i.sec as i64),
interval_months: self.interval.map(|i| i.month as i64),
expires: self.expires,
content: self.content.content.clone(),
tts: self.content.tts,

View File

@ -13,10 +13,7 @@ use serenity::{
use sqlx::MySqlPool;
use crate::{
models::reminder::{
helper::longhand_displacement,
look_flags::{LookFlags, TimeDisplayType},
},
models::reminder::look_flags::{LookFlags, TimeDisplayType},
SQLPool,
};
@ -26,7 +23,8 @@ pub struct Reminder {
pub uid: String,
pub channel: u64,
pub utc_time: NaiveDateTime,
pub interval: Option<u32>,
pub interval_seconds: Option<u32>,
pub interval_months: Option<u32>,
pub expires: Option<NaiveDateTime>,
pub enabled: bool,
pub content: String,
@ -44,7 +42,8 @@ SELECT
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.interval_seconds,
reminders.interval_months,
reminders.expires,
reminders.enabled,
reminders.content,
@ -88,7 +87,8 @@ SELECT
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.interval_seconds,
reminders.interval_months,
reminders.expires,
reminders.enabled,
reminders.content,
@ -141,7 +141,8 @@ SELECT
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.interval_seconds,
reminders.interval_months,
reminders.expires,
reminders.enabled,
reminders.content,
@ -173,7 +174,8 @@ SELECT
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.interval_seconds,
reminders.interval_months,
reminders.expires,
reminders.enabled,
reminders.content,
@ -206,7 +208,8 @@ SELECT
reminders.uid,
channels.channel,
reminders.utc_time,
reminders.interval,
reminders.interval_seconds,
reminders.interval_months,
reminders.expires,
reminders.enabled,
reminders.content,
@ -264,12 +267,11 @@ WHERE
TimeDisplayType::Relative => format!("<t:{}:R>", self.utc_time.timestamp()),
};
if let Some(interval) = self.interval {
if self.interval_seconds.is_some() || self.interval_months.is_some() {
format!(
"'{}' *occurs next at* **{}**, repeating every **{}** (set by {})",
"'{}' *occurs next at* **{}**, repeating (set by {})",
self.display_content(),
time_display,
longhand_displacement(interval as u64),
self.set_by.map(|i| format!("<@{}>", i)).unwrap_or_else(|| "unknown".to_string())
)
} else {