Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions src/brpc/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Stream::Stream()
, _remote_consumed(0)
, _cur_buf_size(0)
, _local_consumed(0)
, _atomic_local_consumed(0)
, _parse_rpc_response(false)
, _pending_buf(NULL)
, _start_idle_timer_us(0)
Expand Down Expand Up @@ -287,14 +288,21 @@ void Stream::SetConnected(const StreamSettings* remote_settings) {
CHECK(_host_socket != NULL);
RPC_VLOG << "stream=" << id() << " is connected to stream_id="
<< _remote_settings.stream_id() << " at host_socket=" << *_host_socket;
_connected = true;
_connected.store(true, butil::memory_order_release);
_connect_meta.ec = 0;
TriggerOnConnectIfNeed();
if (remote_settings == NULL) {
// Start the timer at server-side
// Client-side timer would triggered in Consume after received the first
// message which is the very RPC response
StartIdleTimer();
} else {
// send first feedback for client-side stream if it already consumed data
if (_remote_settings.need_feedback()) {
auto consumed_bytes = _atomic_local_consumed.load(butil::memory_order_acquire);
if (consumed_bytes > 0)
SendFeedback(consumed_bytes);
}
}
}

Expand Down Expand Up @@ -620,20 +628,34 @@ int Stream::Consume(void *meta, bthread::TaskIterator<butil::IOBuf*>& iter) {
}
mb.flush();

if (s->_remote_settings.need_feedback() && mb.total_length() > 0) {
s->_local_consumed += mb.total_length();
s->SendFeedback();
auto total_length = mb.total_length();
if (total_length > 0) {
// fast path for connected stream
if (s->_connected.load(butil::memory_order_acquire)){
if (s->_remote_settings.need_feedback()) {
s->_local_consumed += total_length;
s->SendFeedback(s->_local_consumed);
}
} else {
// Under the scenario of batch creation of Streams, there is concurrency between SetConnected and Consume for the same stream,
// and it is necessary to ensure the memory order.
s->_local_consumed = s->_atomic_local_consumed.fetch_add(total_length, butil::memory_order_release) + total_length;
if (s->_connected.load(butil::memory_order_acquire) && s->_remote_settings.need_feedback()) {
s->SendFeedback(s->_local_consumed);
}
}
}

s->StartIdleTimer();
return 0;
}

void Stream::SendFeedback() {
void Stream::SendFeedback(int64_t _consumed_bytes) {
StreamFrameMeta fm;
fm.set_frame_type(FRAME_TYPE_FEEDBACK);
fm.set_stream_id(_remote_settings.stream_id());
fm.set_source_stream_id(id());
fm.mutable_feedback()->set_consumed_size(_local_consumed);
fm.mutable_feedback()->set_consumed_size(_consumed_bytes);
butil::IOBuf out;
policy::PackStreamMessage(&out, fm, NULL);
WriteToHostSocket(&out);
Expand Down
7 changes: 4 additions & 3 deletions src/brpc/stream_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ friend struct butil::DefaultDeleter<Stream>;
void TriggerOnConnectIfNeed();
void Wait(void (*on_writable)(StreamId, void*, int), void* arg,
const timespec* due_time, bool new_thread, bthread_id_t *join_id);
void SendFeedback();
void SendFeedback(int64_t _consumed_bytes);
void StartIdleTimer();
void StopIdleTimer();
void HandleRpcResponse(butil::IOBuf* response_buffer);
Expand Down Expand Up @@ -115,7 +115,7 @@ friend struct butil::DefaultDeleter<Stream>;

bthread_mutex_t _connect_mutex;
ConnectMeta _connect_meta;
bool _connected;
butil::atomic<bool> _connected;
bool _closed;
int _error_code;
std::string _error_text;
Expand All @@ -127,7 +127,8 @@ friend struct butil::DefaultDeleter<Stream>;
bthread_id_list_t _writable_wait_list;

int64_t _local_consumed;
StreamSettings _remote_settings;
butil::atomic<int64_t> _atomic_local_consumed;
StreamSettings _remote_settings;

bool _parse_rpc_response;
bthread::ExecutionQueueId<butil::IOBuf*> _consumer_queue;
Expand Down
236 changes: 235 additions & 1 deletion test/brpc_streaming_rpc_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
// Date: 2015/10/22 16:28:44

#include <gtest/gtest.h>
#include <atomic>
#include "brpc/server.h"

#include "brpc/controller.h"
#include "brpc/channel.h"
#include "brpc/callback.h"
#include "brpc/socket.h"
#include "brpc/stream_impl.h"
#include "brpc/policy/streaming_rpc_protocol.h"
Expand Down Expand Up @@ -54,7 +56,7 @@ class MyServiceWithStream : public test::EchoService {
const ::test::EchoRequest* request,
::test::EchoResponse* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_gurad(done);
brpc::ClosureGuard done_guard(done);
response->set_message(request->message());
brpc::Controller* cntl = (brpc::Controller*)controller;
brpc::StreamId response_stream;
Expand All @@ -78,6 +80,158 @@ class StreamingRpcTest : public testing::Test {
test::EchoResponse response;
};

struct BatchStreamFeedbackRaceState {
brpc::StreamId server_first_stream_id{brpc::INVALID_STREAM_ID};
brpc::StreamId server_extra_stream_id{brpc::INVALID_STREAM_ID};
brpc::StreamId client_extra_stream_id{brpc::INVALID_STREAM_ID};

std::atomic<int> server_first_write_rc{-1};
std::atomic<int> server_second_write_rc{-1};
std::atomic<bool> client_got_first_msg{false};
std::atomic<bool> client_got_second_msg{false};
std::atomic<bool> server_write_done{false};
std::atomic<bool> rpc_done{false};

bthread_t server_send_tid{0};
std::atomic<bool> server_send_started{false};
};

class BatchStreamClientHandler : public brpc::StreamInputHandler {
public:
explicit BatchStreamClientHandler(BatchStreamFeedbackRaceState* state)
: _state(state) {}

int on_received_messages(brpc::StreamId id,
butil::IOBuf* const messages[],
size_t size) override {
if (id != _state->client_extra_stream_id) {
// This test only cares about extra stream in batch creation.
return 0;
}
for (size_t i = 0; i < size; ++i) {
const size_t len = messages[i]->length();
messages[i]->clear();
// First payload: 64 bytes. Second payload: 1 byte.
if (len == 64) {
_state->client_got_first_msg.store(true, std::memory_order_release);
} else if (len == 1) {
_state->client_got_second_msg.store(true, std::memory_order_release);
}
}
return 0;
}

void on_idle_timeout(brpc::StreamId /*id*/) override {}

void on_closed(brpc::StreamId /*id*/) override {}

void on_failed(brpc::StreamId /*id*/, int /*error_code*/, const std::string& /*error_text*/) override {}

private:
BatchStreamFeedbackRaceState* _state;
};

static void* SendTwoMessagesOnServerExtraStream(void* arg) {
auto* state = static_cast<BatchStreamFeedbackRaceState*>(arg);
const brpc::StreamId sid = state->server_extra_stream_id;

// Wait until server-side stream is connected.
const int64_t connect_deadline_us = butil::gettimeofday_us() + 2 * 1000 * 1000L;
bool connected = false;
while (butil::gettimeofday_us() < connect_deadline_us) {
brpc::SocketUniquePtr ptr;
if (brpc::Socket::Address(sid, &ptr) == 0) {
brpc::Stream* s = static_cast<brpc::Stream*>(ptr->conn());
if (s->_host_socket != NULL && s->_connected) {
connected = true;
break;
}
}
usleep(1000);
}

if (!connected) {
state->server_first_write_rc.store(ETIMEDOUT, std::memory_order_relaxed);
state->server_second_write_rc.store(ETIMEDOUT, std::memory_order_relaxed);
state->server_write_done.store(true, std::memory_order_release);
return NULL;
}

// 1) Send a payload exactly equal to max_buf_size(64).
{
std::string payload(64, 'a');
butil::IOBuf out;
out.append(payload);
state->server_first_write_rc.store(brpc::StreamWrite(sid, out), std::memory_order_relaxed);
}

// 2) Then send another byte. This write should become writable only after
// client sends FEEDBACK with consumed_size >= 64.
const int64_t write_deadline_us = butil::gettimeofday_us() + 2 * 1000 * 1000L;
int rc = -1;
while (butil::gettimeofday_us() < write_deadline_us) {
butil::IOBuf out;
out.append("b", 1);
rc = brpc::StreamWrite(sid, out);
if (rc == 0) {
break;
}
if (rc != EAGAIN) {
break;
}
const timespec duetime = butil::milliseconds_from_now(100);
(void)brpc::StreamWait(sid, &duetime);
}
state->server_second_write_rc.store(rc, std::memory_order_relaxed);
state->server_write_done.store(true, std::memory_order_release);
return NULL;
}

class MyServiceWithBatchStream : public test::EchoService {
public:
MyServiceWithBatchStream(const brpc::StreamOptions& options,
BatchStreamFeedbackRaceState* state)
: _options(options), _state(state) {}

void Echo(::google::protobuf::RpcController* controller,
const ::test::EchoRequest* request,
::test::EchoResponse* response,
::google::protobuf::Closure* done) override {
brpc::ClosureGuard done_guard(done);
response->set_message(request->message());
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);

brpc::StreamIds response_streams;
ASSERT_EQ(0, brpc::StreamAccept(response_streams, *cntl, &_options));
ASSERT_EQ(2u, response_streams.size());
_state->server_first_stream_id = response_streams[0];
_state->server_extra_stream_id = response_streams[1];

bthread_t tid;
ASSERT_EQ(0, bthread_start_background(
&tid, &BTHREAD_ATTR_NORMAL,
SendTwoMessagesOnServerExtraStream, _state));
_state->server_send_tid = tid;
_state->server_send_started.store(true, std::memory_order_release);
}

private:
brpc::StreamOptions _options;
BatchStreamFeedbackRaceState* _state;
};

static void SetAtomicTrue(std::atomic<bool>* f) {
f->store(true, std::memory_order_release);
}

static bool WaitForTrue(const std::atomic<bool>& f, int timeout_ms) {
const int64_t deadline_us = butil::gettimeofday_us() + (int64_t)timeout_ms * 1000L;
while (!f.load(std::memory_order_acquire) && butil::gettimeofday_us() < deadline_us) {
usleep(1000);
}
return f.load(std::memory_order_acquire);
}

TEST_F(StreamingRpcTest, sanity) {
brpc::Server server;
MyServiceWithStream service;
Expand All @@ -98,6 +252,86 @@ TEST_F(StreamingRpcTest, sanity) {
server.Join();
}

TEST_F(StreamingRpcTest, batch_create_stream_feedback_race) {
BatchStreamFeedbackRaceState state;
BatchStreamClientHandler client_handler(&state);

brpc::StreamOptions server_stream_opt;
// Make server-side sender sensitive to FEEDBACK quickly.
server_stream_opt.max_buf_size = 16;

brpc::Server server;
MyServiceWithBatchStream service(server_stream_opt, &state);
ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE));
ASSERT_EQ(0, server.Start(9007, NULL));

brpc::Channel channel;
ASSERT_EQ(0, channel.Init("127.0.0.1:9007", NULL));

brpc::Controller cntl;
brpc::StreamIds request_streams;
brpc::StreamOptions client_stream_opt;
client_stream_opt.handler = &client_handler;
client_stream_opt.max_buf_size = 0;
ASSERT_EQ(0, brpc::StreamCreate(request_streams, 2, cntl, &client_stream_opt));
ASSERT_EQ(2u, request_streams.size());
state.client_extra_stream_id = request_streams[1];

// Block SetConnected() on the extra stream to enlarge the race window.
brpc::SocketUniquePtr client_extra_ptr;
ASSERT_EQ(0, brpc::Socket::Address(state.client_extra_stream_id, &client_extra_ptr));
brpc::Stream* client_extra_stream = static_cast<brpc::Stream*>(client_extra_ptr->conn());
bthread_mutex_lock(&client_extra_stream->_connect_mutex);
struct UnlockGuard {
bthread_mutex_t* m;
~UnlockGuard() {
if (m) {
bthread_mutex_unlock(m);
}
}
} unlock_guard{&client_extra_stream->_connect_mutex};

BRPC_SCOPE_EXIT {
if (state.server_extra_stream_id != brpc::INVALID_STREAM_ID) {
brpc::StreamClose(state.server_extra_stream_id);
}
if (state.server_first_stream_id != brpc::INVALID_STREAM_ID) {
brpc::StreamClose(state.server_first_stream_id);
}
for (auto sid : request_streams) {
brpc::StreamClose(sid);
}

if (state.server_send_tid) {
bthread_join(state.server_send_tid, NULL);
}
server.Stop(0);
server.Join();
};

test::EchoService_Stub stub(&channel);
stub.Echo(&cntl, &request, &response, brpc::NewCallback(SetAtomicTrue, &state.rpc_done));

// Wait until client consumes the first 64B payload on extra stream.
ASSERT_TRUE(WaitForTrue(state.client_got_first_msg, 2000));

// Unblock SetConnected(); the fix in PR 3215 should send the first FEEDBACK
// with consumed_size=64 here, making server-side stream writable again.
bthread_mutex_unlock(&client_extra_stream->_connect_mutex);
unlock_guard.m = NULL;

ASSERT_TRUE(WaitForTrue(state.rpc_done, 2000));
ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText();

// Wait for server-side send thread to be started.
ASSERT_TRUE(WaitForTrue(state.server_send_started, 2000));

ASSERT_TRUE(WaitForTrue(state.server_write_done, 2000));
ASSERT_EQ(0, state.server_first_write_rc.load(std::memory_order_relaxed));
ASSERT_EQ(0, state.server_second_write_rc.load(std::memory_order_relaxed));
ASSERT_TRUE(WaitForTrue(state.client_got_second_msg, 2000));
}

struct HandlerControl {
HandlerControl()
: block(false)
Expand Down
Loading