diff --git a/proto/task_service.proto b/proto/task_service.proto index b8e821fbacef3..6276a031106a9 100644 --- a/proto/task_service.proto +++ b/proto/task_service.proto @@ -140,6 +140,35 @@ message GetStreamRequest { } } +message GetMuxStreamRequest { + message Init { + uint32 up_fragment_id = 1; + uint32 down_fragment_id = 2; + uint32 database_id = 3; + string term_id = 4; + } + + message Register { + uint32 up_actor_id = 1; + uint32 down_actor_id = 2; + } + + message AddPermits { + uint32 up_actor_id = 1; + uint32 down_actor_id = 2; + Permits permits = 3; + } + + oneof value { + // The first message, which tells the upstream which fragment pair this exchange stream is for. + Init init = 1; + // The following messages, which registers a new actor pair to this multiplexed stream. + Register register = 2; + // The following messages, which adds the permits back to the upstream to achieve back-pressure. + AddPermits add_permits = 3; + } +} + message GetStreamResponse { stream_plan.StreamMessageBatch message = 1; // The number of permits acquired for this message, which should be sent back to the upstream with `add_permits`. @@ -148,7 +177,16 @@ message GetStreamResponse { Permits permits = 2; } +message GetMuxStreamResponse { + // TODO(mux): batch the same message (typically barrier) for different actor pairs. + stream_plan.StreamMessageBatch message = 1; + Permits permits = 2; + uint32 up_actor_id = 3; + uint32 down_actor_id = 4; +} + service ExchangeService { rpc GetData(GetDataRequest) returns (stream GetDataResponse); rpc GetStream(stream GetStreamRequest) returns (stream GetStreamResponse); + rpc GetMuxStream(stream GetMuxStreamRequest) returns (stream GetMuxStreamResponse); } diff --git a/src/common/src/config/mod.rs b/src/common/src/config/mod.rs index 54d3dbac0cfdb..9c53dc3bc9b31 100644 --- a/src/common/src/config/mod.rs +++ b/src/common/src/config/mod.rs @@ -221,6 +221,16 @@ pub mod default { 0 } + pub fn stream_exchange_force_remote() -> bool { + // TODO: default to false, as it's for debugging only + true + } + + pub fn stream_exchange_remote_use_multiplexing() -> bool { + // TODO: default to false until it's tested to be stable + true + } + pub fn stream_dml_channel_initial_permits() -> usize { 32768 } diff --git a/src/common/src/config/streaming.rs b/src/common/src/config/streaming.rs index 57cec1a2b9939..5c40f30b7fccb 100644 --- a/src/common/src/config/streaming.rs +++ b/src/common/src/config/streaming.rs @@ -107,6 +107,14 @@ pub struct StreamingDeveloperConfig { #[serde(default = "default::developer::stream_exchange_concurrent_dispatchers")] pub exchange_concurrent_dispatchers: usize, + /// Force all exchanges to be remote exchanges, i.e., use gRPC. This is for debugging only. + #[serde(default = "default::developer::stream_exchange_force_remote")] + pub exchange_force_remote: bool, + + /// Use new experimental multiplexing implementation for remote exchange. + #[serde(default = "default::developer::stream_exchange_remote_use_multiplexing")] + pub exchange_remote_use_multiplexing: bool, + /// The initial permits for a dml channel, i.e., the maximum row count can be buffered in /// the channel. #[serde(default = "default::developer::stream_dml_channel_initial_permits")] diff --git a/src/compute/src/rpc/service/exchange_service.rs b/src/compute/src/rpc/service/exchange_service.rs index b3b0084824072..e508729a39f34 100644 --- a/src/compute/src/rpc/service/exchange_service.rs +++ b/src/compute/src/rpc/service/exchange_service.rs @@ -12,17 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use either::Either; +use futures::stream::SelectAll; use futures::{Stream, StreamExt, TryStreamExt, pin_mut}; use futures_async_stream::try_stream; use risingwave_batch::task::BatchManager; -use risingwave_pb::id::FragmentId; +use risingwave_pb::id::{ActorId, FragmentId}; use risingwave_pb::task_service::exchange_service_server::ExchangeService; use risingwave_pb::task_service::{ - GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, permits, + GetDataRequest, GetDataResponse, GetMuxStreamRequest, GetMuxStreamResponse, GetStreamRequest, + GetStreamResponse, PbPermits, permits, }; use risingwave_stream::executor::DispatcherMessageBatch; use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver}; @@ -42,10 +45,13 @@ pub struct ExchangeServiceImpl { pub type BatchDataStream = ReceiverStream>; pub type StreamDataStream = impl Stream>; +pub type NewStreamDataStream = + impl Stream>; #[async_trait::async_trait] impl ExchangeService for ExchangeServiceImpl { type GetDataStream = BatchDataStream; + type GetMuxStreamStream = NewStreamDataStream; type GetStreamStream = StreamDataStream; async fn get_data( @@ -124,6 +130,131 @@ impl ExchangeService for ExchangeServiceImpl { (up_fragment_id, down_fragment_id), ))) } + + #[define_opaque(NewStreamDataStream)] + async fn get_mux_stream( + &self, + request: Request>, + ) -> std::result::Result, Status> { + let request_stream = request.into_inner(); + + Ok(Response::new(Self::get_mux_stream_impl( + self.stream_mgr.clone(), + request_stream, + ))) + } +} + +impl ExchangeServiceImpl { + #[try_stream(ok = GetMuxStreamResponse, error = Status)] + async fn get_mux_stream_impl( + stream_mgr: LocalStreamManager, + mut request_stream: Streaming, + ) { + use risingwave_pb::task_service::get_mux_stream_request::*; + + // Extract the first `Init` request from the stream. + let Init { + up_fragment_id: _, + down_fragment_id: _, + database_id, + term_id, + } = { + let req = request_stream + .next() + .await + .ok_or_else(|| Status::invalid_argument("get_mux_stream request is empty"))??; + match req.value.unwrap() { + Value::Init(init) => init, + Value::Register(_) | Value::AddPermits(_) => { + unreachable!("the first message must be `Init`") + } + } + }; + + enum Req { + Request(Result), + Message { + up_actor_id: ActorId, + down_actor_id: ActorId, + message: MessageWithPermits, + }, + } + + let mut select_all = SelectAll::new(); + select_all.push(request_stream.map(Req::Request).boxed()); + + let mut all_permits = HashMap::new(); + + while let Some(r) = select_all.next().await { + match r { + Req::Request(req) => match req?.value.unwrap() { + Value::Init(_) => unreachable!("the stream has already been initialized"), + Value::Register(Register { + up_actor_id, + down_actor_id, + }) => { + let receiver = stream_mgr + .take_receiver( + database_id, + term_id.clone(), + (up_actor_id, down_actor_id), + ) + .await?; + let permits = Arc::downgrade(&receiver.permits()); + all_permits.insert((up_actor_id, down_actor_id), permits); + select_all.push( + receiver + .into_raw_stream() + .map(move |message| Req::Message { + up_actor_id, + down_actor_id, + message, + }) + .boxed(), + ); + } + Value::AddPermits(AddPermits { + up_actor_id, + down_actor_id, + permits, + }) => { + if let Some(to_add) = permits.unwrap().value + && let Some(permits) = all_permits + .get(&(up_actor_id, down_actor_id)) + .and_then(|p| p.upgrade()) + { + permits.add_permits(to_add); + } + } + }, + + Req::Message { + up_actor_id, + down_actor_id, + message: MessageWithPermits { message, permits }, + } => { + let message = match message { + DispatcherMessageBatch::Chunk(chunk) => { + DispatcherMessageBatch::Chunk(chunk.compact_vis()) + } + msg @ (DispatcherMessageBatch::Watermark(_) + | DispatcherMessageBatch::BarrierBatch(_)) => msg, + }; + let proto = message.to_protobuf(); + // forward the acquired permit to the downstream + let response = GetMuxStreamResponse { + message: Some(proto), + permits: Some(PbPermits { value: permits }), + up_actor_id, + down_actor_id, + }; + + yield response; + } + } + } + } } impl ExchangeServiceImpl { diff --git a/src/config/example.toml b/src/config/example.toml index c651eefe55fdf..bd50958d26e7b 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -190,6 +190,8 @@ exchange_initial_permits = 2048 exchange_batched_permits = 256 exchange_concurrent_barriers = 1 exchange_concurrent_dispatchers = 0 +exchange_force_remote = true +exchange_remote_use_multiplexing = true dml_channel_initial_permits = 32768 hash_agg_max_dirty_groups_heap_size = 67108864 memory_controller_threshold_aggressive = 0.9 diff --git a/src/prost/build.rs b/src/prost/build.rs index fea31409a888a..3582a7276c65e 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -757,6 +757,23 @@ for_all_wrapped_id_fields! ( FastInsertRequest { table_id: TableId, } + GetMuxStreamRequest.AddPermits { + up_actor_id: ActorId, + down_actor_id: ActorId, + } + GetMuxStreamRequest.Init { + database_id: DatabaseId, + up_fragment_id: FragmentId, + down_fragment_id: FragmentId, + } + GetMuxStreamRequest.Register { + up_actor_id: ActorId, + down_actor_id: ActorId, + } + GetMuxStreamResponse { + up_actor_id: ActorId, + down_actor_id: ActorId, + } GetStreamRequest.Get { database_id: DatabaseId, up_fragment_id: FragmentId, @@ -1055,6 +1072,7 @@ fn main() -> Result<(), Box> { .type_attribute("expr.UdfExprVersion", "#[derive(prost_helpers::Version)]") .type_attribute("meta.Object.object_info", "#[derive(strum::Display)]") .type_attribute("meta.SubscribeResponse.info", "#[derive(strum::Display)]") + .type_attribute("task_service.GetMuxStreamRequest.Init", "#[derive(Hash, Eq)]") // end ; diff --git a/src/rpc_client/src/compute_client.rs b/src/rpc_client/src/compute_client.rs index 31263aa5f7fba..352683288d816 100644 --- a/src/rpc_client/src/compute_client.rs +++ b/src/rpc_client/src/compute_client.rs @@ -40,8 +40,9 @@ use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient; use risingwave_pb::task_service::task_service_client::TaskServiceClient; use risingwave_pb::task_service::{ CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest, - FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, - PbPermits, TaskInfoResponse, permits, + FastInsertResponse, GetDataRequest, GetDataResponse, GetMuxStreamRequest, GetMuxStreamResponse, + GetStreamRequest, GetStreamResponse, PbPermits, TaskInfoResponse, get_mux_stream_request, + permits, }; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -112,6 +113,33 @@ impl ComputeClient { .into_inner()) } + pub async fn get_mux_stream( + &self, + init: get_mux_stream_request::Init, + ) -> Result<( + Streaming, + mpsc::UnboundedSender, + )> { + use risingwave_pb::task_service::get_mux_stream_request::*; + + let (request_sender, request_receiver) = mpsc::unbounded_channel(); + request_sender + .send(GetMuxStreamRequest { + value: Some(Value::Init(init)), + }) + .unwrap(); + + let response_stream = self + .exchange_client + .clone() + .get_mux_stream(UnboundedReceiverStream::new(request_receiver)) + .await + .map_err(RpcError::from_compute_status)? + .into_inner(); + + Ok((response_stream, request_sender)) + } + pub async fn get_stream( &self, up_actor_id: ActorId, diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index c3a7d09876d31..5f755b8b258bd 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -12,15 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod mux_remote_input; + use std::pin::Pin; use std::task::{Context, Poll}; -use either::Either; +use futures::future::Either; use local_input::LocalInputStreamInner; use pin_project::pin_project; use risingwave_common::util::addr::{HostAddr, is_local_address}; +use risingwave_pb::common::ActorInfo; use super::permit::Receiver; +use crate::executor::exchange::input::mux_remote_input::MuxRemoteInputStream; +use crate::executor::exchange::input::remote_input::RemoteInputStream; use crate::executor::prelude::*; use crate::executor::{ BarrierInner, DispatcherMessage, DispatcherMessageBatch, DispatcherMessageStreamItem, @@ -87,7 +92,6 @@ impl LocalInput { mod local_input { use await_tree::InstrumentAwait; - use either::Either; use crate::executor::exchange::error::ExchangeChannelClosed; use crate::executor::exchange::permit::Receiver; @@ -106,15 +110,8 @@ mod local_input { async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) { let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose(); while let Some(msg) = channel.recv().instrument_await(span.clone()).await { - match msg.into_messages() { - Either::Left(barriers) => { - for b in barriers { - yield b; - } - } - Either::Right(m) => { - yield m; - } + for msg in msg.into_messages() { + yield msg; } } // Always emit an error outside the loop. This is because we use barrier as the control @@ -144,14 +141,11 @@ impl Input for LocalInput { #[pin_project] pub struct RemoteInput { #[pin] - inner: RemoteInputStreamInner, + inner: Either, actor_id: ActorId, } -use remote_input::RemoteInputStreamInner; -use risingwave_pb::common::ActorInfo; - impl RemoteInput { /// Create a remote input from compute client and related info. Should provide the corresponding /// compute client of where the actor is placed. @@ -162,6 +156,22 @@ impl RemoteInput { up_down_frag: UpDownFragmentIds, metrics: Arc, ) -> StreamExecutorResult { + if local_barrier_manager + .env + .global_config() + .developer + .exchange_remote_use_multiplexing + { + return RemoteInput::new_mux( + local_barrier_manager, + upstream_addr, + up_down_ids, + up_down_frag, + metrics, + ) + .await; + } + let actor_id = up_down_ids.0; let client = local_barrier_manager @@ -169,6 +179,7 @@ impl RemoteInput { .client_pool() .get_by_addr(upstream_addr) .await?; + let (stream, permits_tx) = client .get_stream( up_down_ids.0, @@ -182,7 +193,7 @@ impl RemoteInput { Ok(Self { actor_id, - inner: remote_input::run( + inner: Either::Left(remote_input::run( stream, permits_tx, up_down_ids, @@ -193,7 +204,7 @@ impl RemoteInput { .global_config() .developer .exchange_batched_permits, - ), + )), }) } } @@ -203,7 +214,7 @@ mod remote_input { use anyhow::Context; use await_tree::InstrumentAwait; - use either::Either; + use futures::Stream; use risingwave_pb::task_service::{GetStreamResponse, permits}; use tokio::sync::mpsc; use tonic::Streaming; @@ -214,9 +225,9 @@ mod remote_input { use crate::executor::{DispatcherMessage, StreamExecutorError}; use crate::task::{UpDownActorIds, UpDownFragmentIds}; - pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream; + pub(super) type RemoteInputStream = impl crate::executor::DispatcherMessageStream; - #[define_opaque(RemoteInputStreamInner)] + #[define_opaque(RemoteInputStream)] pub(super) fn run( stream: Streaming, permits_tx: mpsc::UnboundedSender, @@ -224,7 +235,7 @@ mod remote_input { up_down_frag: UpDownFragmentIds, metrics: Arc, batched_permits_limit: usize, - ) -> RemoteInputStreamInner { + ) -> RemoteInputStream { run_inner( stream, permits_tx, @@ -237,7 +248,7 @@ mod remote_input { #[try_stream(ok = DispatcherMessage, error = StreamExecutorError)] async fn run_inner( - stream: Streaming, + stream: impl Stream>, permits_tx: mpsc::UnboundedSender, up_down_ids: UpDownActorIds, up_down_frag: UpDownFragmentIds, @@ -288,15 +299,8 @@ mod remote_input { } let msg = msg_res.context("RemoteInput decode message error")?; - match msg.into_messages() { - Either::Left(barriers) => { - for b in barriers { - yield b; - } - } - Either::Right(m) => { - yield m; - } + for msg in msg.into_messages() { + yield msg; } } @@ -336,10 +340,18 @@ pub(crate) async fn new_input( upstream_actor_info: &ActorInfo, upstream_fragment_id: FragmentId, ) -> StreamExecutorResult { + let force_remote = local_barrier_manager + .env + .global_config() + .developer + .exchange_force_remote; + let upstream_actor_id = upstream_actor_info.actor_id; let upstream_addr = upstream_actor_info.get_host()?.into(); - let input = if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) { + let input = if !force_remote + && is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) + { LocalInput::new( local_barrier_manager.register_local_upstream_output(actor_id, upstream_actor_id), upstream_actor_id, @@ -361,13 +373,19 @@ pub(crate) async fn new_input( } impl DispatcherMessageBatch { - fn into_messages(self) -> Either, DispatcherMessage> { + fn into_messages(self) -> impl ExactSizeIterator { + use either::Either; + match self { DispatcherMessageBatch::BarrierBatch(barriers) => { Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier)) } - DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)), - DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)), + DispatcherMessageBatch::Chunk(c) => { + Either::Right(std::iter::once(DispatcherMessage::Chunk(c))) + } + DispatcherMessageBatch::Watermark(w) => { + Either::Right(std::iter::once(DispatcherMessage::Watermark(w))) + } } } } diff --git a/src/stream/src/executor/exchange/input/mux_remote_input.rs b/src/stream/src/executor/exchange/input/mux_remote_input.rs new file mode 100644 index 0000000000000..ee99b4c8b965b --- /dev/null +++ b/src/stream/src/executor/exchange/input/mux_remote_input.rs @@ -0,0 +1,247 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::{Arc, LazyLock}; + +use anyhow::Context as _; +use futures::StreamExt; +use futures::future::Either; +use risingwave_common::util::addr::HostAddr; +use risingwave_pb::task_service::get_mux_stream_request::{self, AddPermits, Value}; +use risingwave_pb::task_service::{GetMuxStreamRequest, GetMuxStreamResponse}; +use risingwave_rpc_client::ComputeClient; +use risingwave_rpc_client::error::RpcError; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; + +use crate::executor::exchange::error::ExchangeChannelClosed; +use crate::executor::exchange::input::RemoteInput; +use crate::executor::prelude::StreamingMetrics; +use crate::executor::{DispatcherMessage, StreamExecutorError, StreamExecutorResult}; +use crate::task::{LocalBarrierManager, StreamEnvironment, UpDownActorIds, UpDownFragmentIds}; + +struct RegisterReq { + register: get_mux_stream_request::Register, + msg_tx: mpsc::UnboundedSender>, +} + +#[derive(Clone)] +struct Worker { + register_tx: mpsc::UnboundedSender, + join_handle: Arc>>, +} + +#[derive(Clone, Hash, PartialEq, Eq)] +struct WorkerKey { + upstream_addr: HostAddr, + init: get_mux_stream_request::Init, +} + +impl Worker { + async fn new( + client: ComputeClient, + init: get_mux_stream_request::Init, + ) -> Result { + let (stream, req_tx) = client.get_mux_stream(init).await?; + + let (register_tx, register_rx) = mpsc::unbounded_channel(); + + let task = async move { + enum Event { + Register(RegisterReq), + Response(Result), + } + + let mut stream = futures::stream_select!( + tokio_stream::wrappers::UnboundedReceiverStream::new(register_rx) + .map(Event::Register), + stream.map(Event::Response), + ); + + let mut msg_txs = HashMap::new(); + + while let Some(event) = stream.next().await { + match event { + Event::Register(RegisterReq { register, msg_tx }) => { + req_tx + .send(GetMuxStreamRequest { + value: Some(Value::Register(register)), + }) + .map_err(|_| { + ExchangeChannelClosed::remote_input(114514.into(), None) + })?; + + msg_txs.insert((register.up_actor_id, register.down_actor_id), msg_tx); + } + Event::Response(res) => { + let GetMuxStreamResponse { + message, + permits, + up_actor_id, + down_actor_id, + } = res.map_err(|e| { + ExchangeChannelClosed::remote_input(114514.into(), Some(e)) + })?; + + let actor_pair = (up_actor_id, down_actor_id); + + if let Some(msg_tx) = msg_txs.get(&actor_pair) { + use crate::executor::DispatcherMessageBatch; + let msg = message.unwrap(); + + let msg_res = DispatcherMessageBatch::from_protobuf(&msg); + + let send_result: Result<(), ()> = try { + // TODO(mux): batch putting back permits + req_tx + .send(GetMuxStreamRequest { + value: Some(Value::AddPermits(AddPermits { + up_actor_id, + down_actor_id, + permits, + })), + }) + .map_err(|_| ())?; + + match msg_res { + Ok(msg) => { + for msg in msg.into_messages() { + msg_tx.send(Ok(msg)).map_err(|_| ())?; + } + } + Err(e) => { + msg_tx.send(Err(e)).map_err(|_| ())?; + } + } + }; + + if send_result.is_err() { + msg_txs.remove(&actor_pair); + } + } + } + } + } + + Ok::<_, StreamExecutorError>(()) + }; + + let join_handle = tokio::spawn(task); + + Ok(Self { + register_tx, + join_handle: Arc::new(join_handle), + }) + } +} + +struct Mux { + cache: moka::future::Cache>>, +} + +impl Mux { + pub fn new() -> Self { + Self { + cache: moka::future::Cache::new(u64::MAX), + } + } + + async fn get( + &self, + init: get_mux_stream_request::Init, + upstream_addr: HostAddr, + env: &StreamEnvironment, + ) -> StreamExecutorResult { + self.cache + .entry(WorkerKey { + upstream_addr: upstream_addr.clone(), + init: init.clone(), + }) + .or_insert_with_if( + async move { + let client = env.client_pool().get_by_addr(upstream_addr).await?; + let worker = Worker::new(client, init).await?; + Ok(worker) + }, + |w| match w { + Ok(w) => w.join_handle.is_finished(), + Err(_) => true, + }, + ) + .await + .into_value() + .context("failed to connect to upstream") + .map_err(Into::into) + } +} + +impl RemoteInput { + /// Create a remote input with the experimental multiplexing implementation. + pub(super) async fn new_mux( + local_barrier_manager: &LocalBarrierManager, + upstream_addr: HostAddr, + up_down_ids: UpDownActorIds, + up_down_frag: UpDownFragmentIds, + _metrics: Arc, + ) -> StreamExecutorResult { + let actor_id = up_down_ids.0; + + static MUX: LazyLock = LazyLock::new(Mux::new); + + let init = get_mux_stream_request::Init { + up_fragment_id: up_down_frag.0, + down_fragment_id: up_down_frag.1, + database_id: local_barrier_manager.database_id, + term_id: local_barrier_manager.term_id.clone(), + }; + + let worker = MUX + .get(init, upstream_addr, &local_barrier_manager.env) + .await?; + + let (msg_tx, msg_rx) = mpsc::unbounded_channel(); + worker + .register_tx + .send(RegisterReq { + register: get_mux_stream_request::Register { + up_actor_id: up_down_ids.0, + down_actor_id: up_down_ids.1, + }, + msg_tx, + }) + .ok() + .context("failed to connect to upstream")?; + + Ok(Self { + actor_id, + inner: Either::Right(make_input_stream(msg_rx)), + }) + } +} + +pub(super) type MuxRemoteInputStream = impl crate::executor::DispatcherMessageStream; + +#[define_opaque(MuxRemoteInputStream)] +fn make_input_stream( + msg_rx: mpsc::UnboundedReceiver>, +) -> MuxRemoteInputStream { + tokio_stream::wrappers::UnboundedReceiverStream::new(msg_rx).chain(futures::stream::once( + async { + // Always emit an error outside the loop. This is because we use barrier as the control + // message to stop the stream. Reaching here means the channel is closed unexpectedly. + Err(ExchangeChannelClosed::remote_input(1919810.into(), None).into()) + }, + )) +} diff --git a/src/stream/src/executor/exchange/permit.rs b/src/stream/src/executor/exchange/permit.rs index 426c470c1d710..06a887b6e03cc 100644 --- a/src/stream/src/executor/exchange/permit.rs +++ b/src/stream/src/executor/exchange/permit.rs @@ -205,6 +205,13 @@ impl Receiver { pub fn permits(&self) -> Arc { self.permits.clone() } + + #[futures_async_stream::stream(item = MessageWithPermits)] + pub async fn into_raw_stream(mut self) { + while let Some(message) = self.recv_raw().await { + yield message; + } + } } impl Drop for Receiver { diff --git a/src/stream/src/executor/merge.rs b/src/stream/src/executor/merge.rs index a2672a3384698..0a473e418ddc3 100644 --- a/src/stream/src/executor/merge.rs +++ b/src/stream/src/executor/merge.rs @@ -477,7 +477,8 @@ mod tests { ExchangeService, ExchangeServiceServer, }; use risingwave_pb::task_service::{ - GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, + GetDataRequest, GetDataResponse, GetMuxStreamRequest, GetMuxStreamResponse, + GetStreamRequest, GetStreamResponse, PbPermits, }; use tokio::time::sleep; use tokio_stream::wrappers::ReceiverStream; @@ -860,6 +861,7 @@ mod tests { #[async_trait::async_trait] impl ExchangeService for FakeExchangeService { type GetDataStream = ReceiverStream>; + type GetMuxStreamStream = ReceiverStream>; type GetStreamStream = ReceiverStream>; async fn get_data( @@ -907,6 +909,13 @@ mod tests { .unwrap(); Ok(Response::new(ReceiverStream::new(rx))) } + + async fn get_mux_stream( + &self, + _request: Request>, + ) -> std::result::Result, Status> { + unimplemented!() + } } #[tokio::test]