diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 05a48aec2c7f3..313346b5deced 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -2164,6 +2164,43 @@ TEST(TestRecordBatchStreamReader, MalformedInput) { ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader)); } +namespace { +class EndlessCollectListener : public CollectListener { + public: + EndlessCollectListener() : CollectListener(), decoder_(nullptr) {} + + void SetDecoder(StreamDecoder* decoder) { decoder_ = decoder; } + + arrow::Status OnEOS() override { return decoder_->Reset(); } + + private: + StreamDecoder* decoder_; +}; +}; // namespace + +TEST(TestStreamDecoder, Reset) { + auto listener = std::make_shared(); + StreamDecoder decoder(listener); + listener->SetDecoder(&decoder); + + std::shared_ptr batch; + ASSERT_OK(MakeIntRecordBatch(&batch)); + StreamWriterHelper writer_helper; + ASSERT_OK(writer_helper.Init(batch->schema(), IpcWriteOptions::Defaults())); + ASSERT_OK(writer_helper.WriteBatch(batch)); + ASSERT_OK(writer_helper.Finish()); + + ASSERT_OK_AND_ASSIGN(auto all_buffer, ConcatenateBuffers({writer_helper.buffer_, + writer_helper.buffer_})); + // Consume by Buffer + ASSERT_OK(decoder.Consume(all_buffer)); + ASSERT_EQ(2, listener->num_record_batches()); + + // Consume by raw data + ASSERT_OK(decoder.Consume(all_buffer->data(), all_buffer->size())); + ASSERT_EQ(4, listener->num_record_batches()); +} + TEST(TestStreamDecoder, NextRequiredSize) { auto listener = std::make_shared(); StreamDecoder decoder(listener); diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index d603062d81d4a..5dd01f2015dd7 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -932,14 +932,18 @@ class StreamDecoderInternal : public MessageDecoderListener { return listener_->OnEOS(); } + std::shared_ptr listener() const { return listener_; } + Listener* raw_listener() const { return listener_.get(); } + IpcReadOptions options() const { return options_; } + + State state() const { return state_; } + std::shared_ptr schema() const { return filtered_schema_; } ReadStats stats() const { return stats_; } - State state() const { return state_; } - int num_required_initial_dictionaries() const { return num_required_initial_dictionaries_; } @@ -2039,6 +2043,8 @@ class StreamDecoder::StreamDecoderImpl : public StreamDecoderInternal { int64_t next_required_size() const { return message_decoder_.next_required_size(); } + const MessageDecoder* message_decoder() const { return &message_decoder_; } + private: MessageDecoder message_decoder_; }; @@ -2050,10 +2056,75 @@ StreamDecoder::StreamDecoder(std::shared_ptr listener, IpcReadOptions StreamDecoder::~StreamDecoder() {} Status StreamDecoder::Consume(const uint8_t* data, int64_t size) { - return impl_->Consume(data, size); + while (size > 0) { + const auto next_required_size = impl_->next_required_size(); + if (next_required_size == 0) { + break; + } + if (size < next_required_size) { + break; + } + ARROW_RETURN_NOT_OK(impl_->Consume(data, next_required_size)); + data += next_required_size; + size -= next_required_size; + } + if (size > 0) { + return impl_->Consume(data, size); + } else { + return arrow::Status::OK(); + } } + Status StreamDecoder::Consume(std::shared_ptr buffer) { - return impl_->Consume(std::move(buffer)); + if (buffer->size() == 0) { + return arrow::Status::OK(); + } + if (impl_->next_required_size() == 0 || buffer->size() <= impl_->next_required_size()) { + return impl_->Consume(std::move(buffer)); + } else { + int64_t offset = 0; + while (true) { + const auto next_required_size = impl_->next_required_size(); + if (next_required_size == 0) { + break; + } + if (buffer->size() - offset <= next_required_size) { + break; + } + if (buffer->is_cpu()) { + switch (impl_->message_decoder()->state()) { + case MessageDecoder::State::INITIAL: + case MessageDecoder::State::METADATA_LENGTH: + // We don't need to pass a sliced buffer because + // MessageDecoder doesn't keep reference of the given + // buffer on these states. + ARROW_RETURN_NOT_OK( + impl_->Consume(buffer->data() + offset, next_required_size)); + break; + default: + ARROW_RETURN_NOT_OK( + impl_->Consume(SliceBuffer(buffer, offset, next_required_size))); + break; + } + } else { + ARROW_RETURN_NOT_OK( + impl_->Consume(SliceBuffer(buffer, offset, next_required_size))); + } + offset += next_required_size; + } + if (buffer->size() - offset == 0) { + return arrow::Status::OK(); + } else if (offset == 0) { + return impl_->Consume(std::move(buffer)); + } else { + return impl_->Consume(SliceBuffer(std::move(buffer), offset)); + } + } +} + +Status StreamDecoder::Reset() { + impl_ = std::make_unique(impl_->listener(), impl_->options()); + return Status::OK(); } std::shared_ptr StreamDecoder::schema() const { return impl_->schema(); } diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index 0d7ae22264052..de4606094049c 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -425,6 +425,14 @@ class ARROW_EXPORT StreamDecoder { /// \return Status Status Consume(std::shared_ptr buffer); + /// \brief Reset the internal status. + /// + /// You can reuse this decoder for new stream after calling + /// this. + /// + /// \return Status + Status Reset(); + /// \return the shared schema of the record batches in the stream std::shared_ptr schema() const;