Skip to content

Commit

Permalink
Update discord connector to latest serenity library
Browse files Browse the repository at this point in the history
Signed-off-by: Darach Ennis <darach@gmail.com>
  • Loading branch information
darach committed May 13, 2024
1 parent dc27e5b commit 4acac1e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 141 deletions.
2 changes: 1 addition & 1 deletion tremor-connectors/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ bytes = "1.0"
tempfile = { version = "3.8", default-features = false }
env_logger = "0.11"
tremor-connectors-test-helpers = { path = "../tremor-connectors-test-helpers", version = "0.13.0-rc.23" }
tide = { version = "0.16", default-features = false } # TODO remove tide from TestHttpServer
tide = { version = "0.16", default-features = false } # TODO remove tide from TestHttpServer
tokio = { version = "1.34", default-features = false, features = [
"full",
"test-util",
Expand Down
229 changes: 105 additions & 124 deletions tremor-connectors/src/impls/discord/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
use super::utils::{as_snowflake, get_snowflake, to_reactions, DiscordMessage};
use crate::channel::{Receiver, Sender};
use serenity::{
all::{CreateEmbed, CreateEmbedAuthor, CreateEmbedFooter, CreateMessage, EditMember},
model::{
channel::{Channel, ChannelCategory, GuildChannel, Message, Reaction},
channel::{GuildChannel, Message, Reaction},
event::{
ChannelPinsUpdateEvent, GuildMembersChunkEvent, ResumedEvent, TypingStartEvent,
VoiceServerUpdateEvent,
ChannelPinsUpdateEvent, GuildMemberUpdateEvent, GuildMembersChunkEvent, ResumedEvent,
TypingStartEvent, VoiceServerUpdateEvent,
},
guild::{Emoji, Guild, Member, PartialGuild, Role},
id::{ChannelId, EmojiId, GuildId, MessageId, RoleId, UserId},
Expand All @@ -29,6 +30,7 @@ use serenity::{
},
prelude::*,
};
use simd_json::prelude::ValueObjectAccess;
use std::collections::HashMap;
use tokio::task;
use tremor_value::{prelude::*, to_value};
Expand Down Expand Up @@ -57,6 +59,7 @@ impl Handler {
}
}
}

#[async_trait::async_trait]
impl EventHandler for Handler {
// We use the cache_ready event just in case some cache operation is required in whatever use
Expand All @@ -71,23 +74,17 @@ impl EventHandler for Handler {
}
}

async fn channel_create(&self, _ctx: Context, channel: &GuildChannel) {
self.forward(DiscordMessage::ChannelCreate(channel.clone()))
.await;
}

async fn category_create(&self, _ctx: Context, category: &ChannelCategory) {
self.forward(DiscordMessage::CategoryCreate(category.clone()))
.await;
async fn channel_create(&self, _ctx: Context, channel: GuildChannel) {
self.forward(DiscordMessage::ChannelCreate(channel)).await;
}

async fn category_delete(&self, _ctx: Context, category: &ChannelCategory) {
self.forward(DiscordMessage::CategoryDelete(category.clone()))
.await;
}

async fn channel_delete(&self, _ctx: Context, channel: &GuildChannel) {
self.forward(DiscordMessage::ChannelDelete(channel.clone()))
async fn channel_delete(
&self,
_ctx: Context,
channel: GuildChannel,
maybe_message: Option<Vec<Message>>,
) {
self.forward(DiscordMessage::ChannelDelete(channel, maybe_message))
.await;
}

Expand Down Expand Up @@ -133,7 +130,7 @@ impl EventHandler for Handler {
self.forward(DiscordMessage::TypingStart(e)).await;
}

async fn channel_update(&self, _ctx: Context, old: Option<Channel>, new: Channel) {
async fn channel_update(&self, _ctx: Context, old: Option<GuildChannel>, new: GuildChannel) {
self.forward(DiscordMessage::ChannelUpdate { old, new })
.await;
}
Expand All @@ -148,7 +145,7 @@ impl EventHandler for Handler {
.await;
}

async fn guild_create(&self, _ctx: Context, guild: Guild, is_new: bool) {
async fn guild_create(&self, _ctx: Context, guild: Guild, is_new: Option<bool>) {
self.forward(DiscordMessage::GuildCreate { guild, is_new })
.await;
}
Expand Down Expand Up @@ -200,11 +197,13 @@ impl EventHandler for Handler {
&self,
_ctx: Context,
old_if_available: Option<Member>,
new: Member,
member: Option<Member>,
event: GuildMemberUpdateEvent,
) {
self.forward(DiscordMessage::MemberUpdate {
old_if_available,
new,
member,
event,
})
.await;
}
Expand Down Expand Up @@ -245,11 +244,6 @@ impl EventHandler for Handler {
.await;
}

async fn guild_unavailable(&self, _ctx: Context, guild_id: GuildId) {
self.forward(DiscordMessage::GuildUnavailable(guild_id))
.await;
}

async fn guild_update(
&self,
_ctx: Context,
Expand Down Expand Up @@ -325,9 +319,17 @@ impl EventHandler for Handler {
self.forward(DiscordMessage::Resume(resume)).await;
}

async fn user_update(&self, _ctx: Context, old_data: CurrentUser, new: CurrentUser) {
self.forward(DiscordMessage::UserUpdate { old_data, new })
async fn user_update(&self, _ctx: Context, old_data: Option<CurrentUser>, new: CurrentUser) {
if let Some(old_data) = old_data {
self.forward(DiscordMessage::UserUpdate { old_data, new })
.await;
} else {
self.forward(DiscordMessage::UserUpdate {
old_data: new.clone(),
new,
})
.await;
}
}

async fn voice_server_update(&self, _ctx: Context, update: VoiceServerUpdateEvent) {
Expand Down Expand Up @@ -358,16 +360,16 @@ async fn reply_loop(mut rx: Receiver<Value<'static>>, ctx: Context) {
while let Some(reply) = rx.recv().await {
if let Some(reply) = reply.get("guild") {
let guild = if let Some(id) = get_snowflake(reply, "id") {
GuildId(id)
GuildId::new(id)
} else {
error!("guild `id` missing");
continue;
};

if let Some(member) = reply.get("member") {
if let Some(id) = get_snowflake(member, "id") {
let user = UserId(id);
let mut current_member = match guild.member(&ctx, user).await {
let user = UserId::new(id);
let current_member = match guild.member(&ctx, user).await {
Ok(current_member) => current_member,
Err(e) => {
error!("Member error: {}", e);
Expand All @@ -377,7 +379,7 @@ async fn reply_loop(mut rx: Receiver<Value<'static>>, ctx: Context) {
if let Some(to_remove) = member.get_array("remove_roles") {
let to_remove: Vec<_> = to_remove
.iter()
.filter_map(|v| as_snowflake(v).map(RoleId))
.filter_map(|v| as_snowflake(v).map(RoleId::new))
.collect();
if let Err(e) = current_member.remove_roles(&ctx, &to_remove).await {
error!("Role removal error: {}", e);
Expand All @@ -387,24 +389,20 @@ async fn reply_loop(mut rx: Receiver<Value<'static>>, ctx: Context) {
if let Some(to_roles) = member.get_array("add_roles") {
let to_roles: Vec<_> = to_roles
.iter()
.filter_map(|v| as_snowflake(v).map(RoleId))
.filter_map(|v| as_snowflake(v).map(RoleId::new))
.collect();
if let Err(e) = current_member.add_roles(&ctx, &to_roles).await {
error!("Role add error: {}", e);
};
}
let r = guild
.edit_member(&ctx, id, |m| {
if let Some(deafen) = member.get_bool("deafen") {
m.deafen(deafen);
}
if let Some(mute) = member.get_bool("mute") {
m.mute(mute);
}

m
})
.await;
let mut em = EditMember::default();
if let Some(mute) = member.get_bool("mute") {
em = em.mute(mute);
};
if let Some(deaf) = member.get_bool("deaf") {
em = em.deafen(deaf);
};
let r = guild.edit_member(&ctx, id, em).await;
if let Err(e) = r {
error!("Mute/Deafen error: {}", e);
};
Expand All @@ -413,7 +411,7 @@ async fn reply_loop(mut rx: Receiver<Value<'static>>, ctx: Context) {
}
if let Some(reply) = reply.get("message") {
let channel = if let Some(id) = get_snowflake(reply, "channel_id") {
ChannelId(id)
ChannelId::new(id)
} else {
error!("channel_id missing");
continue;
Expand All @@ -440,84 +438,67 @@ async fn reply_loop(mut rx: Receiver<Value<'static>>, ctx: Context) {
}

if let Some(reply) = reply.get("send") {
if let Err(e) = channel
.send_message(&ctx, |m| {
// Normal content
if let Some(content) = reply.get_str("content") {
m.content(content);
};
// Reference to another message
if let Some(reference_message) = get_snowflake(reply, "reference_message") {
let reference_channel = get_snowflake(reply, "reference_channel")
.map_or(channel, ChannelId);
m.reference_message((reference_channel, MessageId(reference_message)));
};

if let Some(tts) = reply.get_bool("tts") {
m.tts(tts);
};

if let Some(embed) = reply.get("embed") {
m.embed(|e| {
if let Some(author) = embed.get("author") {
e.author(|a| {
if let Some(icon_url) = author.get_str("icon_url") {
a.icon_url(icon_url);
};
if let Some(name) = author.get_str("name") {
a.name(name);
};
if let Some(url) = author.get_str("url") {
a.url(url);
};

a
});
};

if let Some(colour) = embed.get_u64("colour") {
e.colour(colour);
};
if let Some(description) = embed.get_str("description") {
e.description(description);
};

if let Some(fields) = embed.get_array("fields") {
e.fields(fields.iter().filter_map(|v| {
let name = v.get_str("name")?;
let value = v.get_str("value")?;
let inline = v.get_bool("inline").unwrap_or_default();
Some((name, value, inline))
}));
};
if let Some(footer) = embed.get("footer") {
e.footer(|f| {
if let Some(text) = footer.as_str() {
f.text(text);
};
if let Some(text) = footer.get_str("text") {
f.text(text);
};
if let Some(icon_url) = footer.get_str("icon_url") {
f.icon_url(icon_url);
};

f
});
};

e
});
};

if let Some(reactions) = reply.get("reactions").and_then(to_reactions) {
m.reactions(reactions);
let mut created_message = CreateMessage::default();
// Normal content
if let Some(content) = reply.get_str("content") {
created_message = created_message.content(content);
};
// Reference to another message
if let Some(reference_message) = get_snowflake(reply, "reference_message") {
let reference_channel =
get_snowflake(reply, "reference_channel").map_or(channel, ChannelId::new);
created_message = created_message
.reference_message((reference_channel, MessageId::new(reference_message)));
};
if let Some(tts) = reply.get_bool("tts") {
created_message = created_message.tts(tts);
};
if let Some(embed) = reply.get("embed") {
let mut created_embed = CreateEmbed::default();
if let Some(author) = embed.get("author") {
if let Some(name) = author.get_str("name") {
let mut create_embed_author = CreateEmbedAuthor::new(name);
if let Some(icon_url) = author.get_str("icon_url") {
create_embed_author = create_embed_author.icon_url(icon_url);
};
if let Some(url) = author.get_str("url") {
create_embed_author = create_embed_author.url(url);
};
created_embed = created_embed.author(create_embed_author);
}
};
if let Some(colour) = embed.get_u64("colour") {
created_embed = created_embed.colour(colour);
};
if let Some(description) = embed.get_str("description") {
created_embed = created_embed.description(description);
};
if let Some(fields) = embed.get_array("fields") {
created_embed = created_embed.fields(fields.iter().filter_map(|v| {
let name = v.get_str("name")?;
let value = v.get_str("value")?;
let inline = v.get_bool("inline").unwrap_or_default();
Some((name, value, inline))
}));
};
if let Some(footer) = embed.get("footer") {
if let Some(text) = footer.as_str() {
let mut created_embed_footer = CreateEmbedFooter::new(text);
if let Some(text) = footer.get_str("text") {
created_embed_footer = created_embed_footer.text(text);
};
if let Some(icon_url) = footer.get_str("icon_url") {
created_embed_footer = created_embed_footer.icon_url(icon_url);
};
created_embed = created_embed.footer(created_embed_footer);
};

m
})
.await
{
};
created_message = created_message.embed(created_embed);
};
if let Some(reactions) = reply.get("reactions").and_then(to_reactions) {
created_message = created_message.reactions(reactions);
};
if let Err(e) = channel.send_message(&ctx, created_message).await {
error!("Discord send error: {}", e);
};
};
Expand Down
Loading

0 comments on commit 4acac1e

Please sign in to comment.