ran rustfmt over project. cleared up a couple of clippy things

This commit is contained in:
jude 2020-10-12 21:01:27 +01:00
parent 88596fb399
commit c9fd2fea81
10 changed files with 1374 additions and 939 deletions

View File

@ -1,30 +1,15 @@
use regex_command_attr::command;
use serenity::{
client::Context,
model::{
channel::{
Message,
},
},
framework::standard::CommandResult,
};
use serenity::{client::Context, framework::standard::CommandResult, model::channel::Message};
use chrono::offset::Utc;
use crate::{
models::{
UserData,
GuildData,
},
THEME_COLOR,
SQLPool,
models::{GuildData, UserData},
SQLPool, THEME_COLOR,
};
use std::time::{
SystemTime,
UNIX_EPOCH
};
use std::time::{SystemTime, UNIX_EPOCH};
#[command]
#[can_blacklist(false)]
@ -36,7 +21,10 @@ async fn ping(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let delta = since_epoch.as_millis() as i64 - msg.timestamp.timestamp_millis();
let _ = msg.channel_id.say(&ctx, format!("Time taken to receive message: {}ms", delta)).await;
let _ = msg
.channel_id
.say(&ctx, format!("Time taken to receive message: {}ms", delta))
.await;
Ok(())
}
@ -44,92 +32,131 @@ async fn ping(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
#[command]
#[can_blacklist(false)]
async fn help(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let desc = user_data.response(&pool, "help").await;
msg.channel_id.send_message(ctx, |m| m
.embed(move |e| e
.title("Help")
.description(desc)
.color(THEME_COLOR)
)
).await?;
msg.channel_id
.send_message(ctx, |m| {
m.embed(move |e| e.title("Help").description(desc).color(THEME_COLOR))
})
.await?;
Ok(())
}
#[command]
async fn info(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap();
let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool)
.await
.unwrap();
let desc = user_data.response(&pool, "info").await
let desc = user_data
.response(&pool, "info")
.await
.replacen("{user}", &ctx.cache.current_user().await.name, 1)
.replacen("{prefix}", &guild_data.prefix, 1);
msg.channel_id.send_message(ctx, |m| m
.embed(move |e| e
.title("Info")
.description(desc)
.color(THEME_COLOR)
)
).await?;
msg.channel_id
.send_message(ctx, |m| {
m.embed(move |e| e.title("Info").description(desc).color(THEME_COLOR))
})
.await?;
Ok(())
}
#[command]
async fn donate(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let desc = user_data.response(&pool, "donate").await;
msg.channel_id.send_message(ctx, |m| m
.embed(move |e| e
.title("Donate")
.description(desc)
.color(THEME_COLOR)
)
).await?;
msg.channel_id
.send_message(ctx, |m| {
m.embed(move |e| e.title("Donate").description(desc).color(THEME_COLOR))
})
.await?;
Ok(())
}
#[command]
async fn dashboard(ctx: &Context, msg: &Message, _args: String) -> CommandResult {
msg.channel_id.send_message(ctx, |m| m
.embed(move |e| e
.title("Dashboard")
.description("https://reminder-bot.com/dashboard")
.color(THEME_COLOR)
)
).await?;
msg.channel_id
.send_message(ctx, |m| {
m.embed(move |e| {
e.title("Dashboard")
.description("https://reminder-bot.com/dashboard")
.color(THEME_COLOR)
})
})
.await?;
Ok(())
}
#[command]
async fn clock(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let now = Utc::now().with_timezone(&user_data.timezone());
if args == "12" {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "clock/time").await.replacen("{}", &now.format("%I:%M:%S %p").to_string(), 1)).await;
}
else {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "clock/time").await.replacen("{}", &now.format("%H:%M:%S").to_string(), 1)).await;
let _ = msg
.channel_id
.say(
&ctx,
user_data.response(&pool, "clock/time").await.replacen(
"{}",
&now.format("%I:%M:%S %p").to_string(),
1,
),
)
.await;
} else {
let _ = msg
.channel_id
.say(
&ctx,
user_data.response(&pool, "clock/time").await.replacen(
"{}",
&now.format("%H:%M:%S").to_string(),
1,
),
)
.await;
}
Ok(())

View File

@ -1,4 +1,4 @@
pub mod info_cmds;
pub mod moderation_cmds;
pub mod reminder_cmds;
pub mod todo_cmds;
pub mod moderation_cmds;

View File

@ -2,16 +2,8 @@ use regex_command_attr::command;
use serenity::{
client::Context,
model::{
id::RoleId,
channel::{
Message,
},
},
framework::{
Framework,
standard::CommandResult,
},
framework::{standard::CommandResult, Framework},
model::{channel::Message, id::RoleId},
};
use chrono_tz::Tz;
@ -21,41 +13,40 @@ use chrono::offset::Utc;
use inflector::Inflector;
use crate::{
models::{
ChannelData,
UserData,
GuildData,
},
SQLPool,
FrameworkCtx,
consts::{REGEX_ALIAS, REGEX_CHANNEL, REGEX_COMMANDS, REGEX_ROLE},
framework::SendIterator,
consts::{
REGEX_ALIAS,
REGEX_CHANNEL,
REGEX_COMMANDS,
REGEX_ROLE,
},
models::{ChannelData, GuildData, UserData},
FrameworkCtx, SQLPool,
};
use std::iter;
#[command]
#[supports_dm(false)]
#[permission_level(Restricted)]
#[can_blacklist(false)]
async fn blacklist(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let capture_opt = REGEX_CHANNEL.captures(&args).map(|cap| cap.get(1)).flatten();
let capture_opt = REGEX_CHANNEL
.captures(&args)
.map(|cap| cap.get(1))
.flatten();
let mut channel = match capture_opt {
Some(capture) =>
ChannelData::from_id(capture.as_str().parse::<u64>().unwrap(), &pool).await.unwrap(),
Some(capture) => ChannelData::from_id(capture.as_str().parse::<u64>().unwrap(), &pool)
.await
.unwrap(),
None =>
ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await.unwrap(),
None => ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool)
.await
.unwrap(),
};
channel.blacklisted = !channel.blacklisted;
@ -63,8 +54,7 @@ async fn blacklist(ctx: &Context, msg: &Message, args: String) -> CommandResult
if channel.blacklisted {
let _ = msg.channel_id.say(&ctx, "Blacklisted").await;
}
else {
} else {
let _ = msg.channel_id.say(&ctx, "Unblacklisted").await;
}
@ -73,8 +63,13 @@ async fn blacklist(ctx: &Context, msg: &Message, args: String) -> CommandResult
#[command]
async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
@ -86,7 +81,9 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let now = Utc::now().with_timezone(&user_data.timezone());
let content = user_data.response(&pool, "timezone/set_p").await
let content = user_data
.response(&pool, "timezone/set_p")
.await
.replacen("{timezone}", &user_data.timezone, 1)
.replacen("{time}", &now.format("%H:%M").to_string(), 1);
@ -94,13 +91,23 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
}
Err(_) => {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "timezone/no_timezone").await).await;
let _ = msg
.channel_id
.say(
&ctx,
user_data.response(&pool, "timezone/no_timezone").await,
)
.await;
}
}
}
else {
let content = user_data.response(&pool, "timezone/no_argument").await
.replace("{prefix}", &GuildData::prefix_from_id(msg.guild_id, &pool).await)
} else {
let content = user_data
.response(&pool, "timezone/no_argument")
.await
.replace(
"{prefix}",
&GuildData::prefix_from_id(msg.guild_id, &pool).await,
)
.replacen("{timezone}", &user_data.timezone, 1);
let _ = msg.channel_id.say(&ctx, content).await;
@ -111,25 +118,36 @@ async fn timezone(ctx: &Context, msg: &Message, args: String) -> CommandResult {
#[command]
async fn language(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let mut user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
match sqlx::query!(
"
SELECT code FROM languages WHERE code = ? OR name = ?
", args, args)
.fetch_one(&pool)
.await {
",
args,
args
)
.fetch_one(&pool)
.await
{
Ok(row) => {
user_data.language = row.code;
user_data.commit_changes(&pool).await;
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "lang/set_p").await).await;
},
let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "lang/set_p").await)
.await;
}
Err(_) => {
let language_codes = sqlx::query!("SELECT name, code FROM languages")
@ -137,15 +155,24 @@ SELECT code FROM languages WHERE code = ? OR name = ?
.await
.unwrap()
.iter()
.map(|language| format!("{} ({})", language.name.to_title_case(), language.code.to_uppercase()))
.map(|language| {
format!(
"{} ({})",
language.name.to_title_case(),
language.code.to_uppercase()
)
})
.collect::<Vec<String>>()
.join("\n");
let content = user_data.response(&pool, "lang/invalid").await
.replacen("{}", &language_codes, 1);
let content =
user_data
.response(&pool, "lang/invalid")
.await
.replacen("{}", &language_codes, 1);
let _ = msg.channel_id.say(&ctx, content).await;
},
}
}
Ok(())
@ -155,24 +182,38 @@ SELECT code FROM languages WHERE code = ? OR name = ?
#[supports_dm(false)]
#[permission_level(Restricted)]
async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let mut guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap();
let mut guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool)
.await
.unwrap();
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
if args.len() > 5 {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "prefix/too_long").await).await;
}
else if args.is_empty() {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "prefix/no_argument").await).await;
}
else {
let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "prefix/too_long").await)
.await;
} else if args.is_empty() {
let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "prefix/no_argument").await)
.await;
} else {
guild_data.prefix = args;
guild_data.commit_changes(&pool).await;
let content = user_data.response(&pool, "prefix/success").await
.replacen("{prefix}", &guild_data.prefix, 1);
let content = user_data.response(&pool, "prefix/success").await.replacen(
"{prefix}",
&guild_data.prefix,
1,
);
let _ = msg.channel_id.say(&ctx, content).await;
}
@ -184,17 +225,31 @@ async fn prefix(ctx: &Context, msg: &Message, args: String) -> CommandResult {
#[supports_dm(false)]
#[permission_level(Restricted)]
async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool).await.unwrap();
let guild_data = GuildData::from_guild(msg.guild(&ctx).await.unwrap(), &pool)
.await
.unwrap();
let role_tag_match = REGEX_ROLE.find(&args);
if let Some(role_tag) = role_tag_match {
let commands = REGEX_COMMANDS.find_iter(&args.to_lowercase()).map(|c| c.as_str().to_string()).collect::<Vec<String>>();
let role_id = RoleId(role_tag.as_str()[3..role_tag.as_str().len()-1].parse::<u64>().unwrap());
let commands = REGEX_COMMANDS
.find_iter(&args.to_lowercase())
.map(|c| c.as_str().to_string())
.collect::<Vec<String>>();
let role_id = RoleId(
role_tag.as_str()[3..role_tag.as_str().len() - 1]
.parse::<u64>()
.unwrap(),
);
let role_opt = role_id.to_role_cached(&ctx).await;
@ -202,20 +257,28 @@ async fn restrict(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let _ = sqlx::query!(
"
DELETE FROM command_restrictions WHERE role_id = (SELECT id FROM roles WHERE role = ?)
", role.id.as_u64())
.execute(&pool)
.await;
",
role.id.as_u64()
)
.execute(&pool)
.await;
if commands.is_empty() {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/disabled").await).await;
}
else {
let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "restrict/disabled").await)
.await;
} else {
let _ = sqlx::query!(
"
INSERT IGNORE INTO roles (role, name, guild_id) VALUES (?, ?, ?)
", role.id.as_u64(), role.name, guild_data.id)
.execute(&pool)
.await;
",
role.id.as_u64(),
role.name,
guild_data.id
)
.execute(&pool)
.await;
for command in commands {
let res = sqlx::query!(
@ -228,18 +291,22 @@ INSERT INTO command_restrictions (role_id, command) VALUES ((SELECT id FROM role
if res.is_err() {
println!("{:?}", res);
let content = user_data.response(&pool, "restrict/failure").await
let content = user_data
.response(&pool, "restrict/failure")
.await
.replacen("{command}", &command, 1);
let _ = msg.channel_id.say(&ctx, content).await;
}
}
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/enabled").await).await;
let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "restrict/enabled").await)
.await;
}
}
}
else if args.is_empty() {
} else if args.is_empty() {
let guild_id = msg.guild_id.unwrap().as_u64().to_owned();
let rows = sqlx::query!(
@ -254,18 +321,29 @@ ON
roles.id = command_restrictions.role_id
WHERE
roles.guild_id = (SELECT id FROM guilds WHERE guild = ?)
", guild_id)
.fetch_all(&pool)
.await
.unwrap();
",
guild_id
)
.fetch_all(&pool)
.await
.unwrap();
let display_inner = rows.iter().map(|row| format!("<@&{}> can use {}", row.role, row.command)).collect::<Vec<String>>().join("\n");
let display = user_data.response(&pool, "restrict/allowed").await.replacen("{}", &display_inner, 1);
let display_inner = rows
.iter()
.map(|row| format!("<@&{}> can use {}", row.role, row.command))
.collect::<Vec<String>>()
.join("\n");
let display = user_data
.response(&pool, "restrict/allowed")
.await
.replacen("{}", &display_inner, 1);
let _ = msg.channel_id.say(&ctx, display).await;
}
else {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "restrict/help").await).await;
} else {
let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "restrict/help").await)
.await;
}
Ok(())
@ -275,8 +353,13 @@ WHERE
#[supports_dm(false)]
#[permission_level(Managed)]
async fn alias(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
@ -293,21 +376,21 @@ async fn alias(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let aliases = sqlx::query!(
"
SELECT name, command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?)
", guild_id)
.fetch_all(&pool)
.await
.unwrap();
",
guild_id
)
.fetch_all(&pool)
.await
.unwrap();
let content = iter::once("Aliases:".to_string())
.chain(
aliases
.iter()
.map(|row| format!("**{}**: `{}`", row.name, row.command)
)
);
let content = iter::once("Aliases:".to_string()).chain(
aliases
.iter()
.map(|row| format!("**{}**: `{}`", row.name, row.command)),
);
let _ = msg.channel_id.say_lines(&ctx, content).await;
},
}
"remove" => {
if let Some(command) = command_opt {
@ -322,19 +405,27 @@ SELECT COUNT(1) AS count FROM command_aliases WHERE name = ? AND guild_id = (SEL
sqlx::query!(
"
DELETE FROM command_aliases WHERE name = ? AND guild_id = (SELECT id FROM guilds WHERE guild = ?)
", command, guild_id)
.execute(&pool)
.await
.unwrap();
",
command,
guild_id
)
.execute(&pool)
.await
.unwrap();
let content = user_data.response(&pool, "alias/removed").await.replace("{count}", &deleted_count.count.to_string());
let content = user_data
.response(&pool, "alias/removed")
.await
.replace("{count}", &deleted_count.count.to_string());
let _ = msg.channel_id.say(&ctx, content).await;
} else {
let _ = msg
.channel_id
.say(&ctx, user_data.response(&pool, "alias/help").await)
.await;
}
else {
let _ = msg.channel_id.say(&ctx, user_data.response(&pool, "alias/help").await).await;
}
},
}
name => {
if let Some(command) = command_opt {
@ -355,11 +446,13 @@ UPDATE command_aliases SET command = ? WHERE guild_id = (SELECT id FROM guilds W
.unwrap();
}
let content = user_data.response(&pool, "alias/created").await.replace("{name}", name);
let content = user_data
.response(&pool, "alias/created")
.await
.replace("{name}", name);
let _ = msg.channel_id.say(&ctx, content).await;
}
else {
} else {
match sqlx::query!(
"
SELECT command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND name = ?
@ -386,10 +479,12 @@ SELECT command FROM command_aliases WHERE guild_id = (SELECT id FROM guilds WHER
}
}
}
}
else {
} else {
let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await;
let content = user_data.response(&pool, "alias/help").await.replace("{prefix}", &prefix);
let content = user_data
.response(&pool, "alias/help")
.await
.replace("{prefix}", &prefix);
let _ = msg.channel_id.say(&ctx, content).await;
}

File diff suppressed because it is too large Load Diff

View File

@ -1,26 +1,19 @@
use regex_command_attr::command;
use serenity::{
constants::MESSAGE_CODE_LIMIT,
client::Context,
model::{
id::{
UserId, GuildId, ChannelId,
},
channel::{
Message,
},
},
constants::MESSAGE_CODE_LIMIT,
framework::standard::CommandResult,
model::{
channel::Message,
id::{ChannelId, GuildId, UserId},
},
};
use std::fmt;
use crate::{
models::{
UserData,
GuildData,
},
models::{GuildData, UserData},
SQLPool,
};
use sqlx::MySqlPool;
@ -54,18 +47,15 @@ impl TodoTarget {
pub fn command(&self, subcommand_opt: Option<SubCommand>) -> String {
let context = if self.channel.is_some() {
"channel"
}
else if self.guild.is_some() {
} else if self.guild.is_some() {
"guild"
}
else {
} else {
"user"
};
if let Some(subcommand) = subcommand_opt {
format!("todo {} {}", context, subcommand.to_string())
}
else {
} else {
format!("todo {}", context)
}
}
@ -73,43 +63,56 @@ impl TodoTarget {
pub fn name(&self) -> String {
if self.channel.is_some() {
"Channel"
}
else if self.guild.is_some() {
} else if self.guild.is_some() {
"Guild"
}
else {
} else {
"User"
}.to_string()
}
.to_string()
}
pub async fn view(&self, pool: MySqlPool) -> Result<Vec<Todo>, Box<dyn std::error::Error + Send + Sync>> {
pub async fn view(
&self,
pool: MySqlPool,
) -> Result<Vec<Todo>, Box<dyn std::error::Error + Send + Sync>> {
Ok(if let Some(cid) = self.channel {
sqlx::query_as!(Todo,
sqlx::query_as!(
Todo,
"
SELECT * FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
", cid.as_u64())
.fetch_all(&pool)
.await?
}
else if let Some(gid) = self.guild {
sqlx::query_as!(Todo,
",
cid.as_u64()
)
.fetch_all(&pool)
.await?
} else if let Some(gid) = self.guild {
sqlx::query_as!(
Todo,
"
SELECT * FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
", gid.as_u64())
.fetch_all(&pool)
.await?
}
else {
sqlx::query_as!(Todo,
",
gid.as_u64()
)
.fetch_all(&pool)
.await?
} else {
sqlx::query_as!(
Todo,
"
SELECT * FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
", self.user.as_u64())
.fetch_all(&pool)
.await?
",
self.user.as_u64()
)
.fetch_all(&pool)
.await?
})
}
pub async fn add(&self, value: String, pool: MySqlPool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pub async fn add(
&self,
value: String,
pool: MySqlPool,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let (Some(cid), Some(gid)) = (self.channel, self.guild) {
sqlx::query!(
"
@ -119,11 +122,15 @@ INSERT INTO todos (user_id, guild_id, channel_id, value) VALUES (
(SELECT id FROM channels WHERE channel = ?),
?
)
", self.user.as_u64(), gid.as_u64(), cid.as_u64(), value)
.execute(&pool)
.await?;
}
else if let Some(gid) = self.guild {
",
self.user.as_u64(),
gid.as_u64(),
cid.as_u64(),
value
)
.execute(&pool)
.await?;
} else if let Some(gid) = self.guild {
sqlx::query!(
"
INSERT INTO todos (user_id, guild_id, value) VALUES (
@ -131,74 +138,95 @@ INSERT INTO todos (user_id, guild_id, value) VALUES (
(SELECT id FROM guilds WHERE guild = ?),
?
)
", self.user.as_u64(), gid.as_u64(), value)
.execute(&pool)
.await?;
}
else {
",
self.user.as_u64(),
gid.as_u64(),
value
)
.execute(&pool)
.await?;
} else {
sqlx::query!(
"
INSERT INTO todos (user_id, value) VALUES (
(SELECT id FROM users WHERE user = ?),
?
)
", self.user.as_u64(), value)
.execute(&pool)
.await?;
",
self.user.as_u64(),
value
)
.execute(&pool)
.await?;
}
Ok(())
}
pub async fn remove(&self, num: usize, pool: &MySqlPool) -> Result<Todo, Box<dyn std::error::Error + Sync + Send>> {
pub async fn remove(
&self,
num: usize,
pool: &MySqlPool,
) -> Result<Todo, Box<dyn std::error::Error + Sync + Send>> {
let todos = self.view(pool.clone()).await?;
if let Some(removal_todo) = todos.get(num) {
let deleting = sqlx::query_as!(Todo,
let deleting = sqlx::query_as!(
Todo,
"
SELECT * FROM todos WHERE id = ?
", removal_todo.id)
.fetch_one(&pool.clone())
.await?;
",
removal_todo.id
)
.fetch_one(&pool.clone())
.await?;
sqlx::query!(
"
DELETE FROM todos WHERE id = ?
", removal_todo.id)
.execute(pool)
.await?;
",
removal_todo.id
)
.execute(pool)
.await?;
Ok(deleting)
}
else {
} else {
Err(Box::new(TodoNotFound))
}
}
pub async fn clear(&self, pool: &MySqlPool) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
pub async fn clear(
&self,
pool: &MySqlPool,
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
if let Some(cid) = self.channel {
sqlx::query!(
"
DELETE FROM todos WHERE channel_id = (SELECT id FROM channels WHERE channel = ?)
", cid.as_u64())
.execute(pool)
.await?;
}
else if let Some(gid) = self.guild {
",
cid.as_u64()
)
.execute(pool)
.await?;
} else if let Some(gid) = self.guild {
sqlx::query!(
"
DELETE FROM todos WHERE guild_id = (SELECT id FROM guilds WHERE guild = ?) AND channel_id IS NULL
", gid.as_u64())
.execute(pool)
.await?;
}
else {
",
gid.as_u64()
)
.execute(pool)
.await?;
} else {
sqlx::query!(
"
DELETE FROM todos WHERE user_id = (SELECT id FROM users WHERE user = ?) AND guild_id IS NULL
", self.user.as_u64())
.execute(pool)
.await?;
",
self.user.as_u64()
)
.execute(pool)
.await?;
}
Ok(())
@ -219,47 +247,53 @@ impl ToString for SubCommand {
SubCommand::Add => "add",
SubCommand::Remove => "remove",
SubCommand::Clear => "clear",
}.to_string()
}
.to_string()
}
}
#[command]
#[permission_level(Managed)]
async fn todo_parse(ctx: &Context, msg: &Message, args: String) -> CommandResult {
let mut split = args.split(' ');
if let Some(target) = split.next() {
let target_opt = match target {
"user" =>
Some(TodoTarget {user: msg.author.id, guild: None, channel: None}),
"user" => Some(TodoTarget {
user: msg.author.id,
guild: None,
channel: None,
}),
"channel" =>
"channel" => {
if let Some(gid) = msg.guild_id {
Some(TodoTarget {user: msg.author.id, guild: Some(gid), channel: Some(msg.channel_id)})
}
else {
Some(TodoTarget {
user: msg.author.id,
guild: Some(gid),
channel: Some(msg.channel_id),
})
} else {
None
},
}
}
"server" | "guild" => {
if let Some(gid) = msg.guild_id {
Some(TodoTarget {user: msg.author.id, guild: Some(gid), channel: None})
}
else {
Some(TodoTarget {
user: msg.author.id,
guild: Some(gid),
channel: None,
})
} else {
None
}
},
}
_ => {
None
},
_ => None,
};
if let Some(target) = target_opt {
let subcommand_opt = match split.next() {
Some("add") => Some(SubCommand::Add),
Some("remove") => Some(SubCommand::Remove),
@ -272,19 +306,21 @@ async fn todo_parse(ctx: &Context, msg: &Message, args: String) -> CommandResult
};
if let Some(subcommand) = subcommand_opt {
todo(ctx, msg, target, subcommand, split.collect::<Vec<&str>>().join(" ")).await;
}
else {
todo(
ctx,
msg,
target,
subcommand,
split.collect::<Vec<&str>>().join(" "),
)
.await;
} else {
show_help(&ctx, msg, Some(target)).await;
}
}
else {
} else {
show_help(&ctx, msg, None).await;
}
}
else {
} else {
show_help(&ctx, msg, None).await;
}
@ -292,22 +328,45 @@ async fn todo_parse(ctx: &Context, msg: &Message, args: String) -> CommandResult
}
async fn show_help(ctx: &Context, msg: &Message, target: Option<TodoTarget>) {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await;
let content = user_data.response(&pool, "todo/help").await
let content = user_data
.response(&pool, "todo/help")
.await
.replace("{prefix}", &prefix)
.replace("{command}", target.map_or_else(|| "todo user".to_string(), |t| t.command(None)).as_str());
.replace(
"{command}",
target
.map_or_else(|| "todo user".to_string(), |t| t.command(None))
.as_str(),
);
let _ = msg.channel_id.say(&ctx, content).await;
}
async fn todo(ctx: &Context, msg: &Message, target: TodoTarget, subcommand: SubCommand, extra: String) {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
async fn todo(
ctx: &Context,
msg: &Message,
target: TodoTarget,
subcommand: SubCommand,
extra: String,
) {
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let user_data = UserData::from_user(&msg.author, &ctx, &pool).await.unwrap();
let prefix = GuildData::prefix_from_id(msg.guild_id, &pool).await;
@ -325,8 +384,7 @@ async fn todo(ctx: &Context, msg: &Message, target: TodoTarget, subcommand: SubC
char_count = display.len();
todo_groups.push(display);
}
else {
} else {
char_count += display.len();
let last_group = todo_groups.pop().unwrap();
@ -336,44 +394,54 @@ async fn todo(ctx: &Context, msg: &Message, target: TodoTarget, subcommand: SubC
});
for group in todo_groups {
let _ = msg.channel_id.send_message(&ctx, |m| m
.embed(|e| e
.title(format!("{} Todo", target.name()))
.description(group)
)
).await;
let _ = msg
.channel_id
.send_message(&ctx, |m| {
m.embed(|e| {
e.title(format!("{} Todo", target.name()))
.description(group)
})
})
.await;
}
},
}
SubCommand::Add => {
let content = user_data.response(&pool, "todo/added").await
let content = user_data
.response(&pool, "todo/added")
.await
.replacen("{name}", &extra, 1);
target.add(extra, pool).await.unwrap();
let _ = msg.channel_id.say(&ctx, content).await;
},
}
SubCommand::Remove => {
let _ = if let Ok(num) = extra.parse::<usize>() {
if let Ok(todo) = target.remove(num - 1, &pool).await {
let content = user_data.response(&pool, "todo/removed").await
.replacen("{}", &todo.value, 1);
let content = user_data.response(&pool, "todo/removed").await.replacen(
"{}",
&todo.value,
1,
);
msg.channel_id.say(&ctx, content)
} else {
msg.channel_id
.say(&ctx, user_data.response(&pool, "todo/error_index").await)
}
else {
msg.channel_id.say(&ctx, user_data.response(&pool, "todo/error_index").await)
}
}
else {
let content = user_data.response(&pool, "todo/error_value").await
} else {
let content = user_data
.response(&pool, "todo/error_value")
.await
.replacen("{prefix}", &prefix, 1)
.replacen("{command}", &target.command(Some(subcommand)), 1);
msg.channel_id.say(&ctx, content)
}.await;
},
}
.await;
}
SubCommand::Clear => {
target.clear(&pool).await.unwrap();
@ -381,6 +449,6 @@ async fn todo(ctx: &Context, msg: &Message, target: TodoTarget, subcommand: SubC
let content = user_data.response(&pool, "todo/cleared").await;
let _ = msg.channel_id.say(&ctx, content).await;
},
}
}
}

View File

@ -6,44 +6,41 @@ pub const MINUTE: u64 = 60;
pub const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
use std::{
iter::FromIterator,
env,
collections::HashSet,
};
use lazy_static;
use std::{collections::HashSet, env, iter::FromIterator};
use regex::Regex;
lazy_static! {
pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(env::var("SUBSCRIPTION_ROLES")
.map(
|var| var
pub static ref SUBSCRIPTION_ROLES: HashSet<u64> = HashSet::from_iter(
env::var("SUBSCRIPTION_ROLES")
.map(|var| var
.split(',')
.filter_map(|item| {
item.parse::<u64>().ok()
})
.collect::<Vec<u64>>()
).unwrap_or_else(|_| vec![]));
pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD").map(|var| var.parse::<u64>().ok()).ok().flatten();
.filter_map(|item| { item.parse::<u64>().ok() })
.collect::<Vec<u64>>())
.unwrap_or_else(|_| vec![])
);
pub static ref CNC_GUILD: Option<u64> = env::var("CNC_GUILD")
.map(|var| var.parse::<u64>().ok())
.ok()
.flatten();
pub static ref REGEX_CHANNEL: Regex = Regex::new(r#"^\s*<#(\d+)>\s*$"#).unwrap();
pub static ref REGEX_ROLE: Regex = Regex::new(r#"<@&([0-9]+)>"#).unwrap();
pub static ref REGEX_COMMANDS: Regex = Regex::new(r#"([a-z]+)"#).unwrap();
pub static ref REGEX_ALIAS: Regex = Regex::new(r#"(?P<name>[\S]{1,12})(?:(?: (?P<cmd>.*)$)|$)"#).unwrap();
pub static ref REGEX_ALIAS: Regex =
Regex::new(r#"(?P<name>[\S]{1,12})(?:(?: (?P<cmd>.*)$)|$)"#).unwrap();
pub static ref REGEX_CHANNEL_USER: Regex = Regex::new(r#"^\s*<(#|@)(?:!)?(\d+)>\s*$"#).unwrap();
pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL").ok().map(|inner| inner.parse::<i64>().ok()).flatten().unwrap_or(600);
pub static ref MAX_TIME: i64 = env::var("MAX_TIME").ok().map(|inner| inner.parse::<i64>().ok()).flatten().unwrap_or(60*60*24*365*50);
pub static ref LOCAL_TIMEZONE: String = env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string());
pub static ref PYTHON_LOCATION: String = env::var("PYTHON_LOCATION").unwrap_or_else(|_| "venv/bin/python3".to_string());
pub static ref MIN_INTERVAL: i64 = env::var("MIN_INTERVAL")
.ok()
.map(|inner| inner.parse::<i64>().ok())
.flatten()
.unwrap_or(600);
pub static ref MAX_TIME: i64 = env::var("MAX_TIME")
.ok()
.map(|inner| inner.parse::<i64>().ok())
.flatten()
.unwrap_or(60 * 60 * 24 * 365 * 50);
pub static ref LOCAL_TIMEZONE: String =
env::var("LOCAL_TIMEZONE").unwrap_or_else(|_| "UTC".to_string());
pub static ref PYTHON_LOCATION: String =
env::var("PYTHON_LOCATION").unwrap_or_else(|_| "venv/bin/python3".to_string());
}

View File

@ -1,50 +1,29 @@
use async_trait::async_trait;
use serenity::{
constants::MESSAGE_CODE_LIMIT,
http::Http,
Result as SerenityResult,
client::Context,
framework::{
Framework,
standard::CommandResult,
},
model::{
id::ChannelId,
guild::{
Guild,
Member,
},
channel::{
Channel, GuildChannel, Message,
}
},
constants::MESSAGE_CODE_LIMIT,
framework::{standard::CommandResult, Framework},
futures::prelude::future::BoxFuture,
http::Http,
model::{
channel::{Channel, GuildChannel, Message},
guild::{Guild, Member},
id::ChannelId,
},
Result as SerenityResult,
};
use log::{
warn,
error,
info,
};
use log::{error, info, warn};
use regex::{
Regex, Match
};
use regex::{Match, Regex};
use std::{
collections::HashMap,
fmt,
env,
};
use std::{collections::HashMap, env, fmt};
use crate::{
models::ChannelData,
SQLPool,
consts::PREFIX,
};
use crate::{consts::PREFIX, models::ChannelData, SQLPool};
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, CommandResult>;
type CommandFn =
for<'fut> fn(&'fut Context, &'fut Message, String) -> BoxFuture<'fut, CommandResult>;
#[derive(Debug, PartialEq)]
pub enum PermissionLevel {
@ -63,27 +42,30 @@ pub struct Command {
impl Command {
async fn check_permissions(&self, ctx: &Context, guild: &Guild, member: &Member) -> bool {
if self.required_perms == PermissionLevel::Unrestricted {
true
}
else {
} else {
for role_id in &member.roles {
let role = role_id.to_role_cached(&ctx).await;
if let Some(cached_role) = role {
if cached_role.permissions.manage_guild() {
return true
}
else if self.required_perms == PermissionLevel::Managed && cached_role.permissions.manage_messages() {
return true
if cached_role.permissions.manage_guild()
|| (self.required_perms == PermissionLevel::Managed
&& cached_role.permissions.manage_messages())
{
return true;
}
}
}
if self.required_perms == PermissionLevel::Managed {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
match sqlx::query!(
"
@ -102,34 +84,41 @@ WHERE
guilds
WHERE
guild = ?)
", self.name, guild.id.as_u64())
.fetch_all(&pool)
.await {
",
self.name,
guild.id.as_u64()
)
.fetch_all(&pool)
.await
{
Ok(rows) => {
let role_ids = member.roles.iter().map(|r| *r.as_u64()).collect::<Vec<u64>>();
let role_ids = member
.roles
.iter()
.map(|r| *r.as_u64())
.collect::<Vec<u64>>();
for row in rows {
if role_ids.contains(&row.role) {
return true
return true;
}
}
false
}
Err(sqlx::Error::RowNotFound) => {
false
}
Err(sqlx::Error::RowNotFound) => false,
Err(e) => {
warn!("Unexpected error occurred querying command_restrictions: {:?}", e);
warn!(
"Unexpected error occurred querying command_restrictions: {:?}",
e
);
false
}
}
}
else {
} else {
false
}
}
@ -139,11 +128,11 @@ WHERE
impl fmt::Debug for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Command")
.field("name", &self.name)
.field("required_perms", &self.required_perms)
.field("supports_dm", &self.supports_dm)
.field("can_blacklist", &self.can_blacklist)
.finish()
.field("name", &self.name)
.field("required_perms", &self.required_perms)
.field("supports_dm", &self.supports_dm)
.field("can_blacklist", &self.can_blacklist)
.finish()
}
}
@ -159,12 +148,20 @@ pub struct RegexFramework {
#[async_trait]
pub trait SendIterator {
async fn say_lines(self, http: impl AsRef<Http> + Send + Sync + 'async_trait, content: impl Iterator<Item=String> + Send + 'async_trait) -> SerenityResult<()>;
async fn say_lines(
self,
http: impl AsRef<Http> + Send + Sync + 'async_trait,
content: impl Iterator<Item = String> + Send + 'async_trait,
) -> SerenityResult<()>;
}
#[async_trait]
impl SendIterator for ChannelId {
async fn say_lines(self, http: impl AsRef<Http> + Send + Sync + 'async_trait, content: impl Iterator<Item=String> + Send + 'async_trait) -> SerenityResult<()> {
async fn say_lines(
self,
http: impl AsRef<Http> + Send + Sync + 'async_trait,
content: impl Iterator<Item = String> + Send + 'async_trait,
) -> SerenityResult<()> {
let mut current_content = String::new();
for line in content {
@ -172,8 +169,7 @@ impl SendIterator for ChannelId {
self.say(&http, &current_content).await?;
current_content = line;
}
else {
} else {
current_content = format!("{}\n{}", current_content, line);
}
}
@ -220,10 +216,8 @@ impl RegexFramework {
let command_names;
{
let mut command_names_vec = self.commands
.keys()
.map(|k| &k[..])
.collect::<Vec<&str>>();
let mut command_names_vec =
self.commands.keys().map(|k| &k[..]).collect::<Vec<&str>>();
command_names_vec.sort_unstable_by(|a, b| b.len().cmp(&a.len()));
@ -245,7 +239,8 @@ impl RegexFramework {
let dm_command_names;
{
let mut command_names_vec = self.commands
let mut command_names_vec = self
.commands
.iter()
.filter_map(|(key, command)| {
if command.supports_dm {
@ -283,8 +278,11 @@ enum PermissionCheck {
#[async_trait]
impl Framework for RegexFramework {
async fn dispatch(&self, ctx: Context, msg: Message) {
async fn check_self_permissions(ctx: &Context, guild: &Guild, channel: &GuildChannel) -> Result<PermissionCheck, Box<dyn std::error::Error + Sync + Send>> {
async fn check_self_permissions(
ctx: &Context,
guild: &Guild,
channel: &GuildChannel,
) -> Result<PermissionCheck, Box<dyn std::error::Error + Sync + Send>> {
let user_id = ctx.cache.current_user_id().await;
let guild_perms = guild.member_permissions(user_id);
@ -294,31 +292,40 @@ impl Framework for RegexFramework {
Ok(if basic_perms && guild_perms.manage_webhooks() {
PermissionCheck::All
}
else if basic_perms {
} else if basic_perms {
PermissionCheck::Basic
}
else {
} else {
PermissionCheck::None
})
}
async fn check_prefix(ctx: &Context, guild: &Guild, prefix_opt: Option<Match<'_>>) -> bool {
if let Some(prefix) = prefix_opt {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
match sqlx::query!("SELECT prefix FROM guilds WHERE guild = ?", guild.id.as_u64())
.fetch_one(&pool)
.await {
Ok(row) => {
prefix.as_str() == row.prefix
}
match sqlx::query!(
"SELECT prefix FROM guilds WHERE guild = ?",
guild.id.as_u64()
)
.fetch_one(&pool)
.await
{
Ok(row) => prefix.as_str() == row.prefix,
Err(sqlx::Error::RowNotFound) => {
let _ = sqlx::query!("INSERT INTO guilds (guild, name) VALUES (?, ?)", guild.id.as_u64(), guild.name)
.execute(&pool)
.await;
let _ = sqlx::query!(
"INSERT INTO guilds (guild, name) VALUES (?, ?)",
guild.id.as_u64(),
guild.name
)
.execute(&pool)
.await;
prefix.as_str() == "$"
}
@ -329,49 +336,62 @@ impl Framework for RegexFramework {
false
}
}
}
else {
} else {
true
}
}
// gate to prevent analysing messages unnecessarily
if (msg.author.bot && self.ignore_bots) ||
msg.tts ||
msg.content.is_empty() ||
!msg.attachments.is_empty() {}
if (msg.author.bot && self.ignore_bots)
|| msg.tts
|| msg.content.is_empty()
|| !msg.attachments.is_empty()
{
}
// Guild Command
else if let (Some(guild), Some(Channel::Guild(channel))) = (msg.guild(&ctx).await, msg.channel(&ctx).await) {
else if let (Some(guild), Some(Channel::Guild(channel))) =
(msg.guild(&ctx).await, msg.channel(&ctx).await)
{
let member = guild.member(&ctx, &msg.author).await.unwrap();
if let Some(full_match) = self.command_matcher.captures(&msg.content[..]) {
if check_prefix(&ctx, &guild, full_match.name("prefix")).await {
match check_self_permissions(&ctx, &guild, &channel).await {
Ok(perms) => match perms {
PermissionCheck::All => {
let pool = ctx.data.read().await
.get::<SQLPool>().cloned().expect("Could not get SQLPool from data");
let pool = ctx
.data
.read()
.await
.get::<SQLPool>()
.cloned()
.expect("Could not get SQLPool from data");
let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap();
let channel_data = ChannelData::from_channel(msg.channel(&ctx).await.unwrap(), &pool).await;
let command = self
.commands
.get(full_match.name("cmd").unwrap().as_str())
.unwrap();
let channel_data = ChannelData::from_channel(
msg.channel(&ctx).await.unwrap(),
&pool,
)
.await;
if !command.can_blacklist || !channel_data.map(|c| c.blacklisted).unwrap_or(false) {
let args = full_match.name("args")
if !command.can_blacklist
|| !channel_data.map(|c| c.blacklisted).unwrap_or(false)
{
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();
if command.check_permissions(&ctx, &guild, &member).await {
(command.func)(&ctx, &msg, args).await.unwrap();
}
else if command.required_perms == PermissionLevel::Restricted {
} else if command.required_perms == PermissionLevel::Restricted
{
let _ = msg.channel_id.say(&ctx, "You must have permission level `Manage Server` or greater to use this command.").await;
}
else if command.required_perms == PermissionLevel::Managed {
} else if command.required_perms == PermissionLevel::Managed {
let _ = msg.channel_id.say(&ctx, "You must have `Manage Messages` or have a role capable of sending reminders to that channel. Please talk to your server admin, and ask them to use the `{prefix}restrict` command to specify allowed roles.").await;
}
}
@ -384,20 +404,26 @@ impl Framework for RegexFramework {
PermissionCheck::None => {
warn!("Missing enough permissions for guild {}", guild.id);
}
}
},
Err(e) => {
error!("Error occurred getting permissions in guild {}: {:?}", guild.id, e);
error!(
"Error occurred getting permissions in guild {}: {:?}",
guild.id, e
);
}
}
}
}
}
// DM Command
else if let Some(full_match) = self.dm_regex_matcher.captures(&msg.content[..]) {
let command = self.commands.get(full_match.name("cmd").unwrap().as_str()).unwrap();
let args = full_match.name("args")
let command = self
.commands
.get(full_match.name("cmd").unwrap().as_str())
.unwrap();
let args = full_match
.name("args")
.map(|m| m.as_str())
.unwrap_or("")
.to_string();

View File

@ -1,58 +1,37 @@
#[macro_use]
extern crate lazy_static;
mod models;
mod framework;
mod commands;
mod time_parser;
mod consts;
mod framework;
mod models;
mod time_parser;
use serenity::{
cache::Cache,
http::{
CacheHttp,
client::Http,
},
client::{
bridge::gateway::GatewayIntents,
Client,
},
model::{
id::{
GuildId, UserId,
},
channel::Message,
},
client::{bridge::gateway::GatewayIntents, Client},
framework::Framework,
http::{client::Http, CacheHttp},
model::{
channel::Message,
id::{GuildId, UserId},
},
prelude::TypeMapKey,
};
use sqlx::{
mysql::{MySqlConnection, MySqlPool},
Pool,
mysql::{
MySqlPool,
MySqlConnection,
}
};
use dotenv::dotenv;
use std::{
sync::Arc,
env,
};
use std::{env, sync::Arc};
use crate::{
commands::{info_cmds, moderation_cmds, reminder_cmds, todo_cmds},
consts::{CNC_GUILD, PREFIX, SUBSCRIPTION_ROLES},
framework::RegexFramework,
consts::{
PREFIX, SUBSCRIPTION_ROLES, CNC_GUILD,
},
commands::{
info_cmds,
reminder_cmds,
todo_cmds,
moderation_cmds,
},
};
use serenity::futures::TryFutureExt;
@ -85,23 +64,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let http = Http::new_with_token(&token);
let logged_in_id = http.get_current_user().map_ok(|user| user.id.as_u64().to_owned()).await?;
let logged_in_id = http
.get_current_user()
.map_ok(|user| user.id.as_u64().to_owned())
.await?;
let framework = RegexFramework::new(logged_in_id)
.ignore_bots(env::var("IGNORE_BOTS").map_or(true, |var| var == "1"))
.default_prefix(&env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string()))
.add_command("ping", &info_cmds::PING_COMMAND)
.add_command("help", &info_cmds::HELP_COMMAND)
.add_command("info", &info_cmds::INFO_COMMAND)
.add_command("invite", &info_cmds::INFO_COMMAND)
.add_command("donate", &info_cmds::DONATE_COMMAND)
.add_command("dashboard", &info_cmds::DASHBOARD_COMMAND)
.add_command("clock", &info_cmds::CLOCK_COMMAND)
.add_command("timer", &reminder_cmds::TIMER_COMMAND)
.add_command("remind", &reminder_cmds::REMIND_COMMAND)
.add_command("r", &reminder_cmds::REMIND_COMMAND)
.add_command("interval", &reminder_cmds::INTERVAL_COMMAND)
@ -109,36 +87,39 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.add_command("natural", &reminder_cmds::NATURAL_COMMAND)
.add_command("n", &reminder_cmds::NATURAL_COMMAND)
.add_command("", &reminder_cmds::NATURAL_COMMAND)
.add_command("look", &reminder_cmds::LOOK_COMMAND)
.add_command("del", &reminder_cmds::DELETE_COMMAND)
.add_command("todo", &todo_cmds::TODO_PARSE_COMMAND)
.add_command("blacklist", &moderation_cmds::BLACKLIST_COMMAND)
.add_command("restrict", &moderation_cmds::RESTRICT_COMMAND)
.add_command("timezone", &moderation_cmds::TIMEZONE_COMMAND)
.add_command("prefix", &moderation_cmds::PREFIX_COMMAND)
.add_command("lang", &moderation_cmds::LANGUAGE_COMMAND)
.add_command("pause", &reminder_cmds::PAUSE_COMMAND)
.add_command("offset", &reminder_cmds::OFFSET_COMMAND)
.add_command("nudge", &reminder_cmds::NUDGE_COMMAND)
.add_command("alias", &moderation_cmds::ALIAS_COMMAND)
.add_command("a", &moderation_cmds::ALIAS_COMMAND)
.build();
let framework_arc = Arc::new(Box::new(framework) as Box<dyn Framework + Send + Sync>);
let mut client = Client::new(&token)
.intents(GatewayIntents::GUILD_MESSAGES | GatewayIntents::GUILDS | GatewayIntents::DIRECT_MESSAGES)
.intents(
GatewayIntents::GUILD_MESSAGES
| GatewayIntents::GUILDS
| GatewayIntents::DIRECT_MESSAGES,
)
.framework_arc(framework_arc.clone())
.await.expect("Error occurred creating client");
.await
.expect("Error occurred creating client");
{
let pool = MySqlPool::new(&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment")).await.unwrap();
let pool = MySqlPool::new(
&env::var("DATABASE_URL").expect("Missing DATABASE_URL from environment"),
)
.await
.unwrap();
let mut data = client.data.write().await;
@ -152,27 +133,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
pub async fn check_subscription(cache_http: impl CacheHttp, user_id: impl Into<UserId>) -> bool {
if let Some(subscription_guild) = *CNC_GUILD {
let guild_member = GuildId(subscription_guild).member(cache_http, user_id).await;
let guild_member = GuildId(subscription_guild)
.member(cache_http, user_id)
.await;
if let Ok(member) = guild_member {
for role in member.roles {
if SUBSCRIPTION_ROLES.contains(role.as_u64()) {
return true
return true;
}
}
}
false
}
else {
} else {
true
}
}
pub async fn check_subscription_on_message(cache_http: impl CacheHttp + AsRef<Cache>, msg: &Message) -> bool {
check_subscription(&cache_http, &msg.author).await ||
if let Some(guild) = msg.guild(&cache_http).await { check_subscription(&cache_http, guild.owner_id).await } else { false }
pub async fn check_subscription_on_message(
cache_http: impl CacheHttp + AsRef<Cache>,
msg: &Message,
) -> bool {
check_subscription(&cache_http, &msg.author).await
|| if let Some(guild) = msg.guild(&cache_http).await {
check_subscription(&cache_http, guild.owner_id).await
} else {
false
}
}

View File

@ -1,11 +1,6 @@
use serenity::{
http::CacheHttp,
model::{
id::GuildId,
guild::Guild,
channel::Channel,
user::User,
}
model::{channel::Channel, guild::Guild, id::GuildId, user::User},
};
use std::env;
@ -24,49 +19,67 @@ pub struct GuildData {
}
impl GuildData {
pub async fn prefix_from_id<T: Into<GuildId>>(guild_id_opt: Option<T>, pool: &MySqlPool) -> String {
pub async fn prefix_from_id<T: Into<GuildId>>(
guild_id_opt: Option<T>,
pool: &MySqlPool,
) -> String {
if let Some(guild_id) = guild_id_opt {
let guild_id = guild_id.into().as_u64().to_owned();
let row = sqlx::query!(
"
SELECT prefix FROM guilds WHERE guild = ?
", guild_id
",
guild_id
)
.fetch_one(pool)
.await;
.fetch_one(pool)
.await;
row.map_or_else(|_| env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string()), |r| r.prefix)
}
else {
row.map_or_else(
|_| env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string()),
|r| r.prefix,
)
} else {
env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string())
}
}
pub async fn from_guild(guild: Guild, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
pub async fn from_guild(
guild: Guild,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let guild_id = guild.id.as_u64().to_owned();
if let Ok(g) = sqlx::query_as!(Self,
if let Ok(g) = sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
", guild_id)
.fetch_one(pool)
.await {
",
guild_id
)
.fetch_one(pool)
.await
{
Ok(g)
}
else {
} else {
sqlx::query!(
"
INSERT INTO guilds (guild, name, prefix) VALUES (?, ?, ?)
", guild_id, guild.name, env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string()))
.execute(&pool.clone())
.await?;
",
guild_id,
guild.name,
env::var("DEFAULT_PREFIX").unwrap_or_else(|_| PREFIX.to_string())
)
.execute(&pool.clone())
.await?;
Ok(sqlx::query_as!(Self,
"
Ok(sqlx::query_as!(
Self,
"
SELECT id, name, prefix FROM guilds WHERE guild = ?
", guild_id)
",
guild_id
)
.fetch_one(pool)
.await?)
}
@ -76,9 +89,14 @@ SELECT id, name, prefix FROM guilds WHERE guild = ?
sqlx::query!(
"
UPDATE guilds SET name = ?, prefix = ? WHERE id = ?
", self.name, self.prefix, self.id)
.execute(pool)
.await.unwrap();
",
self.name,
self.prefix,
self.id
)
.execute(pool)
.await
.unwrap();
}
}
@ -103,9 +121,10 @@ SELECT id, name, nudge, blacklisted, webhook_id, webhook_token, paused, paused_u
.await.ok()
}
pub async fn from_channel(channel: Channel, pool: &MySqlPool)
-> Result<Self, Box<dyn std::error::Error + Sync + Send>>
{
pub async fn from_channel(
channel: Channel,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let channel_id = channel.id().as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self,
@ -162,19 +181,25 @@ pub struct UserData {
}
impl UserData {
pub async fn from_user(user: &User, ctx: impl CacheHttp, pool: &MySqlPool) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
pub async fn from_user(
user: &User,
ctx: impl CacheHttp,
pool: &MySqlPool,
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let user_id = user.id.as_u64().to_owned();
if let Ok(c) = sqlx::query_as_unchecked!(Self,
if let Ok(c) = sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
", user_id)
.fetch_one(pool)
.await {
",
user_id
)
.fetch_one(pool)
.await
{
Ok(c)
}
else {
} else {
let dm_channel = user.create_dm_channel(ctx).await?;
let dm_id = dm_channel.id.as_u64().to_owned();
@ -183,9 +208,11 @@ SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
sqlx::query!(
"
INSERT INTO channels (channel) VALUES (?)
", dm_id)
.execute(&pool_c)
.await?;
",
dm_id
)
.execute(&pool_c)
.await?;
sqlx::query!(
"
@ -194,12 +221,15 @@ INSERT INTO users (user, name, dm_channel) VALUES (?, ?, (SELECT id FROM channel
.execute(&pool_c)
.await?;
Ok(sqlx::query_as_unchecked!(Self,
Ok(sqlx::query_as_unchecked!(
Self,
"
SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
", user_id)
.fetch_one(pool)
.await?)
",
user_id
)
.fetch_one(pool)
.await?)
}
}
@ -207,9 +237,15 @@ SELECT id, user, name, dm_channel, language, timezone FROM users WHERE user = ?
sqlx::query!(
"
UPDATE users SET name = ?, language = ?, timezone = ? WHERE id = ?
", self.name, self.language, self.timezone, self.id)
.execute(pool)
.await.unwrap();
",
self.name,
self.language,
self.timezone,
self.id
)
.execute(pool)
.await
.unwrap();
}
pub async fn response(&self, pool: &MySqlPool, name: &str) -> String {
@ -221,7 +257,8 @@ SELECT value FROM strings WHERE (language = ? OR language = 'EN') AND name = ? O
.await
.unwrap_or_else(|_| panic!("No string with that name: {}", name));
row.value.unwrap_or_else(|| panic!("Null string with that name: {}", name))
row.value
.unwrap_or_else(|| panic!("Null string with that name: {}", name))
}
pub fn timezone(&self) -> Tz {
@ -237,33 +274,41 @@ pub struct Timer {
impl Timer {
pub async fn from_owner(owner: u64, pool: &MySqlPool) -> Vec<Self> {
sqlx::query_as_unchecked!(Timer,
sqlx::query_as_unchecked!(
Timer,
"
SELECT name, start_time, owner FROM timers WHERE owner = ?
", owner)
.fetch_all(pool)
.await
.unwrap()
",
owner
)
.fetch_all(pool)
.await
.unwrap()
}
pub async fn count_from_owner(owner: u64, pool: &MySqlPool) -> u32 {
sqlx::query!(
"
SELECT COUNT(1) as count FROM timers WHERE owner = ?
", owner)
.fetch_one(pool)
.await
.unwrap()
.count as u32
",
owner
)
.fetch_one(pool)
.await
.unwrap()
.count as u32
}
pub async fn create(name: &str, owner: u64, pool: &MySqlPool) {
sqlx::query!(
"
INSERT INTO timers (name, owner) VALUES (?, ?)
", name, owner)
.execute(pool)
.await
.unwrap();
",
name,
owner
)
.execute(pool)
.await
.unwrap();
}
}

View File

@ -1,16 +1,9 @@
use std::time::{
SystemTime,
UNIX_EPOCH,
};
use std::time::{SystemTime, UNIX_EPOCH};
use std::fmt::{
Formatter,
Display,
Result as FmtResult,
};
use std::fmt::{Display, Formatter, Result as FmtResult};
use chrono_tz::Tz;
use chrono::TimeZone;
use chrono_tz::Tz;
use std::convert::TryFrom;
#[derive(Debug)]
@ -55,8 +48,7 @@ impl TimeParser {
let parse_type = if input.contains('/') || input.contains(':') {
ParseType::Explicit
}
else {
} else {
ParseType::Displacement
};
@ -70,9 +62,7 @@ impl TimeParser {
pub fn timestamp(&self) -> Result<i64, InvalidTime> {
match self.parse_type {
ParseType::Explicit => {
Ok(self.process_explicit()?)
},
ParseType::Explicit => Ok(self.process_explicit()?),
ParseType::Displacement => {
let now = SystemTime::now();
@ -81,7 +71,7 @@ impl TimeParser {
.expect("Time calculated as going backwards. Very bad");
Ok(since_epoch.as_secs() as i64 + self.process_displacement()?)
},
}
}
}
@ -94,11 +84,9 @@ impl TimeParser {
.expect("Time calculated as going backwards. Very bad");
Ok(self.process_explicit()? - since_epoch.as_secs() as i64)
},
}
ParseType::Displacement => {
Ok(self.process_displacement()?)
},
ParseType::Displacement => Ok(self.process_displacement()?),
}
}
@ -112,7 +100,7 @@ impl TimeParser {
0 => Ok("%d-".to_string()),
1 => Ok("%d/%m-".to_string()),
2 => Ok("%d/%m/%Y-".to_string()),
_ => Err(InvalidTime::ParseErrorDMY)
_ => Err(InvalidTime::ParseErrorDMY),
}
} else {
Ok("".to_string())
@ -122,13 +110,16 @@ impl TimeParser {
match colons {
1 => Ok("%H:%M"),
2 => Ok("%H:%M:%S"),
_ => Err(InvalidTime::ParseErrorHMS)
_ => Err(InvalidTime::ParseErrorHMS),
}
} else {
Ok("")
}?;
let dt = self.timezone.datetime_from_str(self.time_string.as_str(), &parse_string).map_err(|_| InvalidTime::ParseErrorChrono)?;
let dt = self
.timezone
.datetime_from_str(self.time_string.as_str(), &parse_string)
.map_err(|_| InvalidTime::ParseErrorChrono)?;
Ok(dt.timestamp() as i64)
}
@ -143,40 +134,42 @@ impl TimeParser {
for character in self.time_string.chars() {
match character {
's' => {
seconds = current_buffer.parse::<i64>().unwrap();
current_buffer = String::from("0");
},
}
'm' => {
minutes = current_buffer.parse::<i64>().unwrap();
current_buffer = String::from("0");
},
}
'h' => {
hours = current_buffer.parse::<i64>().unwrap();
current_buffer = String::from("0");
},
}
'd' => {
days = current_buffer.parse::<i64>().unwrap();
current_buffer = String::from("0");
},
}
c => {
if c.is_digit(10) {
current_buffer += &c.to_string();
} else {
return Err(InvalidTime::ParseErrorDisplacement);
}
else {
return Err(InvalidTime::ParseErrorDisplacement)
}
},
}
}
}
let full = (seconds + (minutes * 60) + (hours * 3600) + (days * 86400) + current_buffer.parse::<i64>().unwrap()) *
if self.inverted { -1 } else { 1 };
let full = (seconds
+ (minutes * 60)
+ (hours * 3600)
+ (days * 86400)
+ current_buffer.parse::<i64>().unwrap())
* if self.inverted { -1 } else { 1 };
Ok(full)
}