Skip to content

Commit

Permalink
Use RwLock for active collectors collection (#2851)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheCataliasTNT2k authored Apr 26, 2024
1 parent 44586a6 commit 83ba59e
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 11 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ percent-encoding = { version = "2.3.0", optional = true }
mini-moka = { version = "0.10.2", optional = true }
mime_guess = { version = "2.0.4", optional = true }
dashmap = { version = "5.5.3", features = ["serde"], optional = true }
parking_lot = { version = "0.12.1", optional = true }
parking_lot = { version = "0.12.1"}
ed25519-dalek = { version = "2.0.0", optional = true }
typesize = { version = "0.1.6", optional = true, features = ["url", "time", "serde_json", "secrecy", "dashmap", "parking_lot", "nonmax", "extract_map_01", "details"] }
# serde feature only allows for serialisation,
Expand Down Expand Up @@ -83,7 +83,7 @@ default_no_backend = [
builder = ["tokio/fs"]
# Enables the cache, which stores the data received from Discord gateway to provide access to
# complete guild data, channels, users and more without needing HTTP requests.
cache = ["fxhash", "dashmap", "parking_lot"]
cache = ["fxhash", "dashmap"]
# Enables collectors, a utility feature that lets you await interaction events in code with
# zero setup, without needing to setup an InteractionCreate event listener.
collector = ["gateway", "model"]
Expand All @@ -95,7 +95,7 @@ framework = ["client", "model", "utils"]
# Enables gateway support, which allows bots to listen for Discord events.
gateway = ["flate2"]
# Enables HTTP, which enables bots to execute actions on Discord.
http = ["dashmap", "parking_lot", "mime_guess", "percent-encoding"]
http = ["dashmap", "mime_guess", "percent-encoding"]
# Enables wrapper methods around HTTP requests on model types.
# Requires "builder" to configure the requests and "http" to execute them.
# Note: the model type definitions themselves are always active, regardless of this feature.
Expand All @@ -116,7 +116,7 @@ chrono = ["dep:chrono", "typesize?/chrono"]

# This enables all parts of the serenity codebase
# (Note: all feature-gated APIs to be documented should have their features listed here!)
#
#
# Unstable functionality should be gated under the `unstable` feature.
full = ["default", "collector", "voice", "voice_model", "interactions_endpoint"]

Expand Down
4 changes: 3 additions & 1 deletion src/collector.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use futures::future::pending;
use futures::{Stream, StreamExt as _};

Expand Down Expand Up @@ -35,7 +37,7 @@ pub fn collect<T: Send + 'static>(
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();

// Register an event callback in the shard. It's kept alive as long as we return `true`
shard.add_collector(CollectorCallback(Box::new(move |event| match extractor(event) {
shard.add_collector(CollectorCallback(Arc::new(move |event| match extractor(event) {
// If this event matches, we send it to the receiver stream
Some(item) => sender.send(item).is_ok(),
None => !sender.is_closed(),
Expand Down
11 changes: 10 additions & 1 deletion src/gateway/bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ mod voice;

use std::fmt;
use std::num::NonZeroU16;
use std::sync::Arc;
use std::time::Duration as StdDuration;

pub use self::event::ShardStageUpdateEvent;
Expand Down Expand Up @@ -97,9 +98,17 @@ pub struct ShardRunnerInfo {
/// Newtype around a callback that will be called on every incoming request. As long as this
/// collector should still receive events, it should return `true`. Once it returns `false`, it is
/// removed.
pub struct CollectorCallback(pub Box<dyn Fn(&Event) -> bool + Send + Sync>);
#[derive(Clone)]
pub struct CollectorCallback(pub Arc<dyn Fn(&Event) -> bool + Send + Sync>);

impl std::fmt::Debug for CollectorCallback {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("CollectorCallback").finish()
}
}

impl PartialEq for CollectorCallback {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
4 changes: 2 additions & 2 deletions src/gateway/bridge/shard_messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::model::prelude::*;
pub struct ShardMessenger {
pub(crate) tx: Sender<ShardRunnerMessage>,
#[cfg(feature = "collector")]
pub(crate) collectors: Arc<std::sync::Mutex<Vec<CollectorCallback>>>,
pub(crate) collectors: Arc<parking_lot::RwLock<Vec<CollectorCallback>>>,
}

impl ShardMessenger {
Expand Down Expand Up @@ -211,6 +211,6 @@ impl ShardMessenger {

#[cfg(feature = "collector")]
pub fn add_collector(&self, collector: CollectorCallback) {
self.collectors.lock().expect("poison").push(collector);
self.collectors.write().push(collector);
}
}
17 changes: 14 additions & 3 deletions src/gateway/bridge/shard_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub struct ShardRunner {
pub cache: Arc<Cache>,
pub http: Arc<Http>,
#[cfg(feature = "collector")]
pub(crate) collectors: Arc<std::sync::Mutex<Vec<CollectorCallback>>>,
pub(crate) collectors: Arc<parking_lot::RwLock<Vec<CollectorCallback>>>,
}

impl ShardRunner {
Expand All @@ -66,7 +66,7 @@ impl ShardRunner {
cache: opt.cache,
http: opt.http,
#[cfg(feature = "collector")]
collectors: Arc::new(std::sync::Mutex::new(vec![])),
collectors: Arc::new(parking_lot::RwLock::new(vec![])),
}
}

Expand Down Expand Up @@ -171,7 +171,18 @@ impl ShardRunner {

if let Some(event) = event {
#[cfg(feature = "collector")]
self.collectors.lock().expect("poison").retain_mut(|callback| (callback.0)(&event));
{
let read_lock = self.collectors.read();
// search all collectors to be removed and clone the Arcs
let to_remove: Vec<_> =
read_lock.iter().filter(|callback| !callback.0(&event)).cloned().collect();
drop(read_lock);
// remove all found arcs from the collection
// this compares the inner pointer of the Arc
if !to_remove.is_empty() {
self.collectors.write().retain(|f| !to_remove.contains(f));
}
}
spawn_named(
"shard_runner::dispatch",
dispatch_model(
Expand Down

0 comments on commit 83ba59e

Please sign in to comment.