diff --git a/proto/task_service.proto b/proto/task_service.proto index 88f7f269bb4b9..cecd2a81046ee 100644 --- a/proto/task_service.proto +++ b/proto/task_service.proto @@ -140,6 +140,35 @@ message GetStreamRequest { } } +message GetMuxStreamRequest { + message Init { + uint32 up_fragment_id = 1; + uint32 down_fragment_id = 2; + uint32 database_id = 3; + string term_id = 4; + } + + message Register { + uint32 up_actor_id = 1; + uint32 down_actor_id = 2; + } + + message AddPermits { + uint32 up_actor_id = 1; + uint32 down_actor_id = 2; + Permits permits = 3; + } + + oneof value { + // The first message, which tells the upstream which fragment pair this exchange stream is for. + Init init = 1; + // The following messages, which registers a new actor pair to this multiplexed stream. + Register register = 2; + // The following messages, which adds the permits back to the upstream to achieve back-pressure. + AddPermits add_permits = 3; + } +} + message GetStreamResponse { stream_plan.StreamMessageBatch message = 1; // The number of permits acquired for this message, which should be sent back to the upstream with `add_permits`. @@ -148,10 +177,19 @@ message GetStreamResponse { Permits permits = 2; } +message GetMuxStreamResponse { + // TODO(mux): batch the same message (typically barrier) for different actor pairs. + stream_plan.StreamMessageBatch message = 1; + Permits permits = 2; + uint32 up_actor_id = 3; + uint32 down_actor_id = 4; +} + service BatchExchangeService { rpc GetData(GetDataRequest) returns (stream GetDataResponse); } service StreamExchangeService { rpc GetStream(stream GetStreamRequest) returns (stream GetStreamResponse); + rpc GetMuxStream(stream GetMuxStreamRequest) returns (stream GetMuxStreamResponse); } diff --git a/src/common/src/config/mod.rs b/src/common/src/config/mod.rs index 54d3dbac0cfdb..6155950a00899 100644 --- a/src/common/src/config/mod.rs +++ b/src/common/src/config/mod.rs @@ -221,6 +221,16 @@ pub mod default { 0 } + pub fn stream_exchange_force_remote() -> bool { + // TODO(mux): default to false, as it's for debugging only + true + } + + pub fn stream_exchange_remote_use_multiplexing() -> bool { + // TODO(mux): default to false until it's tested to be stable + true + } + pub fn stream_dml_channel_initial_permits() -> usize { 32768 } diff --git a/src/common/src/config/streaming.rs b/src/common/src/config/streaming.rs index 57cec1a2b9939..cb84b9ef6f184 100644 --- a/src/common/src/config/streaming.rs +++ b/src/common/src/config/streaming.rs @@ -107,6 +107,15 @@ pub struct StreamingDeveloperConfig { #[serde(default = "default::developer::stream_exchange_concurrent_dispatchers")] pub exchange_concurrent_dispatchers: usize, + /// Force all exchanges to be remote exchanges, i.e., use gRPC even for local exchanges between + /// actors on the same compute node. This is for debugging only. + #[serde(default = "default::developer::stream_exchange_force_remote")] + pub exchange_force_remote: bool, + + /// Use new experimental multiplexing implementation for remote exchange. + #[serde(default = "default::developer::stream_exchange_remote_use_multiplexing")] + pub exchange_remote_use_multiplexing: bool, + /// The initial permits for a dml channel, i.e., the maximum row count can be buffered in /// the channel. #[serde(default = "default::developer::stream_dml_channel_initial_permits")] diff --git a/src/compute/src/rpc/service/stream_exchange_service.rs b/src/compute/src/rpc/service/stream_exchange_service.rs index 4b814b6932526..41004c46ed2fc 100644 --- a/src/compute/src/rpc/service/stream_exchange_service.rs +++ b/src/compute/src/rpc/service/stream_exchange_service.rs @@ -20,7 +20,10 @@ use futures::{Stream, StreamExt, TryStreamExt, pin_mut}; use futures_async_stream::try_stream; use risingwave_pb::id::FragmentId; use risingwave_pb::task_service::stream_exchange_service_server::StreamExchangeService; -use risingwave_pb::task_service::{GetStreamRequest, GetStreamResponse, PbPermits, permits}; +use risingwave_pb::task_service::{ + GetMuxStreamRequest, GetMuxStreamResponse, GetStreamRequest, GetStreamResponse, PbPermits, + permits, +}; use risingwave_stream::executor::DispatcherMessageBatch; use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver}; use risingwave_stream::task::LocalStreamManager; @@ -28,8 +31,11 @@ use tonic::{Request, Response, Status, Streaming}; pub mod metrics; pub use metrics::{GLOBAL_STREAM_EXCHANGE_SERVICE_METRICS, StreamExchangeServiceMetrics}; +mod mux; pub type StreamDataStream = impl Stream>; +pub type MuxStreamDataStream = + impl Stream>; #[derive(Clone)] pub struct StreamExchangeServiceImpl { @@ -39,6 +45,7 @@ pub struct StreamExchangeServiceImpl { #[async_trait::async_trait] impl StreamExchangeService for StreamExchangeServiceImpl { + type GetMuxStreamStream = MuxStreamDataStream; type GetStreamStream = StreamDataStream; #[define_opaque(StreamDataStream)] @@ -92,6 +99,19 @@ impl StreamExchangeService for StreamExchangeServiceImpl { (up_fragment_id, down_fragment_id), ))) } + + #[define_opaque(MuxStreamDataStream)] + async fn get_mux_stream( + &self, + request: Request>, + ) -> std::result::Result, Status> { + let request_stream = request.into_inner(); + + Ok(Response::new(Self::get_mux_stream_impl( + self.stream_mgr.clone(), + request_stream, + ))) + } } impl StreamExchangeServiceImpl { diff --git a/src/compute/src/rpc/service/stream_exchange_service/mux.rs b/src/compute/src/rpc/service/stream_exchange_service/mux.rs new file mode 100644 index 0000000000000..06edb238ae957 --- /dev/null +++ b/src/compute/src/rpc/service/stream_exchange_service/mux.rs @@ -0,0 +1,159 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::{Arc, Weak}; + +use futures::StreamExt; +use futures::stream::SelectAll; +use futures_async_stream::try_stream; +use risingwave_pb::id::ActorId; +use risingwave_pb::task_service::{GetMuxStreamRequest, GetMuxStreamResponse, PbPermits}; +use risingwave_stream::executor::DispatcherMessageBatch; +use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Permits}; +use risingwave_stream::task::LocalStreamManager; +use tonic::{Status, Streaming}; + +use crate::rpc::service::stream_exchange_service::StreamExchangeServiceImpl; + +impl StreamExchangeServiceImpl { + #[try_stream(ok = GetMuxStreamResponse, error = Status)] + pub(super) async fn get_mux_stream_impl( + stream_mgr: LocalStreamManager, + mut request_stream: Streaming, + ) { + use risingwave_pb::task_service::get_mux_stream_request::*; + + // Extract the first `Init` request from the stream. + let Init { + up_fragment_id: _, + down_fragment_id: _, + database_id, + term_id, + } = { + let req = request_stream + .next() + .await + .ok_or_else(|| Status::invalid_argument("get_mux_stream request is empty"))??; + match req.value.unwrap() { + Value::Init(init) => init, + Value::Register(_) | Value::AddPermits(_) => { + unreachable!("the first message must be `Init`") + } + } + }; + + enum Event { + Request(Result), + ExchangeMessage { + up_actor_id: ActorId, + down_actor_id: ActorId, + message: MessageWithPermits, + }, + } + + // Merge events from the downstream client and all upstream actors. + let mut select_all = SelectAll::new(); + select_all.push(request_stream.map(Event::Request).left_stream()); + + // Weak permit handles of all registered actor pairs. + let mut permit_handles: HashMap<(ActorId, ActorId), Weak> = HashMap::new(); + + while let Some(event) = select_all.next().await { + match event { + // Request from the downstream client. + Event::Request(req) => match req?.value.unwrap() { + Value::Init(_) => unreachable!("the stream has already been initialized"), + + // Register a new actor pair to this multiplexed stream. + Value::Register(Register { + up_actor_id, + down_actor_id, + }) => { + let receiver = stream_mgr + .take_receiver( + database_id, + term_id.clone(), + (up_actor_id, down_actor_id), + ) + .await?; + permit_handles.insert( + (up_actor_id, down_actor_id), + Arc::downgrade(&receiver.permits()), + ); + select_all.push( + Box::pin(receiver.into_raw_stream()) + .map(move |message| Event::ExchangeMessage { + up_actor_id, + down_actor_id, + message, + }) + .right_stream(), + ); + } + + // Add permits back to the upstream. + Value::AddPermits(AddPermits { + up_actor_id, + down_actor_id, + permits, + }) => { + let Some(permits) = permits.unwrap().value else { + continue; + }; + if let Some(handle) = permit_handles.get(&(up_actor_id, down_actor_id)) { + if let Some(handle) = handle.upgrade() { + handle.add_permits(permits); + } else { + // The channel is already closed, ignore the request. + } + } else { + tracing::warn!( + %up_actor_id, + %down_actor_id, + ?permits, + "add permits to unregistered actor pair", + ); + } + } + }, + + // Exchange message from the upstream. + Event::ExchangeMessage { + up_actor_id, + down_actor_id, + message: MessageWithPermits { message, permits }, + } => { + let message = match message { + DispatcherMessageBatch::Chunk(chunk) => { + DispatcherMessageBatch::Chunk(chunk.compact_vis()) + } + msg @ (DispatcherMessageBatch::Watermark(_) + | DispatcherMessageBatch::BarrierBatch(_)) => msg, + }; + let proto = message.to_protobuf(); + + let response = GetMuxStreamResponse { + message: Some(proto), + permits: Some(PbPermits { value: permits }), + up_actor_id, + down_actor_id, + }; + + yield response; + } + } + } + } +} diff --git a/src/config/example.toml b/src/config/example.toml index c651eefe55fdf..bd50958d26e7b 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -190,6 +190,8 @@ exchange_initial_permits = 2048 exchange_batched_permits = 256 exchange_concurrent_barriers = 1 exchange_concurrent_dispatchers = 0 +exchange_force_remote = true +exchange_remote_use_multiplexing = true dml_channel_initial_permits = 32768 hash_agg_max_dirty_groups_heap_size = 67108864 memory_controller_threshold_aggressive = 0.9 diff --git a/src/error/src/tonic.rs b/src/error/src/tonic.rs index 5daa29ed99f4f..d01c52523f103 100644 --- a/src/error/src/tonic.rs +++ b/src/error/src/tonic.rs @@ -127,7 +127,7 @@ where /// A wrapper of [`tonic::Status`] that provides better error message and extracts /// the source chain from the `details` field. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TonicStatusWrapper { inner: tonic::Status, diff --git a/src/prost/build.rs b/src/prost/build.rs index fea31409a888a..3582a7276c65e 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -757,6 +757,23 @@ for_all_wrapped_id_fields! ( FastInsertRequest { table_id: TableId, } + GetMuxStreamRequest.AddPermits { + up_actor_id: ActorId, + down_actor_id: ActorId, + } + GetMuxStreamRequest.Init { + database_id: DatabaseId, + up_fragment_id: FragmentId, + down_fragment_id: FragmentId, + } + GetMuxStreamRequest.Register { + up_actor_id: ActorId, + down_actor_id: ActorId, + } + GetMuxStreamResponse { + up_actor_id: ActorId, + down_actor_id: ActorId, + } GetStreamRequest.Get { database_id: DatabaseId, up_fragment_id: FragmentId, @@ -1055,6 +1072,7 @@ fn main() -> Result<(), Box> { .type_attribute("expr.UdfExprVersion", "#[derive(prost_helpers::Version)]") .type_attribute("meta.Object.object_info", "#[derive(strum::Display)]") .type_attribute("meta.SubscribeResponse.info", "#[derive(strum::Display)]") + .type_attribute("task_service.GetMuxStreamRequest.Init", "#[derive(Hash, Eq)]") // end ; diff --git a/src/rpc_client/src/compute_client.rs b/src/rpc_client/src/compute_client.rs index 3641ded7d8bcb..7e4c4f86d5404 100644 --- a/src/rpc_client/src/compute_client.rs +++ b/src/rpc_client/src/compute_client.rs @@ -41,8 +41,9 @@ use risingwave_pb::task_service::stream_exchange_service_client::StreamExchangeS use risingwave_pb::task_service::task_service_client::TaskServiceClient; use risingwave_pb::task_service::{ CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest, - FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, - PbPermits, TaskInfoResponse, permits, + FastInsertResponse, GetDataRequest, GetDataResponse, GetMuxStreamRequest, GetMuxStreamResponse, + GetStreamRequest, GetStreamResponse, PbPermits, TaskInfoResponse, get_mux_stream_request, + permits, }; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -117,6 +118,42 @@ impl ComputeClient { .into_inner()) } + pub async fn get_mux_stream( + &self, + init: get_mux_stream_request::Init, + ) -> Result<( + Streaming, + mpsc::UnboundedSender, + )> { + use risingwave_pb::task_service::get_mux_stream_request::*; + + // Create channel used for future requests (including register new actor pairs and add permits) to the upstream. + let (request_sender, request_receiver) = mpsc::unbounded_channel(); + request_sender + .send(GetMuxStreamRequest { + value: Some(Value::Init(init.clone())), + }) + .unwrap(); + + let response_stream = self + .stream_exchange_client + .clone() + .get_mux_stream(UnboundedReceiverStream::new(request_receiver)) + .await + .inspect_err(|_| { + tracing::error!( + "failed to create mux stream from remote_input {} from fragment {} to fragment {}", + self.addr, + init.up_fragment_id, + init.down_fragment_id + ) + }) + .map_err(RpcError::from_compute_status)? + .into_inner(); + + Ok((response_stream, request_sender)) + } + pub async fn get_stream( &self, up_actor_id: ActorId, diff --git a/src/stream/src/executor/exchange/error.rs b/src/stream/src/executor/exchange/error.rs index 334e54f73f123..0492c0cd3c06d 100644 --- a/src/stream/src/executor/exchange/error.rs +++ b/src/stream/src/executor/exchange/error.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_pb::id::FragmentId; use risingwave_rpc_client::error::TonicStatusWrapper; use crate::task::ActorId; @@ -22,7 +23,7 @@ use crate::task::ActorId; /// exits or panics on other errors, or the network connection is broken. /// Therefore, this error is usually not the root case of the failure in the /// streaming graph. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ExchangeChannelClosed { message: String, @@ -83,6 +84,17 @@ impl ExchangeChannelClosed { } } + /// Creates a new error indicating that the multiplexed exchange channel from the remote + /// upstream is closed unexpectedly, with an optional gRPC error as the cause. + pub fn remote_input_fragment(fragment: FragmentId, source: Option) -> Self { + Self { + message: format!( + "multiplexed exchange channel from remote fragment {fragment} closed unexpectedly", + ), + source: source.map(Into::into), + } + } + /// Creates a new error indicating that the exchange channel to the downstream /// actor is closed unexpectedly. pub fn output(downstream: ActorId) -> Self { diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index c3a7d09876d31..19cbb92f9848b 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -12,15 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub use mux_remote_input::MuxExchangeWorkers; + +mod mux_remote_input; + use std::pin::Pin; use std::task::{Context, Poll}; -use either::Either; +use futures::future::Either; use local_input::LocalInputStreamInner; use pin_project::pin_project; use risingwave_common::util::addr::{HostAddr, is_local_address}; +use risingwave_pb::common::ActorInfo; use super::permit::Receiver; +use crate::executor::exchange::input::mux_remote_input::MuxRemoteInputStream; +use crate::executor::exchange::input::remote_input::RemoteInputStream; use crate::executor::prelude::*; use crate::executor::{ BarrierInner, DispatcherMessage, DispatcherMessageBatch, DispatcherMessageStreamItem, @@ -87,7 +94,6 @@ impl LocalInput { mod local_input { use await_tree::InstrumentAwait; - use either::Either; use crate::executor::exchange::error::ExchangeChannelClosed; use crate::executor::exchange::permit::Receiver; @@ -106,15 +112,8 @@ mod local_input { async fn run_inner(mut channel: Receiver, upstream_actor_id: ActorId) { let span = await_tree::span!("LocalInput (actor {upstream_actor_id})").verbose(); while let Some(msg) = channel.recv().instrument_await(span.clone()).await { - match msg.into_messages() { - Either::Left(barriers) => { - for b in barriers { - yield b; - } - } - Either::Right(m) => { - yield m; - } + for msg in msg.into_messages() { + yield msg; } } // Always emit an error outside the loop. This is because we use barrier as the control @@ -144,14 +143,11 @@ impl Input for LocalInput { #[pin_project] pub struct RemoteInput { #[pin] - inner: RemoteInputStreamInner, + inner: Either, actor_id: ActorId, } -use remote_input::RemoteInputStreamInner; -use risingwave_pb::common::ActorInfo; - impl RemoteInput { /// Create a remote input from compute client and related info. Should provide the corresponding /// compute client of where the actor is placed. @@ -161,6 +157,40 @@ impl RemoteInput { up_down_ids: UpDownActorIds, up_down_frag: UpDownFragmentIds, metrics: Arc, + ) -> StreamExecutorResult { + if local_barrier_manager + .env + .global_config() + .developer + .exchange_remote_use_multiplexing + { + RemoteInput::new_mux( + local_barrier_manager, + upstream_addr, + up_down_ids, + up_down_frag, + metrics, + ) + .await + } else { + RemoteInput::new_simple( + local_barrier_manager, + upstream_addr, + up_down_ids, + up_down_frag, + metrics, + ) + .await + } + } + + /// Create a remote input with the simple, per actor-pair implementation. + pub(crate) async fn new_simple( + local_barrier_manager: &LocalBarrierManager, + upstream_addr: HostAddr, + up_down_ids: UpDownActorIds, + up_down_frag: UpDownFragmentIds, + metrics: Arc, ) -> StreamExecutorResult { let actor_id = up_down_ids.0; @@ -169,6 +199,7 @@ impl RemoteInput { .client_pool() .get_by_addr(upstream_addr) .await?; + let (stream, permits_tx) = client .get_stream( up_down_ids.0, @@ -182,7 +213,7 @@ impl RemoteInput { Ok(Self { actor_id, - inner: remote_input::run( + inner: Either::Left(remote_input::run( stream, permits_tx, up_down_ids, @@ -193,7 +224,7 @@ impl RemoteInput { .global_config() .developer .exchange_batched_permits, - ), + )), }) } } @@ -203,7 +234,6 @@ mod remote_input { use anyhow::Context; use await_tree::InstrumentAwait; - use either::Either; use risingwave_pb::task_service::{GetStreamResponse, permits}; use tokio::sync::mpsc; use tonic::Streaming; @@ -214,9 +244,9 @@ mod remote_input { use crate::executor::{DispatcherMessage, StreamExecutorError}; use crate::task::{UpDownActorIds, UpDownFragmentIds}; - pub(super) type RemoteInputStreamInner = impl crate::executor::DispatcherMessageStream; + pub(super) type RemoteInputStream = impl crate::executor::DispatcherMessageStream; - #[define_opaque(RemoteInputStreamInner)] + #[define_opaque(RemoteInputStream)] pub(super) fn run( stream: Streaming, permits_tx: mpsc::UnboundedSender, @@ -224,7 +254,7 @@ mod remote_input { up_down_frag: UpDownFragmentIds, metrics: Arc, batched_permits_limit: usize, - ) -> RemoteInputStreamInner { + ) -> RemoteInputStream { run_inner( stream, permits_tx, @@ -288,15 +318,8 @@ mod remote_input { } let msg = msg_res.context("RemoteInput decode message error")?; - match msg.into_messages() { - Either::Left(barriers) => { - for b in barriers { - yield b; - } - } - Either::Right(m) => { - yield m; - } + for msg in msg.into_messages() { + yield msg; } } @@ -336,10 +359,18 @@ pub(crate) async fn new_input( upstream_actor_info: &ActorInfo, upstream_fragment_id: FragmentId, ) -> StreamExecutorResult { + let force_remote = local_barrier_manager + .env + .global_config() + .developer + .exchange_force_remote; + let upstream_actor_id = upstream_actor_info.actor_id; let upstream_addr = upstream_actor_info.get_host()?.into(); - let input = if is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) { + let input = if !force_remote + && is_local_address(local_barrier_manager.env.server_address(), &upstream_addr) + { LocalInput::new( local_barrier_manager.register_local_upstream_output(actor_id, upstream_actor_id), upstream_actor_id, @@ -361,13 +392,20 @@ pub(crate) async fn new_input( } impl DispatcherMessageBatch { - fn into_messages(self) -> Either, DispatcherMessage> { + /// Split the batch into multiple messages. + fn into_messages(self) -> impl ExactSizeIterator { + use either::Either; + match self { DispatcherMessageBatch::BarrierBatch(barriers) => { Either::Left(barriers.into_iter().map(DispatcherMessage::Barrier)) } - DispatcherMessageBatch::Chunk(c) => Either::Right(DispatcherMessage::Chunk(c)), - DispatcherMessageBatch::Watermark(w) => Either::Right(DispatcherMessage::Watermark(w)), + DispatcherMessageBatch::Chunk(c) => { + Either::Right(std::iter::once(DispatcherMessage::Chunk(c))) + } + DispatcherMessageBatch::Watermark(w) => { + Either::Right(std::iter::once(DispatcherMessage::Watermark(w))) + } } } } diff --git a/src/stream/src/executor/exchange/input/mux_remote_input.rs b/src/stream/src/executor/exchange/input/mux_remote_input.rs new file mode 100644 index 0000000000000..b9e9661236b27 --- /dev/null +++ b/src/stream/src/executor/exchange/input/mux_remote_input.rs @@ -0,0 +1,301 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use anyhow::Context as _; +use futures::StreamExt; +use futures::future::Either; +use risingwave_common::util::addr::HostAddr; +use risingwave_pb::id::{ActorId, FragmentId}; +use risingwave_pb::task_service::get_mux_stream_request::{self, AddPermits, Value}; +use risingwave_pb::task_service::{GetMuxStreamRequest, GetMuxStreamResponse}; +use risingwave_rpc_client::error::RpcError; +use risingwave_rpc_client::{ComputeClient, ComputeClientPoolRef}; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::Streaming; + +use crate::executor::exchange::error::ExchangeChannelClosed; +use crate::executor::exchange::input::RemoteInput; +use crate::executor::prelude::StreamingMetrics; +use crate::executor::{DispatcherMessage, DispatcherMessageBatch, StreamExecutorResult}; +use crate::task::{LocalBarrierManager, UpDownActorIds, UpDownFragmentIds}; + +struct RegisterReq { + register: get_mux_stream_request::Register, + msg_tx: mpsc::UnboundedSender>, +} + +#[derive(Clone)] +struct Worker { + register_tx: mpsc::UnboundedSender, + join_handle: Arc>, +} + +impl Worker { + /// Create a new worker by calling `get_mux_stream` to the upstream and running the loop. + async fn new( + client: ComputeClient, + init: get_mux_stream_request::Init, + ) -> Result { + let up_fragment_id = init.up_fragment_id; + let (stream, req_tx) = client.get_mux_stream(init).await?; + let (register_tx, register_rx) = mpsc::unbounded_channel(); + + let join_handle = tokio::spawn(Self::run(up_fragment_id, stream, req_tx, register_rx)); + + Ok(Self { + register_tx, + join_handle: Arc::new(join_handle), + }) + } + + async fn run( + up_fragment_id: FragmentId, + stream: Streaming, + req_tx: mpsc::UnboundedSender, + register_rx: mpsc::UnboundedReceiver, + ) { + enum Event { + Register(RegisterReq), + Response(Result), + } + + // Merge events from the register channel and the upstream response stream. + let mut stream = futures::stream_select!( + UnboundedReceiverStream::new(register_rx).map(Event::Register), + stream.map(Event::Response), + ); + + // All registered message channels. + let mut msg_txs: HashMap<(ActorId, ActorId), mpsc::UnboundedSender<_>> = HashMap::new(); + + let result: Result<(), ExchangeChannelClosed> = try { + while let Some(event) = stream.next().await { + match event { + // Register a new actor pair. + Event::Register(RegisterReq { register, msg_tx }) => { + req_tx + .send(GetMuxStreamRequest { + value: Some(Value::Register(register)), + }) + .map_err(|_| { + ExchangeChannelClosed::remote_input_fragment(up_fragment_id, None) + })?; + + msg_txs.insert((register.up_actor_id, register.down_actor_id), msg_tx); + } + + // Exchange message from the upstream. + Event::Response(res) => { + let GetMuxStreamResponse { + message, + permits, + up_actor_id, + down_actor_id, + } = res.map_err(|e| { + ExchangeChannelClosed::remote_input_fragment(up_fragment_id, Some(e)) + })?; + + let actor_pair = (up_actor_id, down_actor_id); + let Some(msg_tx) = msg_txs.get(&actor_pair) else { + tracing::warn!( + %up_actor_id, + %down_actor_id, + "received message for unregistered actor pair" + ); + continue; + }; + + // TODO(mux): batch putting back permits + req_tx + .send(GetMuxStreamRequest { + value: Some(Value::AddPermits(AddPermits { + up_actor_id, + down_actor_id, + permits, + })), + }) + .map_err(|_| { + ExchangeChannelClosed::remote_input_fragment(up_fragment_id, None) + })?; + + // Any error occurred during sending the message to the specific actor should be + // treated as the actor disconnected. We should gracefully remove the actor from + // the worker, instead of failing the whole worker. + let send_result: Result<(), ()> = try { + let msg = message.unwrap(); + + match DispatcherMessageBatch::from_protobuf(&msg) { + Ok(msg) => { + for msg in msg.into_messages() { + msg_tx.send(Ok(msg)).map_err(|_| ())?; + } + } + Err(e) => { + msg_tx.send(Err(e)).map_err(|_| ())?; + } + } + }; + + if send_result.is_err() { + tracing::debug!( + %up_actor_id, + %down_actor_id, + "downstream actor disconnected, removing from mux worker", + ); + msg_txs.remove(&actor_pair); + } + } + } + } + }; + + // Forward worker error to all registered actors. + // Although we always emit an error in `MuxRemoteInputStream`, this can be more accurate. + if let Err(e) = result { + for (_, msg_tx) in msg_txs { + msg_tx.send(Err(e.clone().into())).ok(); + } + } + } +} + +#[derive(Clone, Hash, PartialEq, Eq)] +struct WorkerKey { + upstream_addr: HostAddr, + init: get_mux_stream_request::Init, +} + +/// Workers for multiplexed exchange. +#[derive(Clone)] +pub struct MuxExchangeWorkers { + // Note: we store `Result` in the value in order to use moka's `or_insert_with_if`. + cache: moka::future::Cache>>, +} + +impl std::fmt::Debug for MuxExchangeWorkers { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MuxExchangeWorkers").finish_non_exhaustive() + } +} + +impl MuxExchangeWorkers { + /// Create a new `MuxExchangeWorkers`. + pub(crate) fn new() -> Self { + Self { + cache: moka::future::Cache::new(u64::MAX), + } + } + + /// Return a worker for the given remote exchange information. + /// + /// A worker will be reused if the `init` and `upstream_addr` of the remote exchange match, i.e., + /// the exchange data of actor pairs in the same fragment pair, upstream worker node, database + /// and term will be multiplexed. + /// + /// If the worker doesn't exist or the previous one has disconnected, a new one will be created + /// using the given `client_pool`. + async fn get( + &self, + init: get_mux_stream_request::Init, + upstream_addr: HostAddr, + client_pool: ComputeClientPoolRef, + ) -> StreamExecutorResult { + self.cache + .entry(WorkerKey { + upstream_addr: upstream_addr.clone(), + init: init.clone(), + }) + .or_insert_with_if( + async move { + let client = client_pool.get_by_addr(upstream_addr).await?; + let worker = Worker::new(client, init).await?; + Ok(worker) + }, + |w| match w { + // Create a new one if the previous worker exited. + Ok(w) => w.join_handle.is_finished(), + // Create a new one if previous connection failed. + Err(_) => true, + }, + ) + .await + .into_value() + .context("failed to create a mux exchange worker to upstream") + .map_err(Into::into) + } +} + +impl RemoteInput { + /// Create a remote input with the experimental multiplexing implementation. + pub(super) async fn new_mux( + local_barrier_manager: &LocalBarrierManager, + upstream_addr: HostAddr, + up_down_ids: UpDownActorIds, + up_down_frag: UpDownFragmentIds, + _metrics: Arc, + ) -> StreamExecutorResult { + let env = &local_barrier_manager.env; + let actor_id = up_down_ids.0; + + let init = get_mux_stream_request::Init { + up_fragment_id: up_down_frag.0, + down_fragment_id: up_down_frag.1, + database_id: local_barrier_manager.database_id, + term_id: local_barrier_manager.term_id.clone(), + }; + + let worker = env + .mux_exchange_workers() + .get(init, upstream_addr, env.client_pool()) + .await?; + + let (msg_tx, msg_rx) = mpsc::unbounded_channel(); + worker + .register_tx + .send(RegisterReq { + register: get_mux_stream_request::Register { + up_actor_id: up_down_ids.0, + down_actor_id: up_down_ids.1, + }, + msg_tx, + }) + .ok() + // Worker exited immediately after we got it, could be connection error. + .context("failed to connect to upstream")?; + + Ok(Self { + actor_id, + inner: Either::Right(make_input_stream(actor_id, msg_rx)), + }) + } +} + +pub(super) type MuxRemoteInputStream = impl crate::executor::DispatcherMessageStream; + +#[define_opaque(MuxRemoteInputStream)] +fn make_input_stream( + actor_id: ActorId, + msg_rx: mpsc::UnboundedReceiver>, +) -> MuxRemoteInputStream { + UnboundedReceiverStream::new(msg_rx).chain(futures::stream::once(async move { + // Always emit an error after worker exited. This is because we use barrier as the control + // message to stop the stream. Reaching here means the channel is closed unexpectedly. + Err(ExchangeChannelClosed::remote_input(actor_id, None).into()) + })) +} diff --git a/src/stream/src/executor/exchange/permit.rs b/src/stream/src/executor/exchange/permit.rs index 426c470c1d710..5fba4c2c136a1 100644 --- a/src/stream/src/executor/exchange/permit.rs +++ b/src/stream/src/executor/exchange/permit.rs @@ -205,6 +205,16 @@ impl Receiver { pub fn permits(&self) -> Arc { self.permits.clone() } + + /// Convert into a stream of [`MessageWithPermits`]. + #[futures_async_stream::stream(item = MessageWithPermits)] + pub async fn into_raw_stream(mut self) { + // Note: we don't use `tokio_stream::wrapper` because it needs destructuring `self`, + // which is impossible since there's `Drop`. + while let Some(message) = self.recv_raw().await { + yield message; + } + } } impl Drop for Receiver { diff --git a/src/stream/src/executor/merge.rs b/src/stream/src/executor/merge.rs index 75b09e1b73bbc..2083e21beed0c 100644 --- a/src/stream/src/executor/merge.rs +++ b/src/stream/src/executor/merge.rs @@ -476,7 +476,9 @@ mod tests { use risingwave_pb::task_service::stream_exchange_service_server::{ StreamExchangeService, StreamExchangeServiceServer, }; - use risingwave_pb::task_service::{GetStreamRequest, GetStreamResponse, PbPermits}; + use risingwave_pb::task_service::{ + GetMuxStreamRequest, GetMuxStreamResponse, GetStreamRequest, GetStreamResponse, PbPermits, + }; use tokio::time::sleep; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status, Streaming}; @@ -857,6 +859,7 @@ mod tests { #[async_trait::async_trait] impl StreamExchangeService for FakeExchangeService { + type GetMuxStreamStream = ReceiverStream>; type GetStreamStream = ReceiverStream>; async fn get_stream( @@ -897,6 +900,13 @@ mod tests { .unwrap(); Ok(Response::new(ReceiverStream::new(rx))) } + + async fn get_mux_stream( + &self, + _request: Request>, + ) -> std::result::Result, Status> { + unreachable!("test case should use simple `get_stream`") + } } #[tokio::test] @@ -928,7 +938,7 @@ mod tests { let test_env = LocalBarrierTestEnv::for_test().await; let remote_input = { - RemoteInput::new( + RemoteInput::new_simple( &test_env.local_barrier_manager, addr.into(), (0.into(), 0.into()), diff --git a/src/stream/src/task/env.rs b/src/stream/src/task/env.rs index dce999d1bfc28..b961f4aa8068b 100644 --- a/src/stream/src/task/env.rs +++ b/src/stream/src/task/env.rs @@ -24,6 +24,8 @@ use risingwave_dml::dml_manager::DmlManagerRef; use risingwave_rpc_client::{ComputeClientPoolRef, MetaClient}; use risingwave_storage::StateStoreImpl; +use crate::executor::exchange::input::MuxExchangeWorkers; + /// The global environment for task execution. /// The instance will be shared by every task. #[derive(Clone, Debug)] @@ -60,6 +62,9 @@ pub struct StreamEnvironment { /// Compute client pool for streaming gRPC exchange. client_pool: ComputeClientPoolRef, + + /// Workers for multiplexed exchange. + mux_exchange_workers: MuxExchangeWorkers, } impl StreamEnvironment { @@ -86,6 +91,7 @@ impl StreamEnvironment { total_mem_val: Arc::new(TrAdder::new()), meta_client: Some(meta_client), client_pool, + mux_exchange_workers: MuxExchangeWorkers::new(), } } @@ -95,9 +101,14 @@ impl StreamEnvironment { use risingwave_dml::dml_manager::DmlManager; use risingwave_rpc_client::ComputeClientPool; use risingwave_storage::monitor::MonitoredStorageMetrics; + + // TODO(mux): remove this after we default `exchange_force_remote` to false + let mut config = StreamingConfig::default(); + config.developer.exchange_force_remote = false; + StreamEnvironment { server_addr: "127.0.0.1:2333".parse().unwrap(), - global_config: Arc::new(StreamingConfig::default()), + global_config: Arc::new(config), worker_id: WorkerNodeId::default(), state_store: StateStoreImpl::shared_in_memory_store(Arc::new( MonitoredStorageMetrics::unused(), @@ -108,6 +119,7 @@ impl StreamEnvironment { total_mem_val: Arc::new(TrAdder::new()), meta_client: None, client_pool: Arc::new(ComputeClientPool::for_test()), + mux_exchange_workers: MuxExchangeWorkers::new(), } } @@ -150,4 +162,8 @@ impl StreamEnvironment { pub fn client_pool(&self) -> ComputeClientPoolRef { self.client_pool.clone() } + + pub fn mux_exchange_workers(&self) -> &MuxExchangeWorkers { + &self.mux_exchange_workers + } }