Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions proto/task_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,32 @@ message GetStreamRequest {
}
}

message GetNewStreamRequest {
message Init {
uint32 up_fragment_id = 1;
uint32 down_fragment_id = 2;
uint32 database_id = 3;
string term_id = 4;
}

message Get {
uint32 up_actor_id = 1;
uint32 down_actor_id = 2;
}

message AddPermits {
uint32 up_actor_id = 1;
uint32 down_actor_id = 2;
Permits permits = 3;
}

oneof value {
Init init = 1;
Get get = 2;
AddPermits add_permits = 3;
}
}

message GetStreamResponse {
stream_plan.StreamMessageBatch message = 1;
// The number of permits acquired for this message, which should be sent back to the upstream with `add_permits`.
Expand All @@ -148,7 +174,21 @@ message GetStreamResponse {
Permits permits = 2;
}

message GetNewStreamResponse {
// message UpActorIds {
// repeated uint32 up_actor_id = 1;
// }

stream_plan.StreamMessageBatch message = 1;
Permits permits = 2;
uint32 up_actor_id = 3;
uint32 down_actor_id = 4;

// map<uint32, UpActorIds> down_up_actor_ids = 3;
}

service ExchangeService {
rpc GetData(GetDataRequest) returns (stream GetDataResponse);
rpc GetStream(stream GetStreamRequest) returns (stream GetStreamResponse);
rpc GetNewStream(stream GetNewStreamRequest) returns (stream GetNewStreamResponse);
}
136 changes: 134 additions & 2 deletions src/compute/src/rpc/service/exchange_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;

use either::Either;
use futures::stream::SelectAll;
use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
use futures_async_stream::try_stream;
use risingwave_batch::task::BatchManager;
use risingwave_pb::id::FragmentId;
use risingwave_pb::id::{ActorId, FragmentId};
use risingwave_pb::task_service::exchange_service_server::ExchangeService;
use risingwave_pb::task_service::{
GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, permits,
GetDataRequest, GetDataResponse, GetNewStreamRequest, GetNewStreamResponse, GetStreamRequest,
GetStreamResponse, PbPermits, get_new_stream_request, permits,
};
use risingwave_stream::executor::DispatcherMessageBatch;
use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver};
Expand All @@ -42,10 +45,13 @@ pub struct ExchangeServiceImpl {

pub type BatchDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
pub type StreamDataStream = impl Stream<Item = std::result::Result<GetStreamResponse, Status>>;
pub type NewStreamDataStream =
impl Stream<Item = std::result::Result<GetNewStreamResponse, Status>>;

#[async_trait::async_trait]
impl ExchangeService for ExchangeServiceImpl {
type GetDataStream = BatchDataStream;
type GetNewStreamStream = NewStreamDataStream;
type GetStreamStream = StreamDataStream;

async fn get_data(
Expand Down Expand Up @@ -124,6 +130,132 @@ impl ExchangeService for ExchangeServiceImpl {
(up_fragment_id, down_fragment_id),
)))
}

#[define_opaque(NewStreamDataStream)]
async fn get_new_stream(
&self,
request: Request<Streaming<GetNewStreamRequest>>,
) -> std::result::Result<Response<Self::GetNewStreamStream>, Status> {
let request_stream = request.into_inner();

Ok(Response::new(Self::get_new_stream_impl(
self.stream_mgr.clone(),
request_stream,
)))
}
}

impl ExchangeServiceImpl {
#[try_stream(ok = GetNewStreamResponse, error = Status)]
async fn get_new_stream_impl(
stream_mgr: LocalStreamManager,
mut request_stream: Streaming<GetNewStreamRequest>,
) {
use risingwave_pb::task_service::get_new_stream_request::*;

// Extract the first `Init` request from the stream.
let Init {
up_fragment_id: _,
down_fragment_id: _,
database_id,
term_id,
} = {
let req = request_stream
.next()
.await
.ok_or_else(|| Status::invalid_argument("get_new_stream request is empty"))??;
match req.value.unwrap() {
Value::Init(init) => init,
Value::Get(_) | Value::AddPermits(_) => {
unreachable!("the first message must be `Init`")
}
}
};

enum Req {
Request(Result<GetNewStreamRequest, Status>),
Message {
up_actor_id: ActorId,
down_actor_id: ActorId,
message: MessageWithPermits,
},
}

let mut select_all = SelectAll::new();
select_all.push(request_stream.map(Req::Request).boxed());

let mut all_permits = HashMap::new();

while let Some(r) = select_all.next().await {
match r {
Req::Request(req) => match req?.value.unwrap() {
Value::Init(_) => unreachable!("the stream has already been initialized"),
Value::Get(Get {
up_actor_id,
down_actor_id,
}) => {
let receiver = stream_mgr
.take_receiver(
database_id,
term_id.clone(),
(up_actor_id, down_actor_id),
)
.await?;
let permits = Arc::downgrade(&receiver.permits());
all_permits.insert((up_actor_id, down_actor_id), permits);
select_all.push(
receiver
.into_raw_stream()
.map(move |message| Req::Message {
up_actor_id,
down_actor_id,
message,
})
.boxed(),
);
}
Value::AddPermits(AddPermits {
up_actor_id,
down_actor_id,
permits,
}) => {
let to_add = permits.unwrap().value.unwrap();

if let Some(permits) = all_permits
.get(&(up_actor_id, down_actor_id))
.and_then(|p| p.upgrade())
{
permits.add_permits(to_add);
}
}
},

Req::Message {
up_actor_id,
down_actor_id,
message: MessageWithPermits { message, permits },
} => {
let message = match message {
DispatcherMessageBatch::Chunk(chunk) => {
DispatcherMessageBatch::Chunk(chunk.compact_vis())
}
msg @ (DispatcherMessageBatch::Watermark(_)
| DispatcherMessageBatch::BarrierBatch(_)) => msg,
};
let proto = message.to_protobuf();
// forward the acquired permit to the downstream
let response = GetNewStreamResponse {
message: Some(proto),
permits: Some(PbPermits { value: permits }),
up_actor_id,
down_actor_id,
};

yield response;
}
}
}
}
}

impl ExchangeServiceImpl {
Expand Down
24 changes: 23 additions & 1 deletion src/prost/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,27 @@ for_all_wrapped_id_fields! (
up_actor_id: ActorId,
down_actor_id: ActorId,
}
GetNewStreamRequest.Init {
database_id: DatabaseId,
up_fragment_id: FragmentId,
down_fragment_id: FragmentId,
}
GetNewStreamRequest.Get {
up_actor_id: ActorId,
down_actor_id: ActorId,
}
GetNewStreamRequest.AddPermits {
up_actor_id: ActorId,
down_actor_id: ActorId,
}
GetNewStreamResponse {
// down_up_actor_ids: ActorId,
up_actor_id: ActorId,
down_actor_id: ActorId,
}
// GetNewStreamResponse.UpActorIds {
// up_actor_id: ActorId,
// }
}
user {
AlterDefaultPrivilegeRequest {
Expand Down Expand Up @@ -1055,6 +1076,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.type_attribute("expr.UdfExprVersion", "#[derive(prost_helpers::Version)]")
.type_attribute("meta.Object.object_info", "#[derive(strum::Display)]")
.type_attribute("meta.SubscribeResponse.info", "#[derive(strum::Display)]")
.type_attribute("task_service.GetNewStreamRequest.Init", "#[derive(Hash, Eq)]")
// end
;

Expand All @@ -1068,7 +1090,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
//"stream_plan.StreamNode"
]);

check_declared_wrapped_fields_sorted();
// check_declared_wrapped_fields_sorted();

for (wrapped_type, wrapped_fields) in &wrapped_fields() {
for (field_name, field_type) in wrapped_fields {
Expand Down
32 changes: 30 additions & 2 deletions src/rpc_client/src/compute_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient;
use risingwave_pb::task_service::task_service_client::TaskServiceClient;
use risingwave_pb::task_service::{
CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, FastInsertRequest,
FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse,
PbPermits, TaskInfoResponse, permits,
FastInsertResponse, GetDataRequest, GetDataResponse, GetNewStreamRequest, GetNewStreamResponse,
GetStreamRequest, GetStreamResponse, PbPermits, TaskInfoResponse, get_new_stream_request,
permits,
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
Expand Down Expand Up @@ -112,6 +113,33 @@ impl ComputeClient {
.into_inner())
}

pub async fn get_new_stream(
&self,
init: get_new_stream_request::Init,
) -> Result<(
Streaming<GetNewStreamResponse>,
mpsc::UnboundedSender<GetNewStreamRequest>,
)> {
use risingwave_pb::task_service::get_new_stream_request::*;

let (request_sender, request_receiver) = mpsc::unbounded_channel();
request_sender
.send(GetNewStreamRequest {
value: Some(Value::Init(init)),
})
.unwrap();

let response_stream = self
.exchange_client
.clone()
.get_new_stream(UnboundedReceiverStream::new(request_receiver))
.await
.map_err(RpcError::from_compute_status)?
.into_inner();

Ok((response_stream, request_sender))
}

pub async fn get_stream(
&self,
up_actor_id: ActorId,
Expand Down
Loading
Loading