changed how commands get invoked

This commit is contained in:
jellywx 2021-06-09 16:54:31 +01:00
parent e3f30ab085
commit 99e1807097
9 changed files with 657 additions and 328 deletions

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="dataSourceStorageLocal"> <component name="dataSourceStorageLocal" created-in="CL-211.7442.42">
<data-source name="MySQL for 5.1 - soundfx@localhost" uuid="1067c1d0-1386-4a39-b3f5-6d48d6f279eb"> <data-source name="MySQL for 5.1 - soundfx@localhost" uuid="1067c1d0-1386-4a39-b3f5-6d48d6f279eb">
<database-info product="" version="" jdbc-version="" driver-name="" driver-version="" dbms="MYSQL" exact-version="0" /> <database-info product="" version="" jdbc-version="" driver-name="" driver-version="" dbms="MYSQL" exact-version="0" />
<secret-storage>master_key</secret-storage> <secret-storage>master_key</secret-storage>

View File

@ -6,7 +6,7 @@ edition = "2018"
[dependencies] [dependencies]
songbird = { git = "https://github.com/FelixMcFelix/songbird", branch = "ws-fix" } songbird = { git = "https://github.com/FelixMcFelix/songbird", branch = "ws-fix" }
serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["voice", "collector"] } serenity = { git = "https://github.com/serenity-rs/serenity", branch = "next", features = ["voice", "collector", "unstable_discord_api"] }
sqlx = { version = "0.5", default-features = false, features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal"] } sqlx = { version = "0.5", default-features = false, features = ["runtime-tokio-rustls", "macros", "mysql", "bigdecimal"] }
dotenv = "0.15" dotenv = "0.15"
tokio = { version = "1", features = ["fs", "process", "io-util"] } tokio = { version = "1", features = ["fs", "process", "io-util"] }

View File

@ -57,11 +57,15 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
let name = &name[..]; let name = &name[..];
match_options!(name, values, options, span => [ match_options!(name, values, options, span => [
permission_level permission_level;
allow_slash
]); ]);
} }
let Options { permission_level } = options; let Options {
permission_level,
allow_slash,
} = options;
propagate_err!(create_declaration_validations(&mut fun, DeclarFor::Command)); propagate_err!(create_declaration_validations(&mut fun, DeclarFor::Command));
@ -88,6 +92,7 @@ pub fn command(attr: TokenStream, input: TokenStream) -> TokenStream {
func: #name, func: #name,
name: #lit_name, name: #lit_name,
required_perms: #permission_level, required_perms: #permission_level,
allow_slash: #allow_slash,
}; };
#visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> { #visibility fn #name<'fut> (#(#args),*) -> ::serenity::futures::future::BoxFuture<'fut, #ret> {

View File

@ -226,11 +226,15 @@ impl ToTokens for PermissionLevel {
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Options { pub struct Options {
pub permission_level: PermissionLevel, pub permission_level: PermissionLevel,
pub allow_slash: bool,
} }
impl Options { impl Options {
#[inline] #[inline]
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self {
permission_level: PermissionLevel::default(),
allow_slash: false,
}
} }
} }

View File

@ -182,7 +182,7 @@ pub fn create_declaration_validations(fun: &mut CommandFun, dec_for: DeclarFor)
} }
let context: Type = parse_quote!(&serenity::client::Context); let context: Type = parse_quote!(&serenity::client::Context);
let message: Type = parse_quote!(&serenity::model::channel::Message); let message: Type = parse_quote!(&(dyn crate::framework::CommandInvoke + Sync + Send));
let args: Type = parse_quote!(serenity::framework::standard::Args); let args: Type = parse_quote!(serenity::framework::standard::Args);
let args2: Type = parse_quote!(&mut serenity::framework::standard::Args); let args2: Type = parse_quote!(&mut serenity::framework::standard::Args);
let options: Type = parse_quote!(&serenity::framework::standard::CommandOptions); let options: Type = parse_quote!(&serenity::framework::standard::CommandOptions);

View File

@ -1,7 +1,20 @@
use serenity::async_trait; use crate::{
use songbird::Event; guild_data::CtxGuildData,
use songbird::EventContext; join_channel, play_audio,
use songbird::EventHandler as SongbirdEventHandler; sound::{JoinSoundCtx, Sound},
MySQL, ReqwestClient,
};
use serenity::{
async_trait,
client::{Context, EventHandler},
model::{channel::Channel, guild::Guild, id::GuildId, voice::VoiceState},
utils::shard_id,
};
use songbird::{Event, EventContext, EventHandler as SongbirdEventHandler};
use std::{collections::HashMap, env};
pub struct RestartTrack; pub struct RestartTrack;
@ -15,3 +28,130 @@ impl SongbirdEventHandler for RestartTrack {
None None
} }
} }
pub struct Handler;
#[serenity::async_trait]
impl EventHandler for Handler {
async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) {
if is_new {
if let Ok(token) = env::var("DISCORDBOTS_TOKEN") {
let shard_count = ctx.cache.shard_count().await;
let current_shard_id = shard_id(guild.id.as_u64().to_owned(), shard_count);
let guild_count = ctx
.cache
.guilds()
.await
.iter()
.filter(|g| shard_id(g.as_u64().to_owned(), shard_count) == current_shard_id)
.count() as u64;
let mut hm = HashMap::new();
hm.insert("server_count", guild_count);
hm.insert("shard_id", current_shard_id);
hm.insert("shard_count", shard_count);
let client = ctx
.data
.read()
.await
.get::<ReqwestClient>()
.cloned()
.expect("Could not get ReqwestClient from data");
let response = client
.post(
format!(
"https://top.gg/api/bots/{}/stats",
ctx.cache.current_user_id().await.as_u64()
)
.as_str(),
)
.header("Authorization", token)
.json(&hm)
.send()
.await;
if let Err(res) = response {
println!("DiscordBots Response: {:?}", res);
}
}
}
}
async fn voice_state_update(
&self,
ctx: Context,
guild_id_opt: Option<GuildId>,
old: Option<VoiceState>,
new: VoiceState,
) {
if let Some(past_state) = old {
if let (Some(guild_id), None) = (guild_id_opt, new.channel_id) {
if let Some(channel_id) = past_state.channel_id {
if let Some(Channel::Guild(channel)) = channel_id.to_channel_cached(&ctx).await
{
if channel.members(&ctx).await.map(|m| m.len()).unwrap_or(0) <= 1 {
let songbird = songbird::get(&ctx).await.unwrap();
let _ = songbird.remove(guild_id).await;
}
}
}
}
} else if let (Some(guild_id), Some(user_channel)) = (guild_id_opt, new.channel_id) {
if let Some(guild) = ctx.cache.guild(guild_id).await {
let pool = ctx
.data
.read()
.await
.get::<MySQL>()
.cloned()
.expect("Could not get SQLPool from data");
let guild_data_opt = ctx.guild_data(guild.id).await;
if let Ok(guild_data) = guild_data_opt {
let volume;
let allowed_greets;
{
let read = guild_data.read().await;
volume = read.volume;
allowed_greets = read.allow_greets;
}
if allowed_greets {
if let Some(join_id) = ctx.join_sound(new.user_id).await {
let mut sound = sqlx::query_as_unchecked!(
Sound,
"
SELECT name, id, plays, public, server_id, uploader_id
FROM sounds
WHERE id = ?
",
join_id
)
.fetch_one(&pool)
.await
.unwrap();
let (handler, _) = join_channel(&ctx, guild, user_channel).await;
let _ = play_audio(
&mut sound,
volume,
&mut handler.lock().await,
pool,
false,
)
.await;
}
}
}
}
}
}
}

View File

@ -2,13 +2,17 @@ use serenity::{
async_trait, async_trait,
client::Context, client::Context,
constants::MESSAGE_CODE_LIMIT, constants::MESSAGE_CODE_LIMIT,
framework::{standard::Args, Framework}, framework::{
standard::{Args, CommandResult, Delimiter},
Framework,
},
futures::prelude::future::BoxFuture, futures::prelude::future::BoxFuture,
http::Http, http::Http,
model::{ model::{
channel::{Channel, GuildChannel, Message}, channel::{Channel, GuildChannel, Message},
guild::{Guild, Member}, guild::{Guild, Member},
id::ChannelId, id::{ChannelId, GuildId, UserId},
interactions::Interaction,
}, },
Result as SerenityResult, Result as SerenityResult,
}; };
@ -20,9 +24,164 @@ use regex::{Match, Regex, RegexBuilder};
use std::{collections::HashMap, fmt}; use std::{collections::HashMap, fmt};
use crate::{guild_data::CtxGuildData, MySQL}; use crate::{guild_data::CtxGuildData, MySQL};
use serenity::framework::standard::{CommandResult, Delimiter}; use serenity::builder::CreateEmbed;
use serenity::cache::Cache;
use serenity::model::prelude::InteractionResponseType;
use std::sync::Arc;
type CommandFn = for<'fut> fn(&'fut Context, &'fut Message, Args) -> BoxFuture<'fut, CommandResult>; type CommandFn = for<'fut> fn(
&'fut Context,
&'fut (dyn CommandInvoke + Sync + Send),
Args,
) -> BoxFuture<'fut, CommandResult>;
pub struct CreateGenericResponse {
content: String,
embed: Option<CreateEmbed>,
}
impl CreateGenericResponse {
pub fn new() -> Self {
Self {
content: "".to_string(),
embed: None,
}
}
pub fn content<D: ToString>(mut self, content: D) -> Self {
self.content = content.to_string();
self
}
pub fn embed<F: FnOnce(&mut CreateEmbed) -> &mut CreateEmbed>(mut self, f: F) -> Self {
let mut embed = CreateEmbed::default();
f(&mut embed);
self.embed = Some(embed);
self
}
}
#[async_trait]
pub trait CommandInvoke {
fn channel_id(&self) -> ChannelId;
fn guild_id(&self) -> Option<GuildId>;
async fn guild(&self, cache: Arc<Cache>) -> Option<Guild>;
fn author_id(&self) -> UserId;
async fn member(&self, context: &Context) -> SerenityResult<Member>;
fn msg(&self) -> Option<Message>;
fn interaction(&self) -> Option<Interaction>;
async fn respond(
&self,
http: Arc<Http>,
generic_response: CreateGenericResponse,
) -> SerenityResult<()>;
}
#[async_trait]
impl CommandInvoke for Message {
fn channel_id(&self) -> ChannelId {
self.channel_id
}
fn guild_id(&self) -> Option<GuildId> {
self.guild_id
}
async fn guild(&self, cache: Arc<Cache>) -> Option<Guild> {
self.guild(cache).await
}
fn author_id(&self) -> UserId {
self.author.id
}
async fn member(&self, context: &Context) -> SerenityResult<Member> {
self.member(context).await
}
fn msg(&self) -> Option<Message> {
Some(self.clone())
}
fn interaction(&self) -> Option<Interaction> {
None
}
async fn respond(
&self,
http: Arc<Http>,
generic_response: CreateGenericResponse,
) -> SerenityResult<()> {
self.channel_id
.send_message(http, |m| {
m.content(generic_response.content);
if let Some(embed) = generic_response.embed {
m.set_embed(embed.clone());
}
m
})
.await
.map(|_| ())
}
}
#[async_trait]
impl CommandInvoke for Interaction {
fn channel_id(&self) -> ChannelId {
self.channel_id.unwrap()
}
fn guild_id(&self) -> Option<GuildId> {
self.guild_id
}
async fn guild(&self, cache: Arc<Cache>) -> Option<Guild> {
self.guild(cache).await
}
fn author_id(&self) -> UserId {
self.member.as_ref().unwrap().user.id
}
async fn member(&self, _: &Context) -> SerenityResult<Member> {
Ok(self.member.clone().unwrap())
}
fn msg(&self) -> Option<Message> {
None
}
fn interaction(&self) -> Option<Interaction> {
Some(self.clone())
}
async fn respond(
&self,
http: Arc<Http>,
generic_response: CreateGenericResponse,
) -> SerenityResult<()> {
self.create_interaction_response(http, |r| {
r.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|d| {
d.content(generic_response.content);
if let Some(embed) = generic_response.embed {
d.set_embed(embed.clone());
}
d
})
})
.await
.map(|_| ())
}
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum PermissionLevel { pub enum PermissionLevel {
@ -35,6 +194,7 @@ pub struct Command {
pub name: &'static str, pub name: &'static str,
pub required_perms: PermissionLevel, pub required_perms: PermissionLevel,
pub func: CommandFn, pub func: CommandFn,
pub allow_slash: bool,
} }
impl Command { impl Command {

File diff suppressed because it is too large Load Diff

View File

@ -121,13 +121,16 @@ pub struct Sound {
} }
impl Sound { impl Sound {
pub async fn search_for_sound( pub async fn search_for_sound<G: Into<u64>, U: Into<u64>>(
query: &str, query: &str,
guild_id: u64, guild_id: G,
user_id: u64, user_id: U,
db_pool: MySqlPool, db_pool: MySqlPool,
strict: bool, strict: bool,
) -> Result<Vec<Sound>, sqlx::Error> { ) -> Result<Vec<Sound>, sqlx::Error> {
let guild_id = guild_id.into();
let user_id = user_id.into();
fn extract_id(s: &str) -> Option<u32> { fn extract_id(s: &str) -> Option<u32> {
if s.len() > 3 && s.to_lowercase().starts_with("id:") { if s.len() > 3 && s.to_lowercase().starts_with("id:") {
match s[3..].parse::<u32>() { match s[3..].parse::<u32>() {
@ -403,8 +406,8 @@ INSERT INTO sounds (name, server_id, uploader_id, public, src)
} }
} }
pub async fn get_user_sounds( pub async fn get_user_sounds<U: Into<u64>>(
user_id: u64, user_id: U,
db_pool: MySqlPool, db_pool: MySqlPool,
) -> Result<Vec<Sound>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Vec<Sound>, Box<dyn std::error::Error + Send + Sync>> {
let sounds = sqlx::query_as_unchecked!( let sounds = sqlx::query_as_unchecked!(
@ -414,7 +417,7 @@ SELECT name, id, plays, public, server_id, uploader_id
FROM sounds FROM sounds
WHERE uploader_id = ? WHERE uploader_id = ?
", ",
user_id user_id.into()
) )
.fetch_all(&db_pool) .fetch_all(&db_pool)
.await?; .await?;
@ -422,8 +425,8 @@ SELECT name, id, plays, public, server_id, uploader_id
Ok(sounds) Ok(sounds)
} }
pub async fn get_guild_sounds( pub async fn get_guild_sounds<G: Into<u64>>(
guild_id: u64, guild_id: G,
db_pool: MySqlPool, db_pool: MySqlPool,
) -> Result<Vec<Sound>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Vec<Sound>, Box<dyn std::error::Error + Send + Sync>> {
let sounds = sqlx::query_as_unchecked!( let sounds = sqlx::query_as_unchecked!(
@ -433,7 +436,7 @@ SELECT name, id, plays, public, server_id, uploader_id
FROM sounds FROM sounds
WHERE server_id = ? WHERE server_id = ?
", ",
guild_id guild_id.into()
) )
.fetch_all(&db_pool) .fetch_all(&db_pool)
.await?; .await?;