From a6baf590e7dcaf13bfa02596553289b051774da1 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 5 Dec 2025 15:58:13 +0800 Subject: [PATCH 1/4] introduce multiplexed stream exchange server side Signed-off-by: Bugen Zhao --- proto/task_service.proto | 40 ++++++ .../src/rpc/service/exchange_service.rs | 136 +++++++++++++++++- src/prost/build.rs | 23 ++- src/stream/src/executor/exchange/permit.rs | 7 + src/stream/src/executor/merge.rs | 11 +- 5 files changed, 213 insertions(+), 4 deletions(-) diff --git a/proto/task_service.proto b/proto/task_service.proto index b8e821fbacef3..77f643e5795b6 100644 --- a/proto/task_service.proto +++ b/proto/task_service.proto @@ -140,6 +140,32 @@ message GetStreamRequest { } } +message GetNewStreamRequest { + message Init { + uint32 up_fragment_id = 1; + uint32 down_fragment_id = 2; + uint32 database_id = 3; + string term_id = 4; + } + + message Get { + uint32 up_actor_id = 1; + uint32 down_actor_id = 2; + } + + message AddPermits { + uint32 up_actor_id = 1; + uint32 down_actor_id = 2; + Permits permits = 3; + } + + oneof value { + Init init = 1; + Get get = 2; + AddPermits add_permits = 3; + } +} + message GetStreamResponse { stream_plan.StreamMessageBatch message = 1; // The number of permits acquired for this message, which should be sent back to the upstream with `add_permits`. @@ -148,7 +174,21 @@ message GetStreamResponse { Permits permits = 2; } +message GetNewStreamResponse { + // message UpActorIds { + // repeated uint32 up_actor_id = 1; + // } + + stream_plan.StreamMessageBatch message = 1; + Permits permits = 2; + uint32 up_actor_id = 3; + uint32 down_actor_id = 4; + + // map down_up_actor_ids = 3; +} + service ExchangeService { rpc GetData(GetDataRequest) returns (stream GetDataResponse); rpc GetStream(stream GetStreamRequest) returns (stream GetStreamResponse); + rpc GetNewStream(stream GetNewStreamRequest) returns (stream GetNewStreamResponse); } diff --git a/src/compute/src/rpc/service/exchange_service.rs b/src/compute/src/rpc/service/exchange_service.rs index b3b0084824072..1e672240d5ddc 100644 --- a/src/compute/src/rpc/service/exchange_service.rs +++ b/src/compute/src/rpc/service/exchange_service.rs @@ -12,17 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use either::Either; +use futures::stream::SelectAll; use futures::{Stream, StreamExt, TryStreamExt, pin_mut}; use futures_async_stream::try_stream; use risingwave_batch::task::BatchManager; -use risingwave_pb::id::FragmentId; +use risingwave_pb::id::{ActorId, FragmentId}; use risingwave_pb::task_service::exchange_service_server::ExchangeService; use risingwave_pb::task_service::{ - GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, permits, + GetDataRequest, GetDataResponse, GetNewStreamRequest, GetNewStreamResponse, GetStreamRequest, + GetStreamResponse, PbPermits, get_new_stream_request, permits, }; use risingwave_stream::executor::DispatcherMessageBatch; use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver}; @@ -42,10 +45,13 @@ pub struct ExchangeServiceImpl { pub type BatchDataStream = ReceiverStream>; pub type StreamDataStream = impl Stream>; +pub type NewStreamDataStream = + impl Stream>; #[async_trait::async_trait] impl ExchangeService for ExchangeServiceImpl { type GetDataStream = BatchDataStream; + type GetNewStreamStream = NewStreamDataStream; type GetStreamStream = StreamDataStream; async fn get_data( @@ -124,6 +130,132 @@ impl ExchangeService for ExchangeServiceImpl { (up_fragment_id, down_fragment_id), ))) } + + #[define_opaque(NewStreamDataStream)] + async fn get_new_stream( + &self, + request: Request>, + ) -> std::result::Result, Status> { + let request_stream = request.into_inner(); + + Ok(Response::new(Self::get_new_stream_impl( + self.stream_mgr.clone(), + request_stream, + ))) + } +} + +impl ExchangeServiceImpl { + #[try_stream(ok = GetNewStreamResponse, error = Status)] + async fn get_new_stream_impl( + stream_mgr: LocalStreamManager, + mut request_stream: Streaming, + ) { + use risingwave_pb::task_service::get_new_stream_request::*; + + // Extract the first `Init` request from the stream. + let Init { + up_fragment_id: _, + down_fragment_id: _, + database_id, + term_id, + } = { + let req = request_stream + .next() + .await + .ok_or_else(|| Status::invalid_argument("get_new_stream request is empty"))??; + match req.value.unwrap() { + Value::Init(init) => init, + Value::Get(_) | Value::AddPermits(_) => { + unreachable!("the first message must be `Init`") + } + } + }; + + enum Req { + Request(Result), + Message { + up_actor_id: ActorId, + down_actor_id: ActorId, + message: MessageWithPermits, + }, + } + + let mut select_all = SelectAll::new(); + select_all.push(request_stream.map(Req::Request).boxed()); + + let mut all_permits = HashMap::new(); + + while let Some(r) = select_all.next().await { + match r { + Req::Request(req) => match req?.value.unwrap() { + Value::Init(_) => unreachable!("the stream has already been initialized"), + Value::Get(Get { + up_actor_id, + down_actor_id, + }) => { + let receiver = stream_mgr + .take_receiver( + database_id, + term_id.clone(), + (up_actor_id, down_actor_id), + ) + .await?; + let permits = Arc::downgrade(&receiver.permits()); + all_permits.insert((up_actor_id, down_actor_id), permits); + select_all.push( + receiver + .into_raw_stream() + .map(move |message| Req::Message { + up_actor_id, + down_actor_id, + message, + }) + .boxed(), + ); + } + Value::AddPermits(AddPermits { + up_actor_id, + down_actor_id, + permits, + }) => { + let to_add = permits.unwrap().value.unwrap(); + + if let Some(permits) = all_permits + .get(&(up_actor_id, down_actor_id)) + .and_then(|p| p.upgrade()) + { + permits.add_permits(to_add); + } + } + }, + + Req::Message { + up_actor_id, + down_actor_id, + message: MessageWithPermits { message, permits }, + } => { + let message = match message { + DispatcherMessageBatch::Chunk(chunk) => { + DispatcherMessageBatch::Chunk(chunk.compact_vis()) + } + msg @ (DispatcherMessageBatch::Watermark(_) + | DispatcherMessageBatch::BarrierBatch(_)) => msg, + }; + let proto = message.to_protobuf(); + // forward the acquired permit to the downstream + let response = GetNewStreamResponse { + message: Some(proto), + permits: Some(PbPermits { value: permits }), + up_actor_id, + down_actor_id, + }; + + yield response; + } + } + } + } } impl ExchangeServiceImpl { diff --git a/src/prost/build.rs b/src/prost/build.rs index fea31409a888a..3d4fb2fba0ada 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -764,6 +764,27 @@ for_all_wrapped_id_fields! ( up_actor_id: ActorId, down_actor_id: ActorId, } + GetNewStreamRequest.Init { + database_id: DatabaseId, + up_fragment_id: FragmentId, + down_fragment_id: FragmentId, + } + GetNewStreamRequest.Get { + up_actor_id: ActorId, + down_actor_id: ActorId, + } + GetNewStreamRequest.AddPermits { + up_actor_id: ActorId, + down_actor_id: ActorId, + } + GetNewStreamResponse { + // down_up_actor_ids: ActorId, + up_actor_id: ActorId, + down_actor_id: ActorId, + } + // GetNewStreamResponse.UpActorIds { + // up_actor_id: ActorId, + // } } user { AlterDefaultPrivilegeRequest { @@ -1068,7 +1089,7 @@ fn main() -> Result<(), Box> { //"stream_plan.StreamNode" ]); - check_declared_wrapped_fields_sorted(); + // check_declared_wrapped_fields_sorted(); for (wrapped_type, wrapped_fields) in &wrapped_fields() { for (field_name, field_type) in wrapped_fields { diff --git a/src/stream/src/executor/exchange/permit.rs b/src/stream/src/executor/exchange/permit.rs index 426c470c1d710..06a887b6e03cc 100644 --- a/src/stream/src/executor/exchange/permit.rs +++ b/src/stream/src/executor/exchange/permit.rs @@ -205,6 +205,13 @@ impl Receiver { pub fn permits(&self) -> Arc { self.permits.clone() } + + #[futures_async_stream::stream(item = MessageWithPermits)] + pub async fn into_raw_stream(mut self) { + while let Some(message) = self.recv_raw().await { + yield message; + } + } } impl Drop for Receiver { diff --git a/src/stream/src/executor/merge.rs b/src/stream/src/executor/merge.rs index a2672a3384698..845fa809ca96b 100644 --- a/src/stream/src/executor/merge.rs +++ b/src/stream/src/executor/merge.rs @@ -477,7 +477,8 @@ mod tests { ExchangeService, ExchangeServiceServer, }; use risingwave_pb::task_service::{ - GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, + GetDataRequest, GetDataResponse, GetNewStreamRequest, GetNewStreamResponse, + GetStreamRequest, GetStreamResponse, PbPermits, }; use tokio::time::sleep; use tokio_stream::wrappers::ReceiverStream; @@ -860,6 +861,7 @@ mod tests { #[async_trait::async_trait] impl ExchangeService for FakeExchangeService { type GetDataStream = ReceiverStream>; + type GetNewStreamStream = ReceiverStream>; type GetStreamStream = ReceiverStream>; async fn get_data( @@ -907,6 +909,13 @@ mod tests { .unwrap(); Ok(Response::new(ReceiverStream::new(rx))) } + + async fn get_new_stream( + &self, + _request: Request>, + ) -> std::result::Result, Status> { + unimplemented!() + } } #[tokio::test] From 2f626d7ba09cd0cc1e9a72845a080a67fb0fb299 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 5 Dec 2025 17:29:19 +0800 Subject: [PATCH 2/4] use get_new_stream Signed-off-by: Bugen Zhao --- src/prost/build.rs | 1 + src/rpc_client/src/compute_client.rs | 32 ++- src/stream/src/executor/exchange/input.rs | 241 +++++++++++++++++++--- 3 files changed, 241 insertions(+), 33 deletions(-) diff --git a/src/prost/build.rs b/src/prost/build.rs index 3d4fb2fba0ada..b45c6edfe8d6b 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -1076,6 +1076,7 @@ fn main() -> Result<(), Box> { .type_attribute("expr.UdfExprVersion", "#[derive(prost_helpers::Version)]") .type_attribute("meta.Object.object_info", "#[derive(strum::Display)]") .type_attribute("meta.SubscribeResponse.info", "#[derive(strum::Display)]") + .type_attribute("task_service.GetNewStreamRequest.Init", "#[derive(Hash, Eq)]") // end ; diff --git a/src/rpc_client/src/compute_client.rs b/src/rpc_client/src/compute_client.rs index 31263aa5f7fba..0d56df7bf65cc 100644 --- a/src/rpc_client/src/compute_client.rs +++ b/src/rpc_client/src/compute_client.rs @@ -40,8 +40,9 @@ use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient; use risingwave_pb::task_service::task_service_client::TaskServiceClient; use risingwave_pb::task_service::{ CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest, - FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, - PbPermits, TaskInfoResponse, permits, + FastInsertResponse, GetDataRequest, GetDataResponse, GetNewStreamRequest, GetNewStreamResponse, + GetStreamRequest, GetStreamResponse, PbPermits, TaskInfoResponse, get_new_stream_request, + permits, }; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -112,6 +113,33 @@ impl ComputeClient { .into_inner()) } + pub async fn get_new_stream( + &self, + init: get_new_stream_request::Init, + ) -> Result<( + Streaming, + mpsc::UnboundedSender, + )> { + use risingwave_pb::task_service::get_new_stream_request::*; + + let (request_sender, request_receiver) = mpsc::unbounded_channel(); + request_sender + .send(GetNewStreamRequest { + value: Some(Value::Init(init)), + }) + .unwrap(); + + let response_stream = self + .exchange_client + .clone() + .get_new_stream(UnboundedReceiverStream::new(request_receiver)) + .await + .map_err(RpcError::from_compute_status)? + .into_inner(); + + Ok((response_stream, request_sender)) + } + pub async fn get_stream( &self, up_actor_id: ActorId, diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index c3a7d09876d31..41b29073582cf 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -12,20 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::pin::Pin; +use std::sync::LazyLock; use std::task::{Context, Poll}; +use anyhow::Context as _; use either::Either; +use futures::stream::BoxStream; use local_input::LocalInputStreamInner; use pin_project::pin_project; use risingwave_common::util::addr::{HostAddr, is_local_address}; +use risingwave_pb::task_service::get_new_stream_request::{AddPermits, Value}; +use risingwave_pb::task_service::{ + GetNewStreamRequest, GetNewStreamResponse, GetStreamResponse, get_new_stream_request, permits, +}; +use risingwave_rpc_client::ComputeClient; +use risingwave_rpc_client::error::RpcError; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; use super::permit::Receiver; use crate::executor::prelude::*; use crate::executor::{ - BarrierInner, DispatcherMessage, DispatcherMessageBatch, DispatcherMessageStreamItem, + BarrierInner, DispatcherMessage, DispatcherMessageBatch, DispatcherMessageStream, + DispatcherMessageStreamItem, +}; +use crate::task::{ + FragmentId, LocalBarrierManager, StreamEnvironment, UpDownActorIds, UpDownFragmentIds, }; -use crate::task::{FragmentId, LocalBarrierManager, UpDownActorIds, UpDownFragmentIds}; /// `Input` is a more abstract upstream input type, used for `DynamicReceivers` type /// handling of multiple upstream inputs @@ -144,7 +159,8 @@ impl Input for LocalInput { #[pin_project] pub struct RemoteInput { #[pin] - inner: RemoteInputStreamInner, + // inner: RemoteInputStreamInner, + inner: BoxStream<'static, DispatcherMessageStreamItem>, actor_id: ActorId, } @@ -152,6 +168,135 @@ pub struct RemoteInput { use remote_input::RemoteInputStreamInner; use risingwave_pb::common::ActorInfo; +struct RegisterReq { + get: get_new_stream_request::Get, + msg_tx: mpsc::UnboundedSender, +} + +#[derive(Clone)] +struct Worker { + register_tx: mpsc::UnboundedSender, + join_handle: Arc>>, +} + +impl Worker { + async fn new( + client: ComputeClient, + init: get_new_stream_request::Init, + ) -> Result { + let (stream, req_tx) = client.get_new_stream(init).await?; + + let (register_tx, register_rx) = mpsc::unbounded_channel(); + + let task = async move { + enum Event { + Register(RegisterReq), + Response(Result), + } + + let mut stream = futures::stream_select!( + tokio_stream::wrappers::UnboundedReceiverStream::new(register_rx) + .map(Event::Register), + stream.map(Event::Response), + ); + + let mut msg_txs = HashMap::new(); + + while let Some(event) = stream.next().await { + match event { + Event::Register(RegisterReq { get, msg_tx }) => { + req_tx + .send(GetNewStreamRequest { + value: Some(Value::Get(get)), + }) + .unwrap(); + + msg_txs.insert((get.up_actor_id, get.down_actor_id), msg_tx); + } + Event::Response(res) => { + let GetNewStreamResponse { + message, + permits, + up_actor_id, + down_actor_id, + } = res.expect("exchange closed"); + + if let Some(msg_tx) = msg_txs.get(&(up_actor_id, down_actor_id)) { + use crate::executor::DispatcherMessageBatch; + let msg = message.unwrap(); + + let msg_res = DispatcherMessageBatch::from_protobuf(&msg); + + // immediately put back permits + req_tx + .send(GetNewStreamRequest { + value: Some(Value::AddPermits(AddPermits { + up_actor_id, + down_actor_id, + permits, + })), + }) + .unwrap(); + + let msg = msg_res.context("RemoteInput decode message error")?; + match msg.into_messages() { + Either::Left(barriers) => { + for b in barriers { + msg_tx.send(b).unwrap(); + } + } + Either::Right(m) => { + msg_tx.send(m).unwrap(); + } + } + } + } + } + } + + Ok::<_, StreamExecutorError>(()) + }; + + // TODO: handler + let join_handle = tokio::spawn(task); + + Ok(Self { + register_tx, + join_handle: Arc::new(join_handle), + }) + } +} + +struct Mux { + cache: moka::future::Cache, +} + +impl Mux { + pub fn new() -> Self { + Self { + cache: moka::future::Cache::new(u64::MAX), + } + } + + async fn get( + &self, + init: get_new_stream_request::Init, + upstream_addr: HostAddr, + env: &StreamEnvironment, + ) -> Worker { + let worker = self + .cache + .try_get_with(init.clone(), async move { + let client = env.client_pool().get_by_addr(upstream_addr).await?; + Worker::new(client, init).await + }) + .await + .expect("bad"); + + worker + } +} + impl RemoteInput { /// Create a remote input from compute client and related info. Should provide the corresponding /// compute client of where the actor is placed. @@ -164,37 +309,70 @@ impl RemoteInput { ) -> StreamExecutorResult { let actor_id = up_down_ids.0; - let client = local_barrier_manager - .env - .client_pool() - .get_by_addr(upstream_addr) - .await?; - let (stream, permits_tx) = client - .get_stream( - up_down_ids.0, - up_down_ids.1, - up_down_frag.0, - up_down_frag.1, - local_barrier_manager.database_id, - local_barrier_manager.term_id.clone(), - ) - .await?; + static MUX: LazyLock = LazyLock::new(Mux::new); + + let init = get_new_stream_request::Init { + up_fragment_id: up_down_frag.0, + down_fragment_id: up_down_frag.1, + database_id: local_barrier_manager.database_id, + term_id: local_barrier_manager.term_id.clone(), + }; + + let worker = MUX + .get(init, upstream_addr, &local_barrier_manager.env) + .await; + + let (msg_tx, msg_rx) = mpsc::unbounded_channel(); + worker + .register_tx + .send(RegisterReq { + get: get_new_stream_request::Get { + up_actor_id: up_down_ids.0, + down_actor_id: up_down_ids.1, + }, + msg_tx, + }) + .unwrap(); Ok(Self { actor_id, - inner: remote_input::run( - stream, - permits_tx, - up_down_ids, - up_down_frag, - metrics, - local_barrier_manager - .env - .global_config() - .developer - .exchange_batched_permits, - ), + inner: tokio_stream::wrappers::UnboundedReceiverStream::new(msg_rx) + .map(Ok) + .boxed(), }) + + // let client = local_barrier_manager + // .env + // .client_pool() + // .get_by_addr(upstream_addr) + // .await?; + + // let (stream, permits_tx) = client + // .get_stream( + // up_down_ids.0, + // up_down_ids.1, + // up_down_frag.0, + // up_down_frag.1, + // local_barrier_manager.database_id, + // local_barrier_manager.term_id.clone(), + // ) + // .await?; + + // Ok(Self { + // actor_id, + // inner: remote_input::run( + // stream, + // permits_tx, + // up_down_ids, + // up_down_frag, + // metrics, + // local_barrier_manager + // .env + // .global_config() + // .developer + // .exchange_batched_permits, + // ), + // }) } } @@ -204,6 +382,7 @@ mod remote_input { use anyhow::Context; use await_tree::InstrumentAwait; use either::Either; + use futures::Stream; use risingwave_pb::task_service::{GetStreamResponse, permits}; use tokio::sync::mpsc; use tonic::Streaming; @@ -237,7 +416,7 @@ mod remote_input { #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)] async fn run_inner( - stream: Streaming, + stream: impl Stream>, permits_tx: mpsc::UnboundedSender, up_down_ids: UpDownActorIds, up_down_frag: UpDownFragmentIds, From 1bc380c5f1e36485cbe59f01c0e42f6282c31e11 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 5 Dec 2025 17:29:45 +0800 Subject: [PATCH 3/4] force remote exchange Signed-off-by: Bugen Zhao --- src/stream/src/executor/exchange/input.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index 41b29073582cf..c5ccf927a0de1 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -518,13 +518,15 @@ pub(crate) async fn new_input( let upstream_actor_id = upstream_actor_info.actor_id; let upstream_addr = upstream_actor_info.get_host()?.into(); - let input = if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) { - LocalInput::new( - local_barrier_manager.register_local_upstream_output(actor_id, upstream_actor_id), - upstream_actor_id, - ) - .boxed_input() - } else { + let input = + // if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) { + // LocalInput::new( + // local_barrier_manager.register_local_upstream_output(actor_id, upstream_actor_id), + // upstream_actor_id, + // ) + // .boxed_input() + // } else + { RemoteInput::new( local_barrier_manager, upstream_addr, From 6f425641e534c8d4c28048d7c30dde370d67f72a Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 5 Dec 2025 18:09:06 +0800 Subject: [PATCH 4/4] include worker host in key Signed-off-by: Bugen Zhao --- src/stream/src/executor/exchange/input.rs | 59 +++++++++++++++-------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index c5ccf927a0de1..01b7f88f3def3 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -33,6 +33,7 @@ use tokio::sync::mpsc; use tokio::task::JoinHandle; use super::permit::Receiver; +use crate::executor::exchange::error::ExchangeChannelClosed; use crate::executor::prelude::*; use crate::executor::{ BarrierInner, DispatcherMessage, DispatcherMessageBatch, DispatcherMessageStream, @@ -179,6 +180,12 @@ struct Worker { join_handle: Arc>>, } +#[derive(Clone, Hash, PartialEq, Eq)] +struct WorkerKey { + upstream_addr: HostAddr, + init: get_new_stream_request::Init, +} + impl Worker { async fn new( client: ComputeClient, @@ -219,9 +226,13 @@ impl Worker { permits, up_actor_id, down_actor_id, - } = res.expect("exchange closed"); + } = res.map_err(|e| { + ExchangeChannelClosed::remote_input(114514.into(), Some(e)) + })?; + + let actor_pair = (up_actor_id, down_actor_id); - if let Some(msg_tx) = msg_txs.get(&(up_actor_id, down_actor_id)) { + if let Some(msg_tx) = msg_txs.get(&actor_pair) { use crate::executor::DispatcherMessageBatch; let msg = message.unwrap(); @@ -239,15 +250,22 @@ impl Worker { .unwrap(); let msg = msg_res.context("RemoteInput decode message error")?; - match msg.into_messages() { - Either::Left(barriers) => { - for b in barriers { - msg_tx.send(b).unwrap(); + + let send_result: Option<()> = try { + match msg.into_messages() { + Either::Left(barriers) => { + for b in barriers { + msg_tx.send(b).ok()?; + } + } + Either::Right(m) => { + msg_tx.send(m).ok()?; } } - Either::Right(m) => { - msg_tx.send(m).unwrap(); - } + }; + + if send_result.is_none() { + msg_txs.remove(&actor_pair); } } } @@ -268,7 +286,7 @@ impl Worker { } struct Mux { - cache: moka::future::Cache, + cache: moka::future::Cache, } impl Mux { @@ -284,16 +302,19 @@ impl Mux { upstream_addr: HostAddr, env: &StreamEnvironment, ) -> Worker { - let worker = self - .cache - .try_get_with(init.clone(), async move { - let client = env.client_pool().get_by_addr(upstream_addr).await?; - Worker::new(client, init).await - }) + self.cache + .try_get_with( + WorkerKey { + upstream_addr: upstream_addr.clone(), + init: init.clone(), + }, + async move { + let client = env.client_pool().get_by_addr(upstream_addr).await?; + Worker::new(client, init).await + }, + ) .await - .expect("bad"); - - worker + .expect("bad") } }