177 lines
5.6 KiB
Rust
177 lines
5.6 KiB
Rust
use crate::web::{
|
|
check_authorization,
|
|
guards::transaction::Transaction,
|
|
routes::{
|
|
dashboard::{
|
|
create_reminder, CreateReminder, ImportBody, ReminderCsv, ReminderTemplateCsv, TodoCsv,
|
|
},
|
|
JsonResult,
|
|
},
|
|
};
|
|
use crate::Database;
|
|
use base64::{prelude::BASE64_STANDARD, Engine};
|
|
use csv::{QuoteStyle, WriterBuilder};
|
|
use log::warn;
|
|
use rocket::{
|
|
get,
|
|
http::CookieJar,
|
|
put,
|
|
serde::json::{json, Json},
|
|
State,
|
|
};
|
|
use serenity::{
|
|
client::Context,
|
|
model::id::{ChannelId, GuildId, UserId},
|
|
};
|
|
use sqlx::{MySql, Pool};
|
|
|
|
#[get("/api/guild/<id>/export/todos")]
|
|
pub async fn export(
|
|
id: u64,
|
|
cookies: &CookieJar<'_>,
|
|
ctx: &State<Context>,
|
|
pool: &State<Pool<MySql>>,
|
|
) -> JsonResult {
|
|
check_authorization(cookies, ctx.inner(), id).await?;
|
|
|
|
let mut csv_writer = WriterBuilder::new().quote_style(QuoteStyle::Always).from_writer(vec![]);
|
|
|
|
match sqlx::query_as_unchecked!(
|
|
TodoCsv,
|
|
"SELECT value, CONCAT('#', channels.channel) AS channel_id FROM todos
|
|
LEFT JOIN channels ON todos.channel_id = channels.id
|
|
INNER JOIN guilds ON todos.guild_id = guilds.id
|
|
WHERE guilds.guild = ?",
|
|
id
|
|
)
|
|
.fetch_all(pool.inner())
|
|
.await
|
|
{
|
|
Ok(todos) => {
|
|
todos.iter().for_each(|todo| {
|
|
csv_writer.serialize(todo).unwrap();
|
|
});
|
|
|
|
match csv_writer.into_inner() {
|
|
Ok(inner) => match String::from_utf8(inner) {
|
|
Ok(encoded) => Ok(json!({ "body": encoded })),
|
|
|
|
Err(e) => {
|
|
warn!("Failed to write UTF-8: {:?}", e);
|
|
|
|
json_err!("Failed to write UTF-8")
|
|
}
|
|
},
|
|
|
|
Err(e) => {
|
|
warn!("Failed to extract CSV: {:?}", e);
|
|
|
|
json_err!("Failed to extract CSV")
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(e) => {
|
|
warn!("Could not fetch templates from {}: {:?}", id, e);
|
|
|
|
json_err!("Failed to query templates")
|
|
}
|
|
}
|
|
}
|
|
|
|
#[put("/api/guild/<id>/export/todos", data = "<body>")]
|
|
pub async fn import(
|
|
id: u64,
|
|
cookies: &CookieJar<'_>,
|
|
body: Json<ImportBody>,
|
|
ctx: &State<Context>,
|
|
pool: &State<Pool<MySql>>,
|
|
) -> JsonResult {
|
|
check_authorization(cookies, ctx.inner(), id).await?;
|
|
|
|
let channels_res = GuildId::new(id).channels(&ctx.inner()).await;
|
|
|
|
match channels_res {
|
|
Ok(channels) => match BASE64_STANDARD.decode(&body.body) {
|
|
Ok(body) => {
|
|
let mut reader = csv::Reader::from_reader(body.as_slice());
|
|
|
|
let query_placeholder = "(?, (SELECT id FROM channels WHERE channel = ?), (SELECT id FROM guilds WHERE guild = ?))";
|
|
let mut query_params = vec![];
|
|
|
|
for result in reader.deserialize::<TodoCsv>() {
|
|
match result {
|
|
Ok(record) => match record.channel_id {
|
|
Some(channel_id) => {
|
|
let channel_id = channel_id.split_at(1).1;
|
|
|
|
match channel_id.parse::<u64>() {
|
|
Ok(channel_id) => {
|
|
if channels.contains_key(&ChannelId::new(channel_id)) {
|
|
query_params.push((record.value, Some(channel_id), id));
|
|
} else {
|
|
return json_err!(format!(
|
|
"Invalid channel ID {}",
|
|
channel_id
|
|
));
|
|
}
|
|
}
|
|
|
|
Err(_) => {
|
|
return json_err!(format!(
|
|
"Invalid channel ID {}",
|
|
channel_id
|
|
));
|
|
}
|
|
}
|
|
}
|
|
|
|
None => {
|
|
query_params.push((record.value, None, id));
|
|
}
|
|
},
|
|
|
|
Err(e) => {
|
|
warn!("Couldn't deserialize CSV row: {:?}", e);
|
|
|
|
return json_err!("Deserialize error. Aborted");
|
|
}
|
|
}
|
|
}
|
|
|
|
let query_str = format!(
|
|
"INSERT INTO todos (value, channel_id, guild_id) VALUES {}",
|
|
vec![query_placeholder].repeat(query_params.len()).join(",")
|
|
);
|
|
let mut query = sqlx::query(&query_str);
|
|
|
|
for param in query_params {
|
|
query = query.bind(param.0).bind(param.1).bind(param.2);
|
|
}
|
|
|
|
let res = query.execute(pool.inner()).await;
|
|
|
|
match res {
|
|
Ok(_) => Ok(json!({})),
|
|
|
|
Err(e) => {
|
|
warn!("Couldn't execute todo query: {:?}", e);
|
|
|
|
json_err!("An unexpected error occured.")
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(_) => {
|
|
json_err!("Malformed base64")
|
|
}
|
|
},
|
|
|
|
Err(e) => {
|
|
warn!("Couldn't fetch channels for guild {}: {:?}", id, e);
|
|
|
|
json_err!("Couldn't fetch channels.")
|
|
}
|
|
}
|
|
}
|