working on the basic reminder commands

This commit is contained in:
jude
2020-09-11 17:41:15 +01:00
parent b7fd89e861
commit e222927858
7 changed files with 222 additions and 24 deletions

View File

@ -45,7 +45,7 @@ async fn help(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let desc = user_data.response(&pool, "help").await;
msg.channel_id.send_message(ctx, |m| m
@ -64,7 +64,7 @@ async fn info(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let desc = user_data.response(&pool, "info").await;
msg.channel_id.send_message(ctx, |m| m
@ -83,7 +83,7 @@ async fn donate(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let desc = user_data.response(&pool, "donate").await;
msg.channel_id.send_message(ctx, |m| m
@ -115,7 +115,7 @@ async fn clock(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let tz: Tz = user_data.timezone.parse().unwrap();

View File

@ -72,7 +72,7 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let mut user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
if args.len() > 0 {
match args.parse::<Tz>() {
@ -101,7 +101,7 @@ async fn language(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let mut user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
match sqlx::query!(
"
@ -134,7 +134,7 @@ async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult {
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let mut guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap();
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
if args.len() > 5 {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "prefix/too_long").await).await;
@ -159,7 +159,7 @@ async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap();
let role_tag_match = REGEX_ROLE.find(&args);
@ -245,7 +245,7 @@ async fn alias(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let guild_id = msg.guild_id.unwrap().as_u64().to_owned();

View File

@ -1,8 +1,15 @@
use custom_error::custom_error;
use regex_command_attr::command;
use serenity::{
client::Context,
model::{
id::{
UserId,
ChannelId,
GuildId,
},
channel::{
Message,
},
@ -36,8 +43,12 @@ use std::{
use regex::Regex;
use serde_json::json;
lazy_static! {
static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"^\s*<(#|@)(?:!)?(\d+)>\s*$"#).unwrap();
}
@ -48,7 +59,7 @@ async fn pause(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap();
if args.len() == 0 {
@ -93,7 +104,7 @@ async fn offset(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
if args.len() == 0 {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "offset/help").await).await;
@ -144,7 +155,7 @@ async fn nudge(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let mut channel = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap();
if args.len() == 0 {
@ -245,7 +256,7 @@ async fn look(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let flags = LookFlags::from_string(&args);
@ -314,7 +325,7 @@ async fn delete(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "del/listing").await).await;
@ -423,7 +434,7 @@ async fn timer(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let user_data = UserData::from_id(&msg.author, &ctx, &pool).await.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let mut args_iter = args.splitn(2, " ");
@ -491,3 +502,171 @@ DELETE FROM timers WHERE owner = ? AND name = ?
Ok(())
}
enum RemindCommand {
Remind,
Interval,
}
enum ReminderScope {
User(u64),
Channel(u64),
}
custom_error!{ReminderError
LongTime = "Time too long",
LongInterval = "Interval too long",
PastTime = "Time has already passed",
ShortInterval = "Interval too short",
InvalidTag = "Invalid reminder scope",
NotEnoughArgs = "Not enough args",
InvalidTime = "Invalid time provided",
DiscordError = "Bad response received from Discord"
}
#[command]
#[permission_level(Managed)]
async fn remind(ctx: &Context, msg: &Message, args: String) -> CommandResult {
remind_command(ctx, msg, args, RemindCommand::Remind).await;
Ok(())
}
#[command]
#[permission_level(Managed)]
async fn interval(ctx: &Context, msg: &Message, args: String) -> CommandResult {
remind_command(ctx, msg, args, RemindCommand::Interval).await;
Ok(())
}
async fn remind_command(ctx: &Context, msg: &Message, args: String, command: RemindCommand) {
let user_data;
{
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
}
let mut args_iter = args.split(' ').filter(|s| s.len() > 0);
if let Some(first_arg) = args_iter.next().map(|s| s.to_string()) {
let scope_id;
let time_parser;
let content;
let guild_id = msg.guild_id;
if let Some((Some(scope_match), Some(id_match))) = REGEX_CHANNEL_USER
.captures(&first_arg)
.map(|cap| (cap.get(1), cap.get(2))) {
if scope_match.as_str() == "@" {
scope_id = ReminderScope::User(id_match.as_str().parse::<u64>().unwrap());
}
else {
scope_id = ReminderScope::Channel(id_match.as_str().parse::<u64>().unwrap());
}
if let Some(next) = args_iter.next().map(|inner| inner.to_string()) {
time_parser = TimeParser::new(next, user_data.timezone.parse().unwrap());
content = args_iter.collect::<Vec<&str>>().join(" ");
// TODO replace unwrap with converting response into discord response
create_reminder(ctx, guild_id, scope_id, time_parser, content).await.unwrap();
}
else {
}
}
else {
scope_id = ReminderScope::Channel(msg.channel_id.as_u64().to_owned());
time_parser = TimeParser::new(first_arg, user_data.timezone.parse().unwrap());
content = args_iter.collect::<Vec<&str>>().join(" ");
// TODO replace unwrap with converting response into discord response
create_reminder(ctx, guild_id, scope_id, time_parser, content).await.unwrap();
}
}
else {
}
}
async fn create_reminder(ctx: &Context, guild_id: Option<GuildId>, scope_id: ReminderScope, time_parser: TimeParser, content: String)
-> Result<(), ReminderError> {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let db_channel_id = match scope_id {
ReminderScope::User(user_id) => {
let user = UserId(user_id).to_user(&ctx).await.unwrap();
let user_data = UserData::from_user(&user, &ctx, &pool).await.unwrap();
user_data.dm_channel
},
ReminderScope::Channel(channel_id) => {
let channel = ChannelId(channel_id).to_channel(&ctx).await.unwrap();
if channel.clone().guild().map(|gc| gc.guild_id) != guild_id {
return Err(ReminderError::InvalidTag)
}
let mut channel_data = ChannelData::from_channel(channel.clone(), &pool).await.unwrap();
if let Some(guild_channel) = channel.guild() {
if channel_data.webhook_token.is_none() || channel_data.webhook_id.is_none() {
if let Ok(webhook) = ctx.http.create_webhook(guild_channel.id.as_u64().to_owned(), &json!({"name": "Reminder"})).await {
channel_data.webhook_id = Some(webhook.id.as_u64().to_owned());
channel_data.webhook_token = Some(webhook.token);
channel_data.commit_changes(&pool).await;
}
else {
return Err(ReminderError::DiscordError)
}
}
}
channel_data.id
},
};
// validate time, channel, content
if content.len() == 0 {
Err(ReminderError::NotEnoughArgs)
}
else {
match time_parser.timestamp() {
Ok(time) => {
let unix_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
if time > unix_time {
if time > unix_time + 60*60*24*365*50 {
Err(ReminderError::LongTime)
}
else {
Ok(())
}
}
else {
Err(ReminderError::PastTime)
}
},
Err(_) => {
Err(ReminderError::InvalidTime)
},
}
}
}