diff --git a/src/brpc/stream.cpp b/src/brpc/stream.cpp index 2a4430548f..a2a106a8b1 100644 --- a/src/brpc/stream.cpp +++ b/src/brpc/stream.cpp @@ -52,6 +52,7 @@ Stream::Stream() , _remote_consumed(0) , _cur_buf_size(0) , _local_consumed(0) + , _atomic_local_consumed(0) , _parse_rpc_response(false) , _pending_buf(NULL) , _start_idle_timer_us(0) @@ -287,7 +288,7 @@ void Stream::SetConnected(const StreamSettings* remote_settings) { CHECK(_host_socket != NULL); RPC_VLOG << "stream=" << id() << " is connected to stream_id=" << _remote_settings.stream_id() << " at host_socket=" << *_host_socket; - _connected = true; + _connected.store(true, butil::memory_order_release); _connect_meta.ec = 0; TriggerOnConnectIfNeed(); if (remote_settings == NULL) { @@ -295,6 +296,13 @@ void Stream::SetConnected(const StreamSettings* remote_settings) { // Client-side timer would triggered in Consume after received the first // message which is the very RPC response StartIdleTimer(); + } else { + // send first feedback for client-side stream if it already consumed data + if (_remote_settings.need_feedback()) { + auto consumed_bytes = _atomic_local_consumed.load(butil::memory_order_acquire); + if (consumed_bytes > 0) + SendFeedback(consumed_bytes); + } } } @@ -620,20 +628,34 @@ int Stream::Consume(void *meta, bthread::TaskIterator& iter) { } mb.flush(); - if (s->_remote_settings.need_feedback() && mb.total_length() > 0) { - s->_local_consumed += mb.total_length(); - s->SendFeedback(); + auto total_length = mb.total_length(); + if (total_length > 0) { + // fast path for connected stream + if (s->_connected.load(butil::memory_order_acquire)){ + if (s->_remote_settings.need_feedback()) { + s->_local_consumed += total_length; + s->SendFeedback(s->_local_consumed); + } + } else { + // Under the scenario of batch creation of Streams, there is concurrency between SetConnected and Consume for the same stream, + // and it is necessary to ensure the memory order. + s->_local_consumed = s->_atomic_local_consumed.fetch_add(total_length, butil::memory_order_release) + total_length; + if (s->_connected.load(butil::memory_order_acquire) && s->_remote_settings.need_feedback()) { + s->SendFeedback(s->_local_consumed); + } + } } + s->StartIdleTimer(); return 0; } -void Stream::SendFeedback() { +void Stream::SendFeedback(int64_t _consumed_bytes) { StreamFrameMeta fm; fm.set_frame_type(FRAME_TYPE_FEEDBACK); fm.set_stream_id(_remote_settings.stream_id()); fm.set_source_stream_id(id()); - fm.mutable_feedback()->set_consumed_size(_local_consumed); + fm.mutable_feedback()->set_consumed_size(_consumed_bytes); butil::IOBuf out; policy::PackStreamMessage(&out, fm, NULL); WriteToHostSocket(&out); diff --git a/src/brpc/stream_impl.h b/src/brpc/stream_impl.h index 5ff7cb04a2..284b33ca33 100644 --- a/src/brpc/stream_impl.h +++ b/src/brpc/stream_impl.h @@ -81,7 +81,7 @@ friend struct butil::DefaultDeleter; void TriggerOnConnectIfNeed(); void Wait(void (*on_writable)(StreamId, void*, int), void* arg, const timespec* due_time, bool new_thread, bthread_id_t *join_id); - void SendFeedback(); + void SendFeedback(int64_t _consumed_bytes); void StartIdleTimer(); void StopIdleTimer(); void HandleRpcResponse(butil::IOBuf* response_buffer); @@ -115,7 +115,7 @@ friend struct butil::DefaultDeleter; bthread_mutex_t _connect_mutex; ConnectMeta _connect_meta; - bool _connected; + butil::atomic _connected; bool _closed; int _error_code; std::string _error_text; @@ -127,7 +127,8 @@ friend struct butil::DefaultDeleter; bthread_id_list_t _writable_wait_list; int64_t _local_consumed; - StreamSettings _remote_settings; + butil::atomic _atomic_local_consumed; + StreamSettings _remote_settings; bool _parse_rpc_response; bthread::ExecutionQueueId _consumer_queue; diff --git a/test/brpc_streaming_rpc_unittest.cpp b/test/brpc_streaming_rpc_unittest.cpp index 056ea9a963..ecb88c6150 100644 --- a/test/brpc_streaming_rpc_unittest.cpp +++ b/test/brpc_streaming_rpc_unittest.cpp @@ -20,10 +20,12 @@ // Date: 2015/10/22 16:28:44 #include +#include #include "brpc/server.h" #include "brpc/controller.h" #include "brpc/channel.h" +#include "brpc/callback.h" #include "brpc/socket.h" #include "brpc/stream_impl.h" #include "brpc/policy/streaming_rpc_protocol.h" @@ -54,7 +56,7 @@ class MyServiceWithStream : public test::EchoService { const ::test::EchoRequest* request, ::test::EchoResponse* response, ::google::protobuf::Closure* done) { - brpc::ClosureGuard done_gurad(done); + brpc::ClosureGuard done_guard(done); response->set_message(request->message()); brpc::Controller* cntl = (brpc::Controller*)controller; brpc::StreamId response_stream; @@ -78,6 +80,158 @@ class StreamingRpcTest : public testing::Test { test::EchoResponse response; }; +struct BatchStreamFeedbackRaceState { + brpc::StreamId server_first_stream_id{brpc::INVALID_STREAM_ID}; + brpc::StreamId server_extra_stream_id{brpc::INVALID_STREAM_ID}; + brpc::StreamId client_extra_stream_id{brpc::INVALID_STREAM_ID}; + + std::atomic server_first_write_rc{-1}; + std::atomic server_second_write_rc{-1}; + std::atomic client_got_first_msg{false}; + std::atomic client_got_second_msg{false}; + std::atomic server_write_done{false}; + std::atomic rpc_done{false}; + + bthread_t server_send_tid{0}; + std::atomic server_send_started{false}; +}; + +class BatchStreamClientHandler : public brpc::StreamInputHandler { +public: + explicit BatchStreamClientHandler(BatchStreamFeedbackRaceState* state) + : _state(state) {} + + int on_received_messages(brpc::StreamId id, + butil::IOBuf* const messages[], + size_t size) override { + if (id != _state->client_extra_stream_id) { + // This test only cares about extra stream in batch creation. + return 0; + } + for (size_t i = 0; i < size; ++i) { + const size_t len = messages[i]->length(); + messages[i]->clear(); + // First payload: 64 bytes. Second payload: 1 byte. + if (len == 64) { + _state->client_got_first_msg.store(true, std::memory_order_release); + } else if (len == 1) { + _state->client_got_second_msg.store(true, std::memory_order_release); + } + } + return 0; + } + + void on_idle_timeout(brpc::StreamId /*id*/) override {} + + void on_closed(brpc::StreamId /*id*/) override {} + + void on_failed(brpc::StreamId /*id*/, int /*error_code*/, const std::string& /*error_text*/) override {} + +private: + BatchStreamFeedbackRaceState* _state; +}; + +static void* SendTwoMessagesOnServerExtraStream(void* arg) { + auto* state = static_cast(arg); + const brpc::StreamId sid = state->server_extra_stream_id; + + // Wait until server-side stream is connected. + const int64_t connect_deadline_us = butil::gettimeofday_us() + 2 * 1000 * 1000L; + bool connected = false; + while (butil::gettimeofday_us() < connect_deadline_us) { + brpc::SocketUniquePtr ptr; + if (brpc::Socket::Address(sid, &ptr) == 0) { + brpc::Stream* s = static_cast(ptr->conn()); + if (s->_host_socket != NULL && s->_connected) { + connected = true; + break; + } + } + usleep(1000); + } + + if (!connected) { + state->server_first_write_rc.store(ETIMEDOUT, std::memory_order_relaxed); + state->server_second_write_rc.store(ETIMEDOUT, std::memory_order_relaxed); + state->server_write_done.store(true, std::memory_order_release); + return NULL; + } + + // 1) Send a payload exactly equal to max_buf_size(64). + { + std::string payload(64, 'a'); + butil::IOBuf out; + out.append(payload); + state->server_first_write_rc.store(brpc::StreamWrite(sid, out), std::memory_order_relaxed); + } + + // 2) Then send another byte. This write should become writable only after + // client sends FEEDBACK with consumed_size >= 64. + const int64_t write_deadline_us = butil::gettimeofday_us() + 2 * 1000 * 1000L; + int rc = -1; + while (butil::gettimeofday_us() < write_deadline_us) { + butil::IOBuf out; + out.append("b", 1); + rc = brpc::StreamWrite(sid, out); + if (rc == 0) { + break; + } + if (rc != EAGAIN) { + break; + } + const timespec duetime = butil::milliseconds_from_now(100); + (void)brpc::StreamWait(sid, &duetime); + } + state->server_second_write_rc.store(rc, std::memory_order_relaxed); + state->server_write_done.store(true, std::memory_order_release); + return NULL; +} + +class MyServiceWithBatchStream : public test::EchoService { +public: + MyServiceWithBatchStream(const brpc::StreamOptions& options, + BatchStreamFeedbackRaceState* state) + : _options(options), _state(state) {} + + void Echo(::google::protobuf::RpcController* controller, + const ::test::EchoRequest* request, + ::test::EchoResponse* response, + ::google::protobuf::Closure* done) override { + brpc::ClosureGuard done_guard(done); + response->set_message(request->message()); + brpc::Controller* cntl = static_cast(controller); + + brpc::StreamIds response_streams; + ASSERT_EQ(0, brpc::StreamAccept(response_streams, *cntl, &_options)); + ASSERT_EQ(2u, response_streams.size()); + _state->server_first_stream_id = response_streams[0]; + _state->server_extra_stream_id = response_streams[1]; + + bthread_t tid; + ASSERT_EQ(0, bthread_start_background( + &tid, &BTHREAD_ATTR_NORMAL, + SendTwoMessagesOnServerExtraStream, _state)); + _state->server_send_tid = tid; + _state->server_send_started.store(true, std::memory_order_release); + } + +private: + brpc::StreamOptions _options; + BatchStreamFeedbackRaceState* _state; +}; + +static void SetAtomicTrue(std::atomic* f) { + f->store(true, std::memory_order_release); +} + +static bool WaitForTrue(const std::atomic& f, int timeout_ms) { + const int64_t deadline_us = butil::gettimeofday_us() + (int64_t)timeout_ms * 1000L; + while (!f.load(std::memory_order_acquire) && butil::gettimeofday_us() < deadline_us) { + usleep(1000); + } + return f.load(std::memory_order_acquire); +} + TEST_F(StreamingRpcTest, sanity) { brpc::Server server; MyServiceWithStream service; @@ -98,6 +252,86 @@ TEST_F(StreamingRpcTest, sanity) { server.Join(); } +TEST_F(StreamingRpcTest, batch_create_stream_feedback_race) { + BatchStreamFeedbackRaceState state; + BatchStreamClientHandler client_handler(&state); + + brpc::StreamOptions server_stream_opt; + // Make server-side sender sensitive to FEEDBACK quickly. + server_stream_opt.max_buf_size = 16; + + brpc::Server server; + MyServiceWithBatchStream service(server_stream_opt, &state); + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + ASSERT_EQ(0, server.Start(9007, NULL)); + + brpc::Channel channel; + ASSERT_EQ(0, channel.Init("127.0.0.1:9007", NULL)); + + brpc::Controller cntl; + brpc::StreamIds request_streams; + brpc::StreamOptions client_stream_opt; + client_stream_opt.handler = &client_handler; + client_stream_opt.max_buf_size = 0; + ASSERT_EQ(0, brpc::StreamCreate(request_streams, 2, cntl, &client_stream_opt)); + ASSERT_EQ(2u, request_streams.size()); + state.client_extra_stream_id = request_streams[1]; + + // Block SetConnected() on the extra stream to enlarge the race window. + brpc::SocketUniquePtr client_extra_ptr; + ASSERT_EQ(0, brpc::Socket::Address(state.client_extra_stream_id, &client_extra_ptr)); + brpc::Stream* client_extra_stream = static_cast(client_extra_ptr->conn()); + bthread_mutex_lock(&client_extra_stream->_connect_mutex); + struct UnlockGuard { + bthread_mutex_t* m; + ~UnlockGuard() { + if (m) { + bthread_mutex_unlock(m); + } + } + } unlock_guard{&client_extra_stream->_connect_mutex}; + + BRPC_SCOPE_EXIT { + if (state.server_extra_stream_id != brpc::INVALID_STREAM_ID) { + brpc::StreamClose(state.server_extra_stream_id); + } + if (state.server_first_stream_id != brpc::INVALID_STREAM_ID) { + brpc::StreamClose(state.server_first_stream_id); + } + for (auto sid : request_streams) { + brpc::StreamClose(sid); + } + + if (state.server_send_tid) { + bthread_join(state.server_send_tid, NULL); + } + server.Stop(0); + server.Join(); + }; + + test::EchoService_Stub stub(&channel); + stub.Echo(&cntl, &request, &response, brpc::NewCallback(SetAtomicTrue, &state.rpc_done)); + + // Wait until client consumes the first 64B payload on extra stream. + ASSERT_TRUE(WaitForTrue(state.client_got_first_msg, 2000)); + + // Unblock SetConnected(); the fix in PR 3215 should send the first FEEDBACK + // with consumed_size=64 here, making server-side stream writable again. + bthread_mutex_unlock(&client_extra_stream->_connect_mutex); + unlock_guard.m = NULL; + + ASSERT_TRUE(WaitForTrue(state.rpc_done, 2000)); + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); + + // Wait for server-side send thread to be started. + ASSERT_TRUE(WaitForTrue(state.server_send_started, 2000)); + + ASSERT_TRUE(WaitForTrue(state.server_write_done, 2000)); + ASSERT_EQ(0, state.server_first_write_rc.load(std::memory_order_relaxed)); + ASSERT_EQ(0, state.server_second_write_rc.load(std::memory_order_relaxed)); + ASSERT_TRUE(WaitForTrue(state.client_got_second_msg, 2000)); +} + struct HandlerControl { HandlerControl() : block(false)