From 5f37a6c3822c0cf4a43033cd60eb4115a8e9f6c6 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Mon, 11 Dec 2023 00:36:10 +0900 Subject: [PATCH] GH-39163: [C++] Add missing data copy in StreamDecoder::Consume(data) We need to copy data for metadata message. Because it may be used in subsequent `Consume(data)` calls. We can't assume that the given `data` is still valid in subsequent `Consume(data)`. We also need to copy buffered `data` because it's used in subsequent `Consume(data)` calls. --- cpp/src/arrow/ipc/message.cc | 37 ++++++++++++++++++++-------- cpp/src/arrow/ipc/read_write_test.cc | 11 ++++++--- cpp/src/arrow/ipc/reader.cc | 33 +++++++++++++++++++++++++ cpp/src/arrow/ipc/reader.h | 14 ++++++----- 4 files changed, 76 insertions(+), 19 deletions(-) diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc index 36754518d29d4..fbcd6f139b6d2 100644 --- a/cpp/src/arrow/ipc/message.cc +++ b/cpp/src/arrow/ipc/message.cc @@ -626,10 +626,24 @@ class MessageDecoder::MessageDecoderImpl { RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_)); break; case State::METADATA: { - auto buffer = std::make_shared(data, next_required_size_); + // We need to copy metadata because it's used in + // ConsumeBody(). ConsumeBody() may be called from another + // ConsumeData(). We can't assume that the given data for + // the current ConsumeData() call is still valid in the + // next ConsumeData() call. So we need to copy metadata + // here. + ARROW_ASSIGN_OR_RAISE(std::shared_ptr buffer, + AllocateBuffer(next_required_size_, pool_)); + memcpy(buffer->mutable_data(), data, next_required_size_); RETURN_NOT_OK(ConsumeMetadataBuffer(buffer)); } break; case State::BODY: { + // We don't need to copy the given data for body because + // we can assume that a decoded record batch should be + // valid only in a listener_->OnMessageDecoded() call. If + // the passed message is needed to be valid after the + // call, it's a listener_'s responsibility. The listener_ + // may copy the data for it. auto buffer = std::make_shared(data, next_required_size_); RETURN_NOT_OK(ConsumeBodyBuffer(buffer)); } break; @@ -645,7 +659,12 @@ class MessageDecoder::MessageDecoderImpl { return Status::OK(); } - chunks_.push_back(std::make_shared(data, size)); + // We need to copy unused data because the given data for the + // current ConsumeData() call may be invalid in the next + // ConsumeData() call. + ARROW_ASSIGN_OR_RAISE(std::shared_ptr chunk, AllocateBuffer(size, pool_)); + memcpy(chunk->mutable_data(), data, size); + chunks_.push_back(std::move(chunk)); buffered_size_ += size; return ConsumeChunks(); } @@ -830,8 +849,7 @@ class MessageDecoder::MessageDecoderImpl { } buffered_size_ -= next_required_size_; } else { - ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_)); - metadata_ = std::shared_ptr(metadata.release()); + ARROW_ASSIGN_OR_RAISE(metadata_, AllocateBuffer(next_required_size_, pool_)); RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data())); } return ConsumeMetadata(); @@ -846,9 +864,8 @@ class MessageDecoder::MessageDecoderImpl { next_required_size_ = skip_body_ ? 0 : body_length; RETURN_NOT_OK(listener_->OnBody()); if (next_required_size_ == 0) { - ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_)); - std::shared_ptr shared_body(body.release()); - return ConsumeBody(&shared_body); + auto body = std::make_shared(nullptr, 0); + return ConsumeBody(&body); } else { return Status::OK(); } @@ -872,10 +889,10 @@ class MessageDecoder::MessageDecoderImpl { buffered_size_ -= used_size; return Status::OK(); } else { - ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr body, + AllocateBuffer(next_required_size_, pool_)); RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data())); - std::shared_ptr shared_body(body.release()); - return ConsumeBody(&shared_body); + return ConsumeBody(&body); } } diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 17c4c5636d5b0..e77c760d6c716 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -1334,7 +1334,7 @@ struct StreamDecoderWriterHelper : public StreamWriterHelper { Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches, ReadStats* out_stats = nullptr, MetadataVector* out_metadata_list = nullptr) override { - auto listener = std::make_shared(); + auto listener = std::make_shared(true); StreamDecoder decoder(listener, options); RETURN_NOT_OK(DoConsume(&decoder)); *out_batches = listener->record_batches(); @@ -1358,7 +1358,10 @@ struct StreamDecoderWriterHelper : public StreamWriterHelper { struct StreamDecoderDataWriterHelper : public StreamDecoderWriterHelper { Status DoConsume(StreamDecoder* decoder) override { - return decoder->Consume(buffer_->data(), buffer_->size()); + // This data is valid only in this function. + ARROW_ASSIGN_OR_RAISE(auto temporary_buffer, + Buffer::Copy(buffer_, arrow::default_cpu_memory_manager())); + return decoder->Consume(temporary_buffer->data(), temporary_buffer->size()); } }; @@ -1369,7 +1372,9 @@ struct StreamDecoderBufferWriterHelper : public StreamDecoderWriterHelper { struct StreamDecoderSmallChunksWriterHelper : public StreamDecoderWriterHelper { Status DoConsume(StreamDecoder* decoder) override { for (int64_t offset = 0; offset < buffer_->size() - 1; ++offset) { - RETURN_NOT_OK(decoder->Consume(buffer_->data() + offset, 1)); + // This data is valid only in this block. + ARROW_ASSIGN_OR_RAISE(auto temporary_buffer, buffer_->CopySlice(offset, 1)); + RETURN_NOT_OK(decoder->Consume(temporary_buffer->data(), temporary_buffer->size())); } return Status::OK(); } diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index d272c78560f82..fe7789afa167e 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -2052,6 +2052,39 @@ Status Listener::OnRecordBatchWithMetadataDecoded( return OnRecordBatchDecoded(std::move(record_batch_with_metadata.batch)); } +namespace { +Status CopyArrayData(std::shared_ptr data) { + auto& buffers = data->buffers; + for (size_t i = 0; i < buffers.size(); ++i) { + auto& buffer = buffers[i]; + if (!buffer) { + continue; + } + ARROW_ASSIGN_OR_RAISE(buffers[i], Buffer::Copy(buffer, buffer->memory_manager())); + } + for (auto child_data : data->child_data) { + ARROW_RETURN_NOT_OK(CopyArrayData(child_data)); + } + if (data->dictionary) { + ARROW_RETURN_NOT_OK(CopyArrayData(data->dictionary)); + } + return Status::OK(); +} +}; // namespace + +Status CollectListener::OnRecordBatchWithMetadataDecoded( + RecordBatchWithMetadata record_batch_with_metadata) { + auto record_batch = std::move(record_batch_with_metadata.batch); + if (copy_record_batch_) { + for (auto column_data : record_batch->column_data()) { + ARROW_RETURN_NOT_OK(CopyArrayData(column_data)); + } + } + record_batches_.push_back(std::move(record_batch)); + metadatas_.push_back(std::move(record_batch_with_metadata.custom_metadata)); + return Status::OK(); +} + class StreamDecoder::StreamDecoderImpl : public StreamDecoderInternal { public: explicit StreamDecoderImpl(std::shared_ptr listener, IpcReadOptions options) diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index 888f59a627771..8eb83844130da 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -317,7 +317,12 @@ class ARROW_EXPORT Listener { /// \since 0.17.0 class ARROW_EXPORT CollectListener : public Listener { public: - CollectListener() : schema_(), filtered_schema_(), record_batches_(), metadatas_() {} + explicit CollectListener(bool copy_record_batch = false) + : copy_record_batch_(copy_record_batch), + schema_(), + filtered_schema_(), + record_batches_(), + metadatas_() {} virtual ~CollectListener() = default; Status OnSchemaDecoded(std::shared_ptr schema, @@ -328,11 +333,7 @@ class ARROW_EXPORT CollectListener : public Listener { } Status OnRecordBatchWithMetadataDecoded( - RecordBatchWithMetadata record_batch_with_metadata) override { - record_batches_.push_back(std::move(record_batch_with_metadata.batch)); - metadatas_.push_back(std::move(record_batch_with_metadata.custom_metadata)); - return Status::OK(); - } + RecordBatchWithMetadata record_batch_with_metadata) override; /// \return the decoded schema std::shared_ptr schema() const { return schema_; } @@ -375,6 +376,7 @@ class ARROW_EXPORT CollectListener : public Listener { } private: + bool copy_record_batch_; std::shared_ptr schema_; std::shared_ptr filtered_schema_; std::vector> record_batches_;