ran rustfmt over project. cleared up a couple of clippy things

This commit is contained in:
jude 2020-10-12 21:01:27 +01:00
parent 88596fb399
commit c9fd2fea81
10 changed files with 1374 additions and 939 deletions

View File

@ -1,30 +1,15 @@
use regex_command_attr::command; use regex_command_attr::command;
use serenity::{ use serenity::{client::Context, framework::standard::CommandResult, model::channel::Message};
client::Context,
model::{
channel::{
Message,
},
},
framework::standard::CommandResult,
};
use chrono::offset::Utc; use chrono::offset::Utc;
use crate::{ use crate::{
models::{ models::{GuildData, UserData},
UserData, SQLPool, THEME_COLOR,
GuildData,
},
THEME_COLOR,
SQLPool,
}; };
use std::time::{ use std::time::{SystemTime, UNIX_EPOCH};
SystemTime,
UNIX_EPOCH
};
#[command] #[command]
#[can_blacklist(false)] #[can_blacklist(false)]
@ -36,7 +21,10 @@ async fn ping(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let delta = since_epoch.as_millis() as i64 - msg.timestamp.timestamp_millis(); let delta = since_epoch.as_millis() as i64 - msg.timestamp.timestamp_millis();
let _ = msg.channel_id.say(&ctx, format!("Time taken to receive message: {}ms", delta)).await; let _ = msg
.channel_id
.say(&ctx, format!("Time taken to receive message: {}ms", delta))
.await;
Ok(()) Ok(())
} }
@ -44,92 +32,131 @@ async fn ping(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
#[command] #[command]
#[can_blacklist(false)] #[can_blacklist(false)]
async fn help(ctx: &Context, msg: &Message, _args: String) -> CommandResult { async fn help(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let desc = user_data.response(&pool, "help").await; let desc = user_data.response(&pool, "help").await;
msg.channel_id.send_message(ctx, |m| m msg.channel_id
.embed(move |e| e .send_message(ctx, |m| {
.title("Help") m.embed(move |e| e.title("Help").description(desc).color(THEME_COLOR))
.description(desc) })
.color(THEME_COLOR) .await?;
)
).await?;
Ok(()) Ok(())
} }
#[command] #[command]
async fn info(ctx: &Context, msg: &Message, _args: String) -> CommandResult { async fn info(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap(); let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool)
.await
.unwrap();
let desc = user_data.response(&pool, "info").await let desc = user_data
.response(&pool, "info")
.await
.replacen("{user}", &ctx.cache.current_user().await.name, 1) .replacen("{user}", &ctx.cache.current_user().await.name, 1)
.replacen("{prefix}", &guild_data.prefix, 1); .replacen("{prefix}", &guild_data.prefix, 1);
msg.channel_id.send_message(ctx, |m| m msg.channel_id
.embed(move |e| e .send_message(ctx, |m| {
.title("Info") m.embed(move |e| e.title("Info").description(desc).color(THEME_COLOR))
.description(desc) })
.color(THEME_COLOR) .await?;
)
).await?;
Ok(()) Ok(())
} }
#[command] #[command]
async fn donate(ctx: &Context, msg: &Message, _args: String) -> CommandResult { async fn donate(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let desc = user_data.response(&pool, "donate").await; let desc = user_data.response(&pool, "donate").await;
msg.channel_id.send_message(ctx, |m| m msg.channel_id
.embed(move |e| e .send_message(ctx, |m| {
.title("Donate") m.embed(move |e| e.title("Donate").description(desc).color(THEME_COLOR))
.description(desc) })
.color(THEME_COLOR) .await?;
)
).await?;
Ok(()) Ok(())
} }
#[command] #[command]
async fn dashboard(ctx: &Context, msg: &Message, _args: String) -> CommandResult { async fn dashboard(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
msg.channel_id.send_message(ctx, |m| m msg.channel_id
.embed(move |e| e .send_message(ctx, |m| {
.title("Dashboard") m.embed(move |e| {
.description("https://reminder-bot.com/dashboard") e.title("Dashboard")
.color(THEME_COLOR) .description("https://reminder-bot.com/dashboard")
) .color(THEME_COLOR)
).await?; })
})
.await?;
Ok(()) Ok(())
} }
#[command] #[command]
async fn clock(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn clock(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let now = Utc::now().with_timezone(&user_data.timezone()); let now = Utc::now().with_timezone(&user_data.timezone());
if args == "12" { if args == "12" {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "clock/time").await.replacen("{}", &now.format("%I:%M:%S %p").to_string(), 1)).await; let _ = msg
} .channel_id
else { .say(
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "clock/time").await.replacen("{}", &now.format("%H:%M:%S").to_string(), 1)).await; &ctx,
user_data.response(&pool, "clock/time").await.replacen(
"{}",
&now.format("%I:%M:%S %p").to_string(),
1,
),
)
.await;
} else {
let _ = msg
.channel_id
.say(
&ctx,
user_data.response(&pool, "clock/time").await.replacen(
"{}",
&now.format("%H:%M:%S").to_string(),
1,
),
)
.await;
} }
Ok(()) Ok(())

View File

@ -1,4 +1,4 @@
pub mod info_cmds; pub mod info_cmds;
pub mod moderation_cmds;
pub mod reminder_cmds; pub mod reminder_cmds;
pub mod todo_cmds; pub mod todo_cmds;
pub mod moderation_cmds;

View File

@ -2,16 +2,8 @@ use regex_command_attr::command;
use serenity::{ use serenity::{
client::Context, client::Context,
model::{ framework::{standard::CommandResult, Framework},
id::RoleId, model::{channel::Message, id::RoleId},
channel::{
Message,
},
},
framework::{
Framework,
standard::CommandResult,
},
}; };
use chrono_tz::Tz; use chrono_tz::Tz;
@ -21,41 +13,40 @@ use chrono::offset::Utc;
use inflector::Inflector; use inflector::Inflector;
use crate::{ use crate::{
models::{ consts::{REGEX_ALIAS, REGEX_CHANNEL, REGEX_COMMANDS, REGEX_ROLE},
ChannelData,
UserData,
GuildData,
},
SQLPool,
FrameworkCtx,
framework::SendIterator, framework::SendIterator,
consts::{ models::{ChannelData, GuildData, UserData},
REGEX_ALIAS, FrameworkCtx, SQLPool,
REGEX_CHANNEL,
REGEX_COMMANDS,
REGEX_ROLE,
},
}; };
use std::iter; use std::iter;
#[command] #[command]
#[supports_dm(false)] #[supports_dm(false)]
#[permission_level(Restricted)] #[permission_level(Restricted)]
#[can_blacklist(false)] #[can_blacklist(false)]
async fn blacklist(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn blacklist(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let capture_opt = REGEX_CHANNEL.captures(&args).map(|cap| cap.get(1)).flatten(); let capture_opt = REGEX_CHANNEL
.captures(&args)
.map(|cap| cap.get(1))
.flatten();
let mut channel = match capture_opt { let mut channel = match capture_opt {
Some(capture) => Some(capture) => ChannelData::from_id(capture.as_str().parse::<u64>().unwrap(), &pool)
ChannelData::from_id(capture.as_str().parse::<u64>().unwrap(), &pool).await.unwrap(), .await
.unwrap(),
None => None => ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool)
ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap(), .await
.unwrap(),
}; };
channel.blacklisted = !channel.blacklisted; channel.blacklisted = !channel.blacklisted;
@ -63,8 +54,7 @@ async fn blacklist(ctx: &Context, msg: &Message, args: String) -> CommandResult
if channel.blacklisted { if channel.blacklisted {
let _ = msg.channel_id.say(&ctx, "Blacklisted").await; let _ = msg.channel_id.say(&ctx, "Blacklisted").await;
} } else {
else {
let _ = msg.channel_id.say(&ctx, "Unblacklisted").await; let _ = msg.channel_id.say(&ctx, "Unblacklisted").await;
} }
@ -73,8 +63,13 @@ async fn blacklist(ctx: &Context, msg: &Message, args: String) -> CommandResult
#[command] #[command]
async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
@ -86,7 +81,9 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let now = Utc::now().with_timezone(&user_data.timezone()); let now = Utc::now().with_timezone(&user_data.timezone());
let content = user_data.response(&pool, "timezone/set_p").await let content = user_data
.response(&pool, "timezone/set_p")
.await
.replacen("{timezone}", &user_data.timezone, 1) .replacen("{timezone}", &user_data.timezone, 1)
.replacen("{time}", &now.format("%H:%M").to_string(), 1); .replacen("{time}", &now.format("%H:%M").to_string(), 1);
@ -94,13 +91,23 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
} }
Err(_) => { Err(_) => {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "timezone/no_timezone").await).await; let _ = msg
.channel_id
.say(
&ctx,
user_data.response(&pool, "timezone/no_timezone").await,
)
.await;
} }
} }
} } else {
else { let content = user_data
let content = user_data.response(&pool, "timezone/no_argument").await .response(&pool, "timezone/no_argument")
.replace("{prefix}", &GuildData::prefix_from_id(msg.guild_id, &pool).await) .await
.replace(
"{prefix}",
&GuildData::prefix_from_id(msg.guild_id, &pool).await,
)
.replacen("{timezone}", &user_data.timezone, 1); .replacen("{timezone}", &user_data.timezone, 1);
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
@ -111,25 +118,36 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
#[command] #[command]
async fn language(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn language(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
match sqlx::query!( match sqlx::query!(
" "
SELECT code FROM languages WHERE code = ? OR name = ? SELECT code FROM languages WHERE code = ? OR name = ?
", args, args) ",
.fetch_one(&pool) args,
.await { args
)
.fetch_one(&pool)
.await
{
Ok(row) => { Ok(row) => {
user_data.language = row.code; user_data.language = row.code;
user_data.commit_changes(&pool).await; user_data.commit_changes(&pool).await;
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "lang/set_p").await).await; let _ = msg
}, .channel_id
.say(&ctx, user_data.response(&pool, "lang/set_p").await)
.await;
}
Err(_) => { Err(_) => {
let language_codes = sqlx::query!("SELECT name, code FROM languages") let language_codes = sqlx::query!("SELECT name, code FROM languages")
@ -137,15 +155,24 @@ SELECT code FROM languages WHERE code = ? OR name = ?
.await .await
.unwrap() .unwrap()
.iter() .iter()
.map(|language| format!("{} ({})", language.name.to_title_case(), language.code.to_uppercase())) .map(|language| {
format!(
"{} ({})",
language.name.to_title_case(),
language.code.to_uppercase()
)
})
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("\n"); .join("\n");
let content = user_data.response(&pool, "lang/invalid").await let content =
.replacen("{}", &language_codes, 1); user_data
.response(&pool, "lang/invalid")
.await
.replacen("{}", &language_codes, 1);
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
}, }
} }
Ok(()) Ok(())
@ -155,24 +182,38 @@ SELECT code FROM languages WHERE code = ? OR name = ?
#[supports_dm(false)] #[supports_dm(false)]
#[permission_level(Restricted)] #[permission_level(Restricted)]
async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let mut guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap(); let mut guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool)
.await
.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
if args.len() > 5 { if args.len() > 5 {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "prefix/too_long").await).await; let _ = msg
} .channel_id
else if args.is_empty() { .say(&ctx, user_data.response(&pool, "prefix/too_long").await)
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "prefix/no_argument").await).await; .await;
} } else if args.is_empty() {
else { let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "prefix/no_argument").await)
.await;
} else {
guild_data.prefix = args; guild_data.prefix = args;
guild_data.commit_changes(&pool).await; guild_data.commit_changes(&pool).await;
let content = user_data.response(&pool, "prefix/success").await let content = user_data.response(&pool, "prefix/success").await.replacen(
.replacen("{prefix}", &guild_data.prefix, 1); "{prefix}",
&guild_data.prefix,
1,
);
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
} }
@ -184,17 +225,31 @@ async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult {
#[supports_dm(false)] #[supports_dm(false)]
#[permission_level(Restricted)] #[permission_level(Restricted)]
async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap(); let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool)
.await
.unwrap();
let role_tag_match = REGEX_ROLE.find(&args); let role_tag_match = REGEX_ROLE.find(&args);
if let Some(role_tag) = role_tag_match { if let Some(role_tag) = role_tag_match {
let commands = REGEX_COMMANDS.find_iter(&args.to_lowercase()).map(|c| c.as_str().to_string()).collect::<Vec<String>>(); let commands = REGEX_COMMANDS
let role_id = RoleId(role_tag.as_str()[3..role_tag.as_str().len()-1].parse::<u64>().unwrap()); .find_iter(&args.to_lowercase())
.map(|c| c.as_str().to_string())
.collect::<Vec<String>>();
let role_id = RoleId(
role_tag.as_str()[3..role_tag.as_str().len() - 1]
.parse::<u64>()
.unwrap(),
);
let role_opt = role_id.to_role_cached(&ctx).await; let role_opt = role_id.to_role_cached(&ctx).await;
@ -202,20 +257,28 @@ async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let _ = sqlx::query!( let _ = sqlx::query!(
" "
DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?) DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?)
", role.id.as_u64()) ",
.execute(&pool) role.id.as_u64()
.await; )
.execute(&pool)
.await;
if commands.is_empty() { if commands.is_empty() {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/disabled").await).await; let _ = msg
} .channel_id
else { .say(&ctx, user_data.response(&pool, "restrict/disabled").await)
.await;
} else {
let _ = sqlx::query!( let _ = sqlx::query!(
" "
INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, ?, ?) INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, ?, ?)
", role.id.as_u64(), role.name, guild_data.id) ",
.execute(&pool) role.id.as_u64(),
.await; role.name,
guild_data.id
)
.execute(&pool)
.await;
for command in commands { for command in commands {
let res = sqlx::query!( let res = sqlx::query!(
@ -228,18 +291,22 @@ INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM role
if res.is_err() { if res.is_err() {
println!("{:?}", res); println!("{:?}", res);
let content = user_data.response(&pool, "restrict/failure").await let content = user_data
.response(&pool, "restrict/failure")
.await
.replacen("{command}", &command, 1); .replacen("{command}", &command, 1);
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
} }
} }
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/enabled").await).await; let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "restrict/enabled").await)
.await;
} }
} }
} } else if args.is_empty() {
else if args.is_empty() {
let guild_id = msg.guild_id.unwrap().as_u64().to_owned(); let guild_id = msg.guild_id.unwrap().as_u64().to_owned();
let rows = sqlx::query!( let rows = sqlx::query!(
@ -254,18 +321,29 @@ ON
roles.id = command_restrictions.role_id roles.id = command_restrictions.role_id
WHERE WHERE
roles.guild_id = (SELECT id FROM guilds WHERE guild = ?) roles.guild_id = (SELECT id FROM guilds WHERE guild = ?)
", guild_id) ",
.fetch_all(&pool) guild_id
.await )
.unwrap(); .fetch_all(&pool)
.await
.unwrap();
let display_inner = rows.iter().map(|row| format!("<@&{}> can use {}", row.role, row.command)).collect::<Vec<String>>().join("\n"); let display_inner = rows
let display = user_data.response(&pool, "restrict/allowed").await.replacen("{}", &display_inner, 1); .iter()
.map(|row| format!("<@&{}> can use {}", row.role, row.command))
.collect::<Vec<String>>()
.join("\n");
let display = user_data
.response(&pool, "restrict/allowed")
.await
.replacen("{}", &display_inner, 1);
let _ = msg.channel_id.say(&ctx, display).await; let _ = msg.channel_id.say(&ctx, display).await;
} } else {
else { let _ = msg
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/help").await).await; .channel_id
.say(&ctx, user_data.response(&pool, "restrict/help").await)
.await;
} }
Ok(()) Ok(())
@ -275,8 +353,13 @@ WHERE
#[supports_dm(false)] #[supports_dm(false)]
#[permission_level(Managed)] #[permission_level(Managed)]
async fn alias(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn alias(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
@ -293,21 +376,21 @@ async fn alias(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let aliases = sqlx::query!( let aliases = sqlx::query!(
" "
SELECT name, command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) SELECT name, command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)
", guild_id) ",
.fetch_all(&pool) guild_id
.await )
.unwrap(); .fetch_all(&pool)
.await
.unwrap();
let content = iter::once("Aliases:".to_string()) let content = iter::once("Aliases:".to_string()).chain(
.chain( aliases
aliases .iter()
.iter() .map(|row| format!("**{}**: `{}`", row.name, row.command)),
.map(|row| format!("**{}**: `{}`", row.name, row.command) );
)
);
let _ = msg.channel_id.say_lines(&ctx, content).await; let _ = msg.channel_id.say_lines(&ctx, content).await;
}, }
"remove" => { "remove" => {
if let Some(command) = command_opt { if let Some(command) = command_opt {
@ -322,19 +405,27 @@ SELECT COUNT(1) AS count FROM command_aliases WHERE name = ? AND guild_id = (SEL
sqlx::query!( sqlx::query!(
" "
DELETE FROM command_aliases WHERE name = ? AND guild_id = (SELECT id FROM guilds WHERE guild = ?) DELETE FROM command_aliases WHERE name = ? AND guild_id = (SELECT id FROM guilds WHERE guild = ?)
", command, guild_id) ",
.execute(&pool) command,
.await guild_id
.unwrap(); )
.execute(&pool)
.await
.unwrap();
let content = user_data.response(&pool, "alias/removed").await.replace("{count}", &deleted_count.count.to_string()); let content = user_data
.response(&pool, "alias/removed")
.await
.replace("{count}", &deleted_count.count.to_string());
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
} else {
let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "alias/help").await)
.await;
} }
else { }
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "alias/help").await).await;
}
},
name => { name => {
if let Some(command) = command_opt { if let Some(command) = command_opt {
@ -355,11 +446,13 @@ UPDATE command_aliases SET command = ? WHERE guild_id = (SELECT id FROM guilds W
.unwrap(); .unwrap();
} }
let content = user_data.response(&pool, "alias/created").await.replace("{name}", name); let content = user_data
.response(&pool, "alias/created")
.await
.replace("{name}", name);
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
} } else {
else {
match sqlx::query!( match sqlx::query!(
" "
SELECT command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ? SELECT command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?
@ -386,10 +479,12 @@ SELECT command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHER
} }
} }
} }
} } else {
else {
let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await; let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await;
let content = user_data.response(&pool, "alias/help").await.replace("{prefix}", &prefix); let content = user_data
.response(&pool, "alias/help")
.await
.replace("{prefix}", &prefix);
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,26 +1,19 @@
use regex_command_attr::command; use regex_command_attr::command;
use serenity::{ use serenity::{
constants::MESSAGE_CODE_LIMIT,
client::Context, client::Context,
model::{ constants::MESSAGE_CODE_LIMIT,
id::{
UserId, GuildId, ChannelId,
},
channel::{
Message,
},
},
framework::standard::CommandResult, framework::standard::CommandResult,
model::{
channel::Message,
id::{ChannelId, GuildId, UserId},
},
}; };
use std::fmt; use std::fmt;
use crate::{ use crate::{
models::{ models::{GuildData, UserData},
UserData,
GuildData,
},
SQLPool, SQLPool,
}; };
use sqlx::MySqlPool; use sqlx::MySqlPool;
@ -54,18 +47,15 @@ impl TodoTarget {
pub fn command(&self, subcommand_opt: Option<SubCommand>) -> String { pub fn command(&self, subcommand_opt: Option<SubCommand>) -> String {
let context = if self.channel.is_some() { let context = if self.channel.is_some() {
"channel" "channel"
} } else if self.guild.is_some() {
else if self.guild.is_some() {
"guild" "guild"
} } else {
else {
"user" "user"
}; };
if let Some(subcommand) = subcommand_opt { if let Some(subcommand) = subcommand_opt {
format!("todo {} {}", context, subcommand.to_string()) format!("todo {} {}", context, subcommand.to_string())
} } else {
else {
format!("todo {}", context) format!("todo {}", context)
} }
} }
@ -73,43 +63,56 @@ impl TodoTarget {
pub fn name(&self) -> String { pub fn name(&self) -> String {
if self.channel.is_some() { if self.channel.is_some() {
"Channel" "Channel"
} } else if self.guild.is_some() {
else if self.guild.is_some() {
"Guild" "Guild"
} } else {
else {
"User" "User"
}.to_string() }
.to_string()
} }
pub async fn view(&self, pool: MySqlPool) -> Result<Vec<Todo>, Box<dyn std::error::Error + Send + Sync>> { pub async fn view(
&self,
pool: MySqlPool,
) -> Result<Vec<Todo>, Box<dyn std::error::Error + Send + Sync>> {
Ok(if let Some(cid) = self.channel { Ok(if let Some(cid) = self.channel {
sqlx::query_as!(Todo, sqlx::query_as!(
Todo,
" "
SELECT * FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?) SELECT * FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
", cid.as_u64()) ",
.fetch_all(&pool) cid.as_u64()
.await? )
} .fetch_all(&pool)
else if let Some(gid) = self.guild { .await?
sqlx::query_as!(Todo, } else if let Some(gid) = self.guild {
sqlx::query_as!(
Todo,
" "
SELECT * FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL SELECT * FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
", gid.as_u64()) ",
.fetch_all(&pool) gid.as_u64()
.await? )
} .fetch_all(&pool)
else { .await?
sqlx::query_as!(Todo, } else {
sqlx::query_as!(
Todo,
" "
SELECT * FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL SELECT * FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
", self.user.as_u64()) ",
.fetch_all(&pool) self.user.as_u64()
.await? )
.fetch_all(&pool)
.await?
}) })
} }
pub async fn add(&self, value: String, pool: MySqlPool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { pub async fn add(
&self,
value: String,
pool: MySqlPool,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let (Some(cid), Some(gid)) = (self.channel, self.guild) { if let (Some(cid), Some(gid)) = (self.channel, self.guild) {
sqlx::query!( sqlx::query!(
" "
@ -119,11 +122,15 @@ INSERT INTO todos (user_id, guild_id, channel_id, value) VALUES (
(SELECT id FROM channels WHERE channel = ?), (SELECT id FROM channels WHERE channel = ?),
? ?
) )
", self.user.as_u64(), gid.as_u64(), cid.as_u64(), value) ",
.execute(&pool) self.user.as_u64(),
.await?; gid.as_u64(),
} cid.as_u64(),
else if let Some(gid) = self.guild { value
)
.execute(&pool)
.await?;
} else if let Some(gid) = self.guild {
sqlx::query!( sqlx::query!(
" "
INSERT INTO todos (user_id, guild_id, value) VALUES ( INSERT INTO todos (user_id, guild_id, value) VALUES (
@ -131,74 +138,95 @@ INSERT INTO todos (user_id, guild_id, value) VALUES (
(SELECT id FROM guilds WHERE guild = ?), (SELECT id FROM guilds WHERE guild = ?),
? ?
) )
", self.user.as_u64(), gid.as_u64(), value) ",
.execute(&pool) self.user.as_u64(),
.await?; gid.as_u64(),
} value
else { )
.execute(&pool)
.await?;
} else {
sqlx::query!( sqlx::query!(
" "
INSERT INTO todos (user_id, value) VALUES ( INSERT INTO todos (user_id, value) VALUES (
(SELECT id FROM users WHERE user = ?), (SELECT id FROM users WHERE user = ?),
? ?
) )
", self.user.as_u64(), value) ",
.execute(&pool) self.user.as_u64(),
.await?; value
)
.execute(&pool)
.await?;
} }
Ok(()) Ok(())
} }
pub async fn remove(&self, num: usize, pool: &MySqlPool) -> Result<Todo, Box<dyn std::error::Error + Sync + Send>> { pub async fn remove(
&self,
num: usize,
pool: &MySqlPool,
) -> Result<Todo, Box<dyn std::error::Error + Sync + Send>> {
let todos = self.view(pool.clone()).await?; let todos = self.view(pool.clone()).await?;
if let Some(removal_todo) = todos.get(num) { if let Some(removal_todo) = todos.get(num) {
let deleting = sqlx::query_as!(Todo, let deleting = sqlx::query_as!(
Todo,
" "
SELECT * FROM todos WHERE id = ? SELECT * FROM todos WHERE id = ?
", removal_todo.id) ",
.fetch_one(&pool.clone()) removal_todo.id
.await?; )
.fetch_one(&pool.clone())
.await?;
sqlx::query!( sqlx::query!(
" "
DELETE FROM todos WHERE id = ? DELETE FROM todos WHERE id = ?
", removal_todo.id) ",
.execute(pool) removal_todo.id
.await?; )
.execute(pool)
.await?;
Ok(deleting) Ok(deleting)
} } else {
else {
Err(Box::new(TodoNotFound)) Err(Box::new(TodoNotFound))
} }
} }
pub async fn clear(&self, pool: &MySqlPool) -> Result<(), Box<dyn std::error::Error + Sync + Send>> { pub async fn clear(
&self,
pool: &MySqlPool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
if let Some(cid) = self.channel { if let Some(cid) = self.channel {
sqlx::query!( sqlx::query!(
" "
DELETE FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?) DELETE FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
", cid.as_u64()) ",
.execute(pool) cid.as_u64()
.await?; )
} .execute(pool)
else if let Some(gid) = self.guild { .await?;
} else if let Some(gid) = self.guild {
sqlx::query!( sqlx::query!(
" "
DELETE FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL DELETE FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
", gid.as_u64()) ",
.execute(pool) gid.as_u64()
.await?; )
} .execute(pool)
else { .await?;
} else {
sqlx::query!( sqlx::query!(
" "
DELETE FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL DELETE FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
", self.user.as_u64()) ",
.execute(pool) self.user.as_u64()
.await?; )
.execute(pool)
.await?;
} }
Ok(()) Ok(())
@ -219,47 +247,53 @@ impl ToString for SubCommand {
SubCommand::Add => "add", SubCommand::Add => "add",
SubCommand::Remove => "remove", SubCommand::Remove => "remove",
SubCommand::Clear => "clear", SubCommand::Clear => "clear",
}.to_string() }
.to_string()
} }
} }
#[command] #[command]
#[permission_level(Managed)] #[permission_level(Managed)]
async fn todo_parse(ctx: &Context, msg: &Message, args: String) -> CommandResult { async fn todo_parse(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let mut split = args.split(' '); let mut split = args.split(' ');
if let Some(target) = split.next() { if let Some(target) = split.next() {
let target_opt = match target { let target_opt = match target {
"user" => "user" => Some(TodoTarget {
Some(TodoTarget {user: msg.author.id, guild: None, channel: None}), user: msg.author.id,
guild: None,
channel: None,
}),
"channel" => "channel" => {
if let Some(gid) = msg.guild_id { if let Some(gid) = msg.guild_id {
Some(TodoTarget {user: msg.author.id, guild: Some(gid), channel: Some(msg.channel_id)}) Some(TodoTarget {
} user: msg.author.id,
else { guild: Some(gid),
channel: Some(msg.channel_id),
})
} else {
None None
}, }
}
"server" | "guild" => { "server" | "guild" => {
if let Some(gid) = msg.guild_id { if let Some(gid) = msg.guild_id {
Some(TodoTarget {user: msg.author.id, guild: Some(gid), channel: None}) Some(TodoTarget {
} user: msg.author.id,
else { guild: Some(gid),
channel: None,
})
} else {
None None
} }
}, }
_ => { _ => None,
None
},
}; };
if let Some(target) = target_opt { if let Some(target) = target_opt {
let subcommand_opt = match split.next() { let subcommand_opt = match split.next() {
Some("add") => Some(SubCommand::Add), Some("add") => Some(SubCommand::Add),
Some("remove") => Some(SubCommand::Remove), Some("remove") => Some(SubCommand::Remove),
@ -272,19 +306,21 @@ async fn todo_parse(ctx: &Context, msg: &Message, args: String) -> CommandResult
}; };
if let Some(subcommand) = subcommand_opt { if let Some(subcommand) = subcommand_opt {
todo(ctx, msg, target, subcommand, split.collect::<Vec<&str>>().join(" ")).await; todo(
} ctx,
else { msg,
target,
subcommand,
split.collect::<Vec<&str>>().join(" "),
)
.await;
} else {
show_help(&ctx, msg, Some(target)).await; show_help(&ctx, msg, Some(target)).await;
} }
} else {
}
else {
show_help(&ctx, msg, None).await; show_help(&ctx, msg, None).await;
} }
} else {
}
else {
show_help(&ctx, msg, None).await; show_help(&ctx, msg, None).await;
} }
@ -292,22 +328,45 @@ async fn todo_parse(ctx: &Context, msg: &Message, args: String) -> CommandResult
} }
async fn show_help(ctx: &Context, msg: &Message, target: Option<TodoTarget>) { async fn show_help(ctx: &Context, msg: &Message, target: Option<TodoTarget>) {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await; let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await;
let content = user_data.response(&pool, "todo/help").await let content = user_data
.response(&pool, "todo/help")
.await
.replace("{prefix}", &prefix) .replace("{prefix}", &prefix)
.replace("{command}", target.map_or_else(|| "todo user".to_string(), |t| t.command(None)).as_str()); .replace(
"{command}",
target
.map_or_else(|| "todo user".to_string(), |t| t.command(None))
.as_str(),
);
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
} }
async fn todo(ctx: &Context, msg: &Message, target: TodoTarget, subcommand: SubCommand, extra: String) { async fn todo(
let pool = ctx.data.read().await ctx: &Context,
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); msg: &Message,
target: TodoTarget,
subcommand: SubCommand,
extra: String,
) {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap(); let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await; let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await;
@ -325,8 +384,7 @@ async fn todo(ctx: &Context, msg: &Message, target: TodoTarget, subcommand: SubC
char_count = display.len(); char_count = display.len();
todo_groups.push(display); todo_groups.push(display);
} } else {
else {
char_count += display.len(); char_count += display.len();
let last_group = todo_groups.pop().unwrap(); let last_group = todo_groups.pop().unwrap();
@ -336,44 +394,54 @@ async fn todo(ctx: &Context, msg: &Message, target: TodoTarget, subcommand: SubC
}); });
for group in todo_groups { for group in todo_groups {
let _ = msg.channel_id.send_message(&ctx, |m| m let _ = msg
.embed(|e| e .channel_id
.title(format!("{} Todo", target.name())) .send_message(&ctx, |m| {
.description(group) m.embed(|e| {
) e.title(format!("{} Todo", target.name()))
).await; .description(group)
})
})
.await;
} }
}, }
SubCommand::Add => { SubCommand::Add => {
let content = user_data.response(&pool, "todo/added").await let content = user_data
.response(&pool, "todo/added")
.await
.replacen("{name}", &extra, 1); .replacen("{name}", &extra, 1);
target.add(extra, pool).await.unwrap(); target.add(extra, pool).await.unwrap();
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
}, }
SubCommand::Remove => { SubCommand::Remove => {
let _ = if let Ok(num) = extra.parse::<usize>() { let _ = if let Ok(num) = extra.parse::<usize>() {
if let Ok(todo) = target.remove(num - 1, &pool).await { if let Ok(todo) = target.remove(num - 1, &pool).await {
let content = user_data.response(&pool, "todo/removed").await let content = user_data.response(&pool, "todo/removed").await.replacen(
.replacen("{}", &todo.value, 1); "{}",
&todo.value,
1,
);
msg.channel_id.say(&ctx, content) msg.channel_id.say(&ctx, content)
} else {
msg.channel_id
.say(&ctx, user_data.response(&pool, "todo/error_index").await)
} }
else { } else {
msg.channel_id.say(&ctx, user_data.response(&pool, "todo/error_index").await) let content = user_data
} .response(&pool, "todo/error_value")
} .await
else {
let content = user_data.response(&pool, "todo/error_value").await
.replacen("{prefix}", &prefix, 1) .replacen("{prefix}", &prefix, 1)
.replacen("{command}", &target.command(Some(subcommand)), 1); .replacen("{command}", &target.command(Some(subcommand)), 1);
msg.channel_id.say(&ctx, content) msg.channel_id.say(&ctx, content)
}.await; }
}, .await;
}
SubCommand::Clear => { SubCommand::Clear => {
target.clear(&pool).await.unwrap(); target.clear(&pool).await.unwrap();
@ -381,6 +449,6 @@ async fn todo(ctx: &Context, msg: &Message, target: TodoTarget, subcommand: SubC
let content = user_data.response(&pool, "todo/cleared").await; let content = user_data.response(&pool, "todo/cleared").await;
let _ = msg.channel_id.say(&ctx, content).await; let _ = msg.channel_id.say(&ctx, content).await;
}, }
} }
} }

View File

@ -6,44 +6,41 @@ pub const MINUTE: u64 = 60;
pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
use std::{ use std::{collections::HashSet, env, iter::FromIterator};
iter::FromIterator,
env,
collections::HashSet,
};
use lazy_static;
use regex::Regex; use regex::Regex;
lazy_static! { lazy_static! {
pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(env::var("SUBSCRIPTION_ROLES") pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(
.map( env::var("SUBSCRIPTION_ROLES")
|var| var .map(|var| var
.split(',') .split(',')
.filter_map(|item| { .filter_map(|item| { item.parse::<u64>().ok() })
item.parse::<u64>().ok() .collect::<Vec<u64>>())
}) .unwrap_or_else(|_| vec![])
.collect::<Vec<u64>>() );
).unwrap_or_else(|_| vec![])); pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD")
.map(|var| var.parse::<u64>().ok())
pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten(); .ok()
.flatten();
pub static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap(); pub static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
pub static ref REGEX_ROLE: Regex = Regex::new(r#"<@&([0-9]+)>"#).unwrap(); pub static ref REGEX_ROLE: Regex = Regex::new(r#"<@&([0-9]+)>"#).unwrap();
pub static ref REGEX_COMMANDS: Regex = Regex::new(r#"([a-z]+)"#).unwrap(); pub static ref REGEX_COMMANDS: Regex = Regex::new(r#"([a-z]+)"#).unwrap();
pub static ref REGEX_ALIAS: Regex =
pub static ref REGEX_ALIAS: Regex = Regex::new(r#"(?P<name>[\S]{1,12})(?:(?: (?P<cmd>.*)$)|$)"#).unwrap(); Regex::new(r#"(?P<name>[\S]{1,12})(?:(?: (?P<cmd>.*)$)|$)"#).unwrap();
pub static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"^\s*<(#|@)(?:!)?(\d+)>\s*$"#).unwrap(); pub static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"^\s*<(#|@)(?:!)?(\d+)>\s*$"#).unwrap();
pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL")
pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL").ok().map(|inner| inner.parse::<i64>().ok()).flatten().unwrap_or(600); .ok()
.map(|inner| inner.parse::<i64>().ok())
pub static ref MAX_TIME: i64 = env::var("MAX_TIME").ok().map(|inner| inner.parse::<i64>().ok()).flatten().unwrap_or(60*60*24*365*50); .flatten()
.unwrap_or(600);
pub static ref LOCAL_TIMEZONE: String = env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string()); pub static ref MAX_TIME: i64 = env::var("MAX_TIME")
.ok()
pub static ref PYTHON_LOCATION: String = env::var("PYTHON_LOCATION").unwrap_or_else(|_| "venv/bin/python3".to_string()); .map(|inner| inner.parse::<i64>().ok())
.flatten()
.unwrap_or(60 * 60 * 24 * 365 * 50);
pub static ref LOCAL_TIMEZONE: String =
env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string());
pub static ref PYTHON_LOCATION: String =
env::var("PYTHON_LOCATION").unwrap_or_else(|_| "venv/bin/python3".to_string());
} }

View File

@ -1,50 +1,29 @@
use async_trait::async_trait; use async_trait::async_trait;
use serenity::{ use serenity::{
constants::MESSAGE_CODE_LIMIT,
http::Http,
Result as SerenityResult,
client::Context, client::Context,
framework::{ constants::MESSAGE_CODE_LIMIT,
Framework, framework::{standard::CommandResult, Framework},
standard::CommandResult,
},
model::{
id::ChannelId,
guild::{
Guild,
Member,
},
channel::{
Channel, GuildChannel, Message,
}
},
futures::prelude::future::BoxFuture, futures::prelude::future::BoxFuture,
http::Http,
model::{
channel::{Channel, GuildChannel, Message},
guild::{Guild, Member},
id::ChannelId,
},
Result as SerenityResult,
}; };
use log::{ use log::{error, info, warn};
warn,
error,
info,
};
use regex::{ use regex::{Match, Regex};
Regex, Match
};
use std::{ use std::{collections::HashMap, env, fmt};
collections::HashMap,
fmt,
env,
};
use crate::{ use crate::{consts::PREFIX, models::ChannelData, SQLPool};
models::ChannelData,
SQLPool,
consts::PREFIX,
};
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, CommandResult>; type CommandFn =
for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, CommandResult>;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum PermissionLevel { pub enum PermissionLevel {
@ -63,27 +42,30 @@ pub struct Command {
impl Command { impl Command {
async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool { async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
if self.required_perms == PermissionLevel::Unrestricted { if self.required_perms == PermissionLevel::Unrestricted {
true true
} } else {
else {
for role_id in &member.roles { for role_id in &member.roles {
let role = role_id.to_role_cached(&ctx).await; let role = role_id.to_role_cached(&ctx).await;
if let Some(cached_role) = role { if let Some(cached_role) = role {
if cached_role.permissions.manage_guild() { if cached_role.permissions.manage_guild()
return true || (self.required_perms == PermissionLevel::Managed
} && cached_role.permissions.manage_messages())
else if self.required_perms == PermissionLevel::Managed && cached_role.permissions.manage_messages() { {
return true return true;
} }
} }
} }
if self.required_perms == PermissionLevel::Managed { if self.required_perms == PermissionLevel::Managed {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
match sqlx::query!( match sqlx::query!(
" "
@ -102,34 +84,41 @@ WHERE
guilds guilds
WHERE WHERE
guild = ?) guild = ?)
", self.name, guild.id.as_u64()) ",
.fetch_all(&pool) self.name,
.await { guild.id.as_u64()
)
.fetch_all(&pool)
.await
{
Ok(rows) => { Ok(rows) => {
let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>(); let role_ids = member
.roles
.iter()
.map(|r| *r.as_u64())
.collect::<Vec<u64>>();
for row in rows { for row in rows {
if role_ids.contains(&row.role) { if role_ids.contains(&row.role) {
return true return true;
} }
} }
false false
} }
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => false,
false
}
Err(e) => { Err(e) => {
warn!("Unexpected error occurred querying command_restrictions: {:?}", e); warn!(
"Unexpected error occurred querying command_restrictions: {:?}",
e
);
false false
} }
} }
} } else {
else {
false false
} }
} }
@ -139,11 +128,11 @@ WHERE
impl fmt::Debug for Command { impl fmt::Debug for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Command") f.debug_struct("Command")
.field("name", &self.name) .field("name", &self.name)
.field("required_perms", &self.required_perms) .field("required_perms", &self.required_perms)
.field("supports_dm", &self.supports_dm) .field("supports_dm", &self.supports_dm)
.field("can_blacklist", &self.can_blacklist) .field("can_blacklist", &self.can_blacklist)
.finish() .finish()
} }
} }
@ -159,12 +148,20 @@ pub struct RegexFramework {
#[async_trait] #[async_trait]
pub trait SendIterator { pub trait SendIterator {
async fn say_lines(self, http: impl AsRef<Http> + Send + Sync + 'async_trait, content: impl Iterator<Item=String> + Send + 'async_trait) -> SerenityResult<()>; async fn say_lines(
self,
http: impl AsRef<Http> + Send + Sync + 'async_trait,
content: impl Iterator<Item = String> + Send + 'async_trait,
) -> SerenityResult<()>;
} }
#[async_trait] #[async_trait]
impl SendIterator for ChannelId { impl SendIterator for ChannelId {
async fn say_lines(self, http: impl AsRef<Http> + Send + Sync + 'async_trait, content: impl Iterator<Item=String> + Send + 'async_trait) -> SerenityResult<()> { async fn say_lines(
self,
http: impl AsRef<Http> + Send + Sync + 'async_trait,
content: impl Iterator<Item = String> + Send + 'async_trait,
) -> SerenityResult<()> {
let mut current_content = String::new(); let mut current_content = String::new();
for line in content { for line in content {
@ -172,8 +169,7 @@ impl SendIterator for ChannelId {
self.say(&http, &current_content).await?; self.say(&http, &current_content).await?;
current_content = line; current_content = line;
} } else {
else {
current_content = format!("{}\n{}", current_content, line); current_content = format!("{}\n{}", current_content, line);
} }
} }
@ -220,10 +216,8 @@ impl RegexFramework {
let command_names; let command_names;
{ {
let mut command_names_vec = self.commands let mut command_names_vec =
.keys() self.commands.keys().map(|k| &k[..]).collect::<Vec<&str>>();
.map(|k| &k[..])
.collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len())); command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
@ -245,7 +239,8 @@ impl RegexFramework {
let dm_command_names; let dm_command_names;
{ {
let mut command_names_vec = self.commands let mut command_names_vec = self
.commands
.iter() .iter()
.filter_map(|(key, command)| { .filter_map(|(key, command)| {
if command.supports_dm { if command.supports_dm {
@ -283,8 +278,11 @@ enum PermissionCheck {
#[async_trait] #[async_trait]
impl Framework for RegexFramework { impl Framework for RegexFramework {
async fn dispatch(&self, ctx: Context, msg: Message) { async fn dispatch(&self, ctx: Context, msg: Message) {
async fn check_self_permissions(
async fn check_self_permissions(ctx: &Context, guild: &Guild, channel: &GuildChannel) -> Result<PermissionCheck, Box<dyn std::error::Error + Sync + Send>> { ctx: &Context,
guild: &Guild,
channel: &GuildChannel,
) -> Result<PermissionCheck, Box<dyn std::error::Error + Sync + Send>> {
let user_id = ctx.cache.current_user_id().await; let user_id = ctx.cache.current_user_id().await;
let guild_perms = guild.member_permissions(user_id); let guild_perms = guild.member_permissions(user_id);
@ -294,31 +292,40 @@ impl Framework for RegexFramework {
Ok(if basic_perms && guild_perms.manage_webhooks() { Ok(if basic_perms && guild_perms.manage_webhooks() {
PermissionCheck::All PermissionCheck::All
} } else if basic_perms {
else if basic_perms {
PermissionCheck::Basic PermissionCheck::Basic
} } else {
else {
PermissionCheck::None PermissionCheck::None
}) })
} }
async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool { async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool {
if let Some(prefix) = prefix_opt { if let Some(prefix) = prefix_opt {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
match sqlx::query!("SELECT prefix FROM guilds WHERE guild = ?", guild.id.as_u64()) match sqlx::query!(
.fetch_one(&pool) "SELECT prefix FROM guilds WHERE guild = ?",
.await { guild.id.as_u64()
Ok(row) => { )
prefix.as_str() == row.prefix .fetch_one(&pool)
} .await
{
Ok(row) => prefix.as_str() == row.prefix,
Err(sqlx::Error::RowNotFound) => { Err(sqlx::Error::RowNotFound) => {
let _ = sqlx::query!("INSERT INTO guilds (guild, name) VALUES (?, ?)", guild.id.as_u64(), guild.name) let _ = sqlx::query!(
.execute(&pool) "INSERT INTO guilds (guild, name) VALUES (?, ?)",
.await; guild.id.as_u64(),
guild.name
)
.execute(&pool)
.await;
prefix.as_str() == "$" prefix.as_str() == "$"
} }
@ -329,49 +336,62 @@ impl Framework for RegexFramework {
false false
} }
} }
} } else {
else {
true true
} }
} }
// gate to prevent analysing messages unnecessarily // gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) || if (msg.author.bot && self.ignore_bots)
msg.tts || || msg.tts
msg.content.is_empty() || || msg.content.is_empty()
!msg.attachments.is_empty() {} || !msg.attachments.is_empty()
{
}
// Guild Command // Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) { else if let (Some(guild), Some(Channel::Guild(channel))) =
(msg.guild(&ctx).await, msg.channel(&ctx).await)
{
let member = guild.member(&ctx, &msg.author).await.unwrap(); let member = guild.member(&ctx, &msg.author).await.unwrap();
if let Some(full_match) = self.command_matcher.captures(&msg.content[..]) { if let Some(full_match) = self.command_matcher.captures(&msg.content[..]) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await { if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
match check_self_permissions(&ctx, &guild, &channel).await { match check_self_permissions(&ctx, &guild, &channel).await {
Ok(perms) => match perms { Ok(perms) => match perms {
PermissionCheck::All => { PermissionCheck::All => {
let pool = ctx.data.read().await let pool = ctx
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data"); .data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap(); let command = self
let channel_data = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await; .commands
.get(full_match.name("cmd").unwrap().as_str())
.unwrap();
let channel_data = ChannelData::from_channel(
msg.channel(&ctx).await.unwrap(),
&pool,
)
.await;
if !command.can_blacklist || !channel_data.map(|c| c.blacklisted).unwrap_or(false) { if !command.can_blacklist
let args = full_match.name("args") || !channel_data.map(|c| c.blacklisted).unwrap_or(false)
{
let args = full_match
.name("args")
.map(|m| m.as_str()) .map(|m| m.as_str())
.unwrap_or("") .unwrap_or("")
.to_string(); .to_string();
if command.check_permissions(&ctx, &guild, &member).await { if command.check_permissions(&ctx, &guild, &member).await {
(command.func)(&ctx, &msg, args).await.unwrap(); (command.func)(&ctx, &msg, args).await.unwrap();
} } else if command.required_perms == PermissionLevel::Restricted
else if command.required_perms == PermissionLevel::Restricted { {
let _ = msg.channel_id.say(&ctx, "You must have permission level `Manage Server` or greater to use this command.").await; let _ = msg.channel_id.say(&ctx, "You must have permission level `Manage Server` or greater to use this command.").await;
} } else if command.required_perms == PermissionLevel::Managed {
else if command.required_perms == PermissionLevel::Managed {
let _ = msg.channel_id.say(&ctx, "You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `{prefix}restrict` command to specify allowed roles.").await; let _ = msg.channel_id.say(&ctx, "You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `{prefix}restrict` command to specify allowed roles.").await;
} }
} }
@ -384,20 +404,26 @@ impl Framework for RegexFramework {
PermissionCheck::None => { PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id); warn!("Missing enough permissions for guild {}", guild.id);
} }
} },
Err(e) => { Err(e) => {
error!("Error occurred getting permissions in guild {}: {:?}", guild.id, e); error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
} }
} }
} }
} }
} }
// DM Command // DM Command
else if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) { else if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap(); let command = self
let args = full_match.name("args") .commands
.get(full_match.name("cmd").unwrap().as_str())
.unwrap();
let args = full_match
.name("args")
.map(|m| m.as_str()) .map(|m| m.as_str())
.unwrap_or("") .unwrap_or("")
.to_string(); .to_string();

View File

@ -1,58 +1,37 @@
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
mod models;
mod framework;
mod commands; mod commands;
mod time_parser;
mod consts; mod consts;
mod framework;
mod models;
mod time_parser;
use serenity::{ use serenity::{
cache::Cache, cache::Cache,
http::{ client::{bridge::gateway::GatewayIntents, Client},
CacheHttp,
client::Http,
},
client::{
bridge::gateway::GatewayIntents,
Client,
},
model::{
id::{
GuildId, UserId,
},
channel::Message,
},
framework::Framework, framework::Framework,
http::{client::Http, CacheHttp},
model::{
channel::Message,
id::{GuildId, UserId},
},
prelude::TypeMapKey, prelude::TypeMapKey,
}; };
use sqlx::{ use sqlx::{
mysql::{MySqlConnection, MySqlPool},
Pool, Pool,
mysql::{
MySqlPool,
MySqlConnection,
}
}; };
use dotenv::dotenv; use dotenv::dotenv;
use std::{ use std::{env, sync::Arc};
sync::Arc,
env,
};
use crate::{ use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds},
consts::{CNC_GUILD, PREFIX, SUBSCRIPTION_ROLES},
framework::RegexFramework, framework::RegexFramework,
consts::{
PREFIX, SUBSCRIPTION_ROLES, CNC_GUILD,
},
commands::{
info_cmds,
reminder_cmds,
todo_cmds,
moderation_cmds,
},
}; };
use serenity::futures::TryFutureExt; use serenity::futures::TryFutureExt;
@ -85,23 +64,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let http = Http::new_with_token(&token); let http = Http::new_with_token(&token);
let logged_in_id = http.get_current_user().map_ok(|user| user.id.as_u64().to_owned()).await?; let logged_in_id = http
.get_current_user()
.map_ok(|user| user.id.as_u64().to_owned())
.await?;
let framework = RegexFramework::new(logged_in_id) let framework = RegexFramework::new(logged_in_id)
.ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1")) .ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
.default_prefix(&env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string())) .default_prefix(&env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string()))
.add_command("ping", &info_cmds::PING_COMMAND) .add_command("ping", &info_cmds::PING_COMMAND)
.add_command("help", &info_cmds::HELP_COMMAND) .add_command("help", &info_cmds::HELP_COMMAND)
.add_command("info", &info_cmds::INFO_COMMAND) .add_command("info", &info_cmds::INFO_COMMAND)
.add_command("invite", &info_cmds::INFO_COMMAND) .add_command("invite", &info_cmds::INFO_COMMAND)
.add_command("donate", &info_cmds::DONATE_COMMAND) .add_command("donate", &info_cmds::DONATE_COMMAND)
.add_command("dashboard", &info_cmds::DASHBOARD_COMMAND) .add_command("dashboard", &info_cmds::DASHBOARD_COMMAND)
.add_command("clock", &info_cmds::CLOCK_COMMAND) .add_command("clock", &info_cmds::CLOCK_COMMAND)
.add_command("timer", &reminder_cmds::TIMER_COMMAND) .add_command("timer", &reminder_cmds::TIMER_COMMAND)
.add_command("remind", &reminder_cmds::REMIND_COMMAND) .add_command("remind", &reminder_cmds::REMIND_COMMAND)
.add_command("r", &reminder_cmds::REMIND_COMMAND) .add_command("r", &reminder_cmds::REMIND_COMMAND)
.add_command("interval", &reminder_cmds::INTERVAL_COMMAND) .add_command("interval", &reminder_cmds::INTERVAL_COMMAND)
@ -109,36 +87,39 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.add_command("natural", &reminder_cmds::NATURAL_COMMAND) .add_command("natural", &reminder_cmds::NATURAL_COMMAND)
.add_command("n", &reminder_cmds::NATURAL_COMMAND) .add_command("n", &reminder_cmds::NATURAL_COMMAND)
.add_command("", &reminder_cmds::NATURAL_COMMAND) .add_command("", &reminder_cmds::NATURAL_COMMAND)
.add_command("look", &reminder_cmds::LOOK_COMMAND) .add_command("look", &reminder_cmds::LOOK_COMMAND)
.add_command("del", &reminder_cmds::DELETE_COMMAND) .add_command("del", &reminder_cmds::DELETE_COMMAND)
.add_command("todo", &todo_cmds::TODO_PARSE_COMMAND) .add_command("todo", &todo_cmds::TODO_PARSE_COMMAND)
.add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND) .add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND)
.add_command("restrict", &moderation_cmds::RESTRICT_COMMAND) .add_command("restrict", &moderation_cmds::RESTRICT_COMMAND)
.add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND) .add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND)
.add_command("prefix", &moderation_cmds::PREFIX_COMMAND) .add_command("prefix", &moderation_cmds::PREFIX_COMMAND)
.add_command("lang", &moderation_cmds::LANGUAGE_COMMAND) .add_command("lang", &moderation_cmds::LANGUAGE_COMMAND)
.add_command("pause", &reminder_cmds::PAUSE_COMMAND) .add_command("pause", &reminder_cmds::PAUSE_COMMAND)
.add_command("offset", &reminder_cmds::OFFSET_COMMAND) .add_command("offset", &reminder_cmds::OFFSET_COMMAND)
.add_command("nudge", &reminder_cmds::NUDGE_COMMAND) .add_command("nudge", &reminder_cmds::NUDGE_COMMAND)
.add_command("alias", &moderation_cmds::ALIAS_COMMAND) .add_command("alias", &moderation_cmds::ALIAS_COMMAND)
.add_command("a", &moderation_cmds::ALIAS_COMMAND) .add_command("a", &moderation_cmds::ALIAS_COMMAND)
.build(); .build();
let framework_arc = Arc::new(Box::new(framework) as Box<dyn Framework + Send + Sync>); let framework_arc = Arc::new(Box::new(framework) as Box<dyn Framework + Send + Sync>);
let mut client = Client::new(&token) let mut client = Client::new(&token)
.intents(GatewayIntents::GUILD_MESSAGES | GatewayIntents::GUILDS | GatewayIntents::DIRECT_MESSAGES) .intents(
GatewayIntents::GUILD_MESSAGES
| GatewayIntents::GUILDS
| GatewayIntents::DIRECT_MESSAGES,
)
.framework_arc(framework_arc.clone()) .framework_arc(framework_arc.clone())
.await.expect("Error occurred creating client"); .await
.expect("Error occurred creating client");
{ {
let pool = MySqlPool::new(&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment")).await.unwrap(); let pool = MySqlPool::new(
&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment"),
)
.await
.unwrap();
let mut data = client.data.write().await; let mut data = client.data.write().await;
@ -152,27 +133,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(()) Ok(())
} }
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool { pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD { if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await; let guild_member = GuildId(subscription_guild)
.member(cache_http, user_id)
.await;
if let Ok(member) = guild_member { if let Ok(member) = guild_member {
for role in member.roles { for role in member.roles {
if SUBSCRIPTION_ROLES.contains(role.as_u64()) { if SUBSCRIPTION_ROLES.contains(role.as_u64()) {
return true return true;
} }
} }
} }
false false
} } else {
else {
true true
} }
} }
pub async fn check_subscription_on_message(cache_http: impl CacheHttp + AsRef<Cache>, msg: &Message) -> bool { pub async fn check_subscription_on_message(
check_subscription(&cache_http, &msg.author).await || cache_http: impl CacheHttp + AsRef<Cache>,
if let Some(guild) = msg.guild(&cache_http).await { check_subscription(&cache_http, guild.owner_id).await } else { false } msg: &Message,
) -> bool {
check_subscription(&cache_http, &msg.author).await
|| if let Some(guild) = msg.guild(&cache_http).await {
check_subscription(&cache_http, guild.owner_id).await
} else {
false
}
} }

View File

@ -1,11 +1,6 @@
use serenity::{ use serenity::{
http::CacheHttp, http::CacheHttp,
model::{ model::{channel::Channel, guild::Guild, id::GuildId, user::User},
id::GuildId,
guild::Guild,
channel::Channel,
user::User,
}
}; };
use std::env; use std::env;
@ -24,49 +19,67 @@ pub struct GuildData {
} }
impl GuildData { impl GuildData {
pub async fn prefix_from_id<T: Into<GuildId>>(guild_id_opt: Option<T>, pool: &MySqlPool) -> String { pub async fn prefix_from_id<T: Into<GuildId>>(
guild_id_opt: Option<T>,
pool: &MySqlPool,
) -> String {
if let Some(guild_id) = guild_id_opt { if let Some(guild_id) = guild_id_opt {
let guild_id = guild_id.into().as_u64().to_owned(); let guild_id = guild_id.into().as_u64().to_owned();
let row = sqlx::query!( let row = sqlx::query!(
" "
SELECT prefix FROM guilds WHERE guild = ? SELECT prefix FROM guilds WHERE guild = ?
", guild_id ",
guild_id
) )
.fetch_one(pool) .fetch_one(pool)
.await; .await;
row.map_or_else(|_| env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string()), |r| r.prefix) row.map_or_else(
} |_| env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string()),
else { |r| r.prefix,
)
} else {
env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string()) env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string())
} }
} }
pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> { pub async fn from_guild(
guild: Guild,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let guild_id = guild.id.as_u64().to_owned(); let guild_id = guild.id.as_u64().to_owned();
if let Ok(g) = sqlx::query_as!(Self, if let Ok(g) = sqlx::query_as!(
Self,
" "
SELECT id, name, prefix FROM guilds WHERE guild = ? SELECT id, name, prefix FROM guilds WHERE guild = ?
", guild_id) ",
.fetch_one(pool) guild_id
.await { )
.fetch_one(pool)
.await
{
Ok(g) Ok(g)
} } else {
else {
sqlx::query!( sqlx::query!(
" "
INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?) INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?)
", guild_id, guild.name, env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string())) ",
.execute(&pool.clone()) guild_id,
.await?; guild.name,
env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string())
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as!(Self, Ok(sqlx::query_as!(
" Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ? SELECT id, name, prefix FROM guilds WHERE guild = ?
", guild_id) ",
guild_id
)
.fetch_one(pool) .fetch_one(pool)
.await?) .await?)
} }
@ -76,9 +89,14 @@ SELECT id, name, prefix FROM guilds WHERE guild = ?
sqlx::query!( sqlx::query!(
" "
UPDATE guilds SET name = ?, prefix = ? WHERE id = ? UPDATE guilds SET name = ?, prefix = ? WHERE id = ?
", self.name, self.prefix, self.id) ",
.execute(pool) self.name,
.await.unwrap(); self.prefix,
self.id
)
.execute(pool)
.await
.unwrap();
} }
} }
@ -103,9 +121,10 @@ SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_u
.await.ok() .await.ok()
} }
pub async fn from_channel(channel: Channel, pool: &MySqlPool) pub async fn from_channel(
-> Result<Self, Box<dyn std::error::Error + Sync + Send>> channel: Channel,
{ pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel.id().as_u64().to_owned(); let channel_id = channel.id().as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self, if let Ok(c) = sqlx::query_as_unchecked!(Self,
@ -162,19 +181,25 @@ pub struct UserData {
} }
impl UserData { impl UserData {
pub async fn from_user(user: &User, ctx: impl CacheHttp, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> { pub async fn from_user(
user: &User,
ctx: impl CacheHttp,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().to_owned(); let user_id = user.id.as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self, if let Ok(c) = sqlx::query_as_unchecked!(
Self,
" "
SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ? SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
", user_id) ",
.fetch_one(pool) user_id
.await { )
.fetch_one(pool)
.await
{
Ok(c) Ok(c)
} } else {
else {
let dm_channel = user.create_dm_channel(ctx).await?; let dm_channel = user.create_dm_channel(ctx).await?;
let dm_id = dm_channel.id.as_u64().to_owned(); let dm_id = dm_channel.id.as_u64().to_owned();
@ -183,9 +208,11 @@ SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
sqlx::query!( sqlx::query!(
" "
INSERT INTO channels (channel) VALUES (?) INSERT INTO channels (channel) VALUES (?)
", dm_id) ",
.execute(&pool_c) dm_id
.await?; )
.execute(&pool_c)
.await?;
sqlx::query!( sqlx::query!(
" "
@ -194,12 +221,15 @@ INSERT INTO users (user, name, dm_channel) VALUES (?, ?, (SELECT id FROM channel
.execute(&pool_c) .execute(&pool_c)
.await?; .await?;
Ok(sqlx::query_as_unchecked!(Self, Ok(sqlx::query_as_unchecked!(
Self,
" "
SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ? SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
", user_id) ",
.fetch_one(pool) user_id
.await?) )
.fetch_one(pool)
.await?)
} }
} }
@ -207,9 +237,15 @@ SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
sqlx::query!( sqlx::query!(
" "
UPDATE users SET name = ?, language = ?, timezone = ? WHERE id = ? UPDATE users SET name = ?, language = ?, timezone = ? WHERE id = ?
", self.name, self.language, self.timezone, self.id) ",
.execute(pool) self.name,
.await.unwrap(); self.language,
self.timezone,
self.id
)
.execute(pool)
.await
.unwrap();
} }
pub async fn response(&self, pool: &MySqlPool, name: &str) -> String { pub async fn response(&self, pool: &MySqlPool, name: &str) -> String {
@ -221,7 +257,8 @@ SELECT value FROM strings WHERE (language = ? OR language = 'EN') AND name = ? O
.await .await
.unwrap_or_else(|_| panic!("No string with that name: {}", name)); .unwrap_or_else(|_| panic!("No string with that name: {}", name));
row.value.unwrap_or_else(|| panic!("Null string with that name: {}", name)) row.value
.unwrap_or_else(|| panic!("Null string with that name: {}", name))
} }
pub fn timezone(&self) -> Tz { pub fn timezone(&self) -> Tz {
@ -237,33 +274,41 @@ pub struct Timer {
impl Timer { impl Timer {
pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> { pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> {
sqlx::query_as_unchecked!(Timer, sqlx::query_as_unchecked!(
Timer,
" "
SELECT name, start_time, owner FROM timers WHERE owner = ? SELECT name, start_time, owner FROM timers WHERE owner = ?
", owner) ",
.fetch_all(pool) owner
.await )
.unwrap() .fetch_all(pool)
.await
.unwrap()
} }
pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 { pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 {
sqlx::query!( sqlx::query!(
" "
SELECT COUNT(1) as count FROM timers WHERE owner = ? SELECT COUNT(1) as count FROM timers WHERE owner = ?
", owner) ",
.fetch_one(pool) owner
.await )
.unwrap() .fetch_one(pool)
.count as u32 .await
.unwrap()
.count as u32
} }
pub async fn create(name: &str, owner: u64, pool: &MySqlPool) { pub async fn create(name: &str, owner: u64, pool: &MySqlPool) {
sqlx::query!( sqlx::query!(
" "
INSERT INTO timers (name, owner) VALUES (?, ?) INSERT INTO timers (name, owner) VALUES (?, ?)
", name, owner) ",
.execute(pool) name,
.await owner
.unwrap(); )
.execute(pool)
.await
.unwrap();
} }
} }

View File

@ -1,16 +1,9 @@
use std::time::{ use std::time::{SystemTime, UNIX_EPOCH};
SystemTime,
UNIX_EPOCH,
};
use std::fmt::{ use std::fmt::{Display, Formatter, Result as FmtResult};
Formatter,
Display,
Result as FmtResult,
};
use chrono_tz::Tz;
use chrono::TimeZone; use chrono::TimeZone;
use chrono_tz::Tz;
use std::convert::TryFrom; use std::convert::TryFrom;
#[derive(Debug)] #[derive(Debug)]
@ -55,8 +48,7 @@ impl TimeParser {
let parse_type = if input.contains('/') || input.contains(':') { let parse_type = if input.contains('/') || input.contains(':') {
ParseType::Explicit ParseType::Explicit
} } else {
else {
ParseType::Displacement ParseType::Displacement
}; };
@ -70,9 +62,7 @@ impl TimeParser {
pub fn timestamp(&self) -> Result<i64, InvalidTime> { pub fn timestamp(&self) -> Result<i64, InvalidTime> {
match self.parse_type { match self.parse_type {
ParseType::Explicit => { ParseType::Explicit => Ok(self.process_explicit()?),
Ok(self.process_explicit()?)
},
ParseType::Displacement => { ParseType::Displacement => {
let now = SystemTime::now(); let now = SystemTime::now();
@ -81,7 +71,7 @@ impl TimeParser {
.expect("Time calculated as going backwards. Very bad"); .expect("Time calculated as going backwards. Very bad");
Ok(since_epoch.as_secs() as i64 + self.process_displacement()?) Ok(since_epoch.as_secs() as i64 + self.process_displacement()?)
}, }
} }
} }
@ -94,11 +84,9 @@ impl TimeParser {
.expect("Time calculated as going backwards. Very bad"); .expect("Time calculated as going backwards. Very bad");
Ok(self.process_explicit()? - since_epoch.as_secs() as i64) Ok(self.process_explicit()? - since_epoch.as_secs() as i64)
}, }
ParseType::Displacement => { ParseType::Displacement => Ok(self.process_displacement()?),
Ok(self.process_displacement()?)
},
} }
} }
@ -112,7 +100,7 @@ impl TimeParser {
0 => Ok("%d-".to_string()), 0 => Ok("%d-".to_string()),
1 => Ok("%d/%m-".to_string()), 1 => Ok("%d/%m-".to_string()),
2 => Ok("%d/%m/%Y-".to_string()), 2 => Ok("%d/%m/%Y-".to_string()),
_ => Err(InvalidTime::ParseErrorDMY) _ => Err(InvalidTime::ParseErrorDMY),
} }
} else { } else {
Ok("".to_string()) Ok("".to_string())
@ -122,13 +110,16 @@ impl TimeParser {
match colons { match colons {
1 => Ok("%H:%M"), 1 => Ok("%H:%M"),
2 => Ok("%H:%M:%S"), 2 => Ok("%H:%M:%S"),
_ => Err(InvalidTime::ParseErrorHMS) _ => Err(InvalidTime::ParseErrorHMS),
} }
} else { } else {
Ok("") Ok("")
}?; }?;
let dt = self.timezone.datetime_from_str(self.time_string.as_str(), &parse_string).map_err(|_| InvalidTime::ParseErrorChrono)?; let dt = self
.timezone
.datetime_from_str(self.time_string.as_str(), &parse_string)
.map_err(|_| InvalidTime::ParseErrorChrono)?;
Ok(dt.timestamp() as i64) Ok(dt.timestamp() as i64)
} }
@ -143,40 +134,42 @@ impl TimeParser {
for character in self.time_string.chars() { for character in self.time_string.chars() {
match character { match character {
's' => { 's' => {
seconds = current_buffer.parse::<i64>().unwrap(); seconds = current_buffer.parse::<i64>().unwrap();
current_buffer = String::from("0"); current_buffer = String::from("0");
}, }
'm' => { 'm' => {
minutes = current_buffer.parse::<i64>().unwrap(); minutes = current_buffer.parse::<i64>().unwrap();
current_buffer = String::from("0"); current_buffer = String::from("0");
}, }
'h' => { 'h' => {
hours = current_buffer.parse::<i64>().unwrap(); hours = current_buffer.parse::<i64>().unwrap();
current_buffer = String::from("0"); current_buffer = String::from("0");
}, }
'd' => { 'd' => {
days = current_buffer.parse::<i64>().unwrap(); days = current_buffer.parse::<i64>().unwrap();
current_buffer = String::from("0"); current_buffer = String::from("0");
}, }
c => { c => {
if c.is_digit(10) { if c.is_digit(10) {
current_buffer += &c.to_string(); current_buffer += &c.to_string();
} else {
return Err(InvalidTime::ParseErrorDisplacement);
} }
else { }
return Err(InvalidTime::ParseErrorDisplacement)
}
},
} }
} }
let full = (seconds + (minutes * 60) + (hours * 3600) + (days * 86400) + current_buffer.parse::<i64>().unwrap()) * let full = (seconds
if self.inverted { -1 } else { 1 }; + (minutes * 60)
+ (hours * 3600)
+ (days * 86400)
+ current_buffer.parse::<i64>().unwrap())
* if self.inverted { -1 } else { 1 };
Ok(full) Ok(full)
} }