diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 7de81eff7a775..ae5fde39d8550 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -2082,29 +2082,28 @@ TEST(TestRecordBatchStreamReader, NotEnoughDictionaries) { // error ASSERT_OK_AND_ASSIGN(auto buffer, out->Finish()); - auto AssertFailsWith = [](std::shared_ptr stream, const std::string& ex_error) { + auto Read = [](std::shared_ptr stream) -> Status { io::BufferReader reader(stream); - ASSERT_OK_AND_ASSIGN(auto ipc_reader, RecordBatchStreamReader::Open(&reader)); + ARROW_ASSIGN_OR_RAISE(auto ipc_reader, RecordBatchStreamReader::Open(&reader)); std::shared_ptr batch; - Status s = ipc_reader->ReadNext(&batch); - ASSERT_TRUE(s.IsInvalid()); - ASSERT_EQ(ex_error, s.message().substr(0, ex_error.size())); + return ipc_reader->ReadNext(&batch); }; // Stream terminates before reading all dictionaries std::shared_ptr truncated_stream; SpliceMessages(buffer, {0, 1}, &truncated_stream); - std::string ex_message = - ("IPC stream ended without reading the expected number (3)" - " of dictionaries"); - AssertFailsWith(truncated_stream, ex_message); + ASSERT_RAISES_WITH_MESSAGE(Invalid, + "Invalid: IPC stream ended without " + "reading the expected number (3) of dictionaries", + Read(truncated_stream)); // One of the dictionaries is missing, then we see a record batch SpliceMessages(buffer, {0, 1, 2, 4}, &truncated_stream); - ex_message = - ("IPC stream did not have the expected number (3) of dictionaries " - "at the start of the stream"); - AssertFailsWith(truncated_stream, ex_message); + ASSERT_RAISES_WITH_MESSAGE(Invalid, + "Invalid: IPC stream did not have " + "the expected number (3) of dictionaries " + "at the start of the stream", + Read(truncated_stream)); } TEST(TestRecordBatchStreamReader, MalformedInput) { diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 694cc732253b3..6a5ab4598e72d 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -849,88 +849,113 @@ Status ReadDictionary(const Message& message, const IpcReadContext& context, return ReadDictionary(*message.metadata(), context, kind, reader.get()); } -// ---------------------------------------------------------------------- -// RecordBatchStreamReader implementation - -class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { +// Streaming format decoder +class StreamDecoderInternal : public MessageDecoderListener { public: - Status Open(std::unique_ptr message_reader, - const IpcReadOptions& options) { - message_reader_ = std::move(message_reader); - options_ = options; + enum State { + SCHEMA, + INITIAL_DICTIONARIES, + RECORD_BATCHES, + EOS, + }; - // Read schema - ARROW_ASSIGN_OR_RAISE(std::unique_ptr message, ReadNextMessage()); - if (!message) { - return Status::Invalid("Tried reading schema message, was null or length 0"); - } + explicit StreamDecoderInternal(std::shared_ptr listener, + IpcReadOptions options) + : listener_(std::move(listener)), + options_(std::move(options)), + state_(State::SCHEMA), + field_inclusion_mask_(), + num_required_initial_dictionaries_(0), + num_read_initial_dictionaries_(0), + dictionary_memo_(), + schema_(nullptr), + out_schema_(nullptr), + stats_(), + swap_endian_(false) {} - RETURN_NOT_OK(UnpackSchemaMessage(*message, options, &dictionary_memo_, &schema_, - &out_schema_, &field_inclusion_mask_, - &swap_endian_)); + Status OnMessageDecoded(std::unique_ptr message) override { + ++stats_.num_messages; + switch (state_) { + case State::SCHEMA: + ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message))); + break; + case State::INITIAL_DICTIONARIES: + ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message))); + break; + case State::RECORD_BATCHES: + ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message))); + break; + case State::EOS: + break; + } return Status::OK(); } - Status ReadNext(std::shared_ptr* batch) override { - ARROW_ASSIGN_OR_RAISE(auto batch_with_metadata, ReadNext()); - *batch = std::move(batch_with_metadata.batch); - return Status::OK(); + Status OnEOS() override { + state_ = State::EOS; + return listener_->OnEOS(); } - Result ReadNext() override { - if (!have_read_initial_dictionaries_) { - RETURN_NOT_OK(ReadInitialDictionaries()); - } - - RecordBatchWithMetadata batch_with_metadata; - if (empty_stream_) { - // ARROW-6006: Degenerate case where stream contains no data, we do not - // bother trying to read a RecordBatch message from the stream - return batch_with_metadata; - } + Listener* raw_listener() const { return listener_.get(); } - // Continue to read other dictionaries, if any - std::unique_ptr message; - ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage()); + std::shared_ptr schema() const { return out_schema_; } - while (message != nullptr && message->type() == MessageType::DICTIONARY_BATCH) { - RETURN_NOT_OK(ReadDictionary(*message)); - ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage()); - } + ReadStats stats() const { return stats_; } - if (message == nullptr) { - // End of stream - return batch_with_metadata; - } + State state() const { return state_; } - CHECK_HAS_BODY(*message); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); - IpcReadContext context(&dictionary_memo_, options_, swap_endian_); - return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, - context, reader.get()); + int num_required_initial_dictionaries() const { + return num_required_initial_dictionaries_; } - std::shared_ptr schema() const override { return out_schema_; } - - ReadStats stats() const override { return stats_; } + int num_read_initial_dictionaries() const { return num_read_initial_dictionaries_; } private: - Result> ReadNextMessage() { - ARROW_ASSIGN_OR_RAISE(auto message, message_reader_->ReadNextMessage()); - if (message) { - ++stats_.num_messages; - switch (message->type()) { - case MessageType::RECORD_BATCH: - ++stats_.num_record_batches; - break; - case MessageType::DICTIONARY_BATCH: - ++stats_.num_dictionary_batches; - break; - default: - break; - } + Status OnSchemaMessageDecoded(std::unique_ptr message) { + RETURN_NOT_OK(UnpackSchemaMessage(*message, options_, &dictionary_memo_, &schema_, + &out_schema_, &field_inclusion_mask_, + &swap_endian_)); + + num_required_initial_dictionaries_ = dictionary_memo_.fields().num_dicts(); + num_read_initial_dictionaries_ = 0; + if (num_required_initial_dictionaries_ == 0) { + state_ = State::RECORD_BATCHES; + RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_)); + } else { + state_ = State::INITIAL_DICTIONARIES; + } + return Status::OK(); + } + + Status OnInitialDictionaryMessageDecoded(std::unique_ptr message) { + if (message->type() != MessageType::DICTIONARY_BATCH) { + return Status::Invalid("IPC stream did not have the expected number (", + num_required_initial_dictionaries_, + ") of dictionaries at the start of the stream"); + } + RETURN_NOT_OK(ReadDictionary(*message)); + num_read_initial_dictionaries_++; + if (num_read_initial_dictionaries_ == num_required_initial_dictionaries_) { + state_ = State::RECORD_BATCHES; + ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_)); + } + return Status::OK(); + } + + Status OnRecordBatchMessageDecoded(std::unique_ptr message) { + if (message->type() == MessageType::DICTIONARY_BATCH) { + return ReadDictionary(*message); + } else { + CHECK_HAS_BODY(*message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); + IpcReadContext context(&dictionary_memo_, options_, swap_endian_); + ARROW_ASSIGN_OR_RAISE( + auto batch_with_metadata, + ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, + context, reader.get())); + ++stats_.num_record_batches; + return listener_->OnRecordBatchWithMetadataDecoded(batch_with_metadata); } - return std::move(message); } // Read dictionary from dictionary batch @@ -938,6 +963,7 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { DictionaryKind kind; IpcReadContext context(&dictionary_memo_, options_, swap_endian_); RETURN_NOT_OK(::arrow::ipc::ReadDictionary(message, context, &kind)); + ++stats_.num_dictionary_batches; switch (kind) { case DictionaryKind::New: break; @@ -951,60 +977,86 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { return Status::OK(); } - Status ReadInitialDictionaries() { - // We must receive all dictionaries before reconstructing the - // first record batch. Subsequent dictionary deltas modify the memo - std::unique_ptr message; - - // TODO(wesm): In future, we may want to reconcile the ids in the stream with - // those found in the schema - const auto num_dicts = dictionary_memo_.fields().num_dicts(); - for (int i = 0; i < num_dicts; ++i) { - ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage()); - if (!message) { - if (i == 0) { - /// ARROW-6006: If we fail to find any dictionaries in the stream, then - /// it may be that the stream has a schema but no actual data. In such - /// case we communicate that we were unable to find the dictionaries - /// (but there was no failure otherwise), so the caller can decide what - /// to do - empty_stream_ = true; - break; - } else { - // ARROW-6126, the stream terminated before receiving the expected - // number of dictionaries - return Status::Invalid("IPC stream ended without reading the expected number (", - num_dicts, ") of dictionaries"); - } - } + std::shared_ptr listener_; + const IpcReadOptions options_; + State state_; + std::vector field_inclusion_mask_; + int num_required_initial_dictionaries_; + int num_read_initial_dictionaries_; + DictionaryMemo dictionary_memo_; + std::shared_ptr schema_; + std::shared_ptr out_schema_; + ReadStats stats_; + bool swap_endian_; +}; - if (message->type() != MessageType::DICTIONARY_BATCH) { - return Status::Invalid("IPC stream did not have the expected number (", num_dicts, - ") of dictionaries at the start of the stream"); - } - RETURN_NOT_OK(ReadDictionary(*message)); +// ---------------------------------------------------------------------- +// RecordBatchStreamReader implementation + +class RecordBatchStreamReaderImpl : public RecordBatchStreamReader, + public StreamDecoderInternal { + public: + RecordBatchStreamReaderImpl(std::unique_ptr message_reader, + const IpcReadOptions& options) + : RecordBatchStreamReader(), + StreamDecoderInternal(std::make_shared(), options), + message_reader_(std::move(message_reader)) {} + + Status Init() { + // Read schema + ARROW_ASSIGN_OR_RAISE(auto message, message_reader_->ReadNextMessage()); + if (!message) { + return Status::Invalid("Tried reading schema message, was null or length 0"); } + return OnMessageDecoded(std::move(message)); + } - have_read_initial_dictionaries_ = true; + Status ReadNext(std::shared_ptr* batch) override { + ARROW_ASSIGN_OR_RAISE(auto batch_with_metadata, ReadNext()); + *batch = std::move(batch_with_metadata.batch); return Status::OK(); } - std::unique_ptr message_reader_; - IpcReadOptions options_; - std::vector field_inclusion_mask_; - - bool have_read_initial_dictionaries_ = false; - - // Flag to set in case where we fail to observe all dictionaries in a stream, - // and so the reader should not attempt to parse any messages - bool empty_stream_ = false; + Result ReadNext() override { + auto collect_listener = checked_cast(raw_listener()); + while (collect_listener->num_record_batches() == 0 && + state() != StreamDecoderInternal::State::EOS) { + ARROW_ASSIGN_OR_RAISE(auto message, message_reader_->ReadNextMessage()); + if (!message) { // End of stream + if (state() == StreamDecoderInternal::State::INITIAL_DICTIONARIES) { + if (num_read_initial_dictionaries() == 0) { + // ARROW-6006: If we fail to find any dictionaries in the + // stream, then it may be that the stream has a schema + // but no actual data. In such case we communicate that + // we were unable to find the dictionaries (but there was + // no failure otherwise), so the caller can decide what + // to do + return RecordBatchWithMetadata{nullptr, nullptr}; + } else { + // ARROW-6126, the stream terminated before receiving the + // expected number of dictionaries + return Status::Invalid( + "IPC stream ended without reading the " + "expected number (", + num_required_initial_dictionaries(), ") of dictionaries"); + } + } else { + return RecordBatchWithMetadata{nullptr, nullptr}; + } + } + ARROW_RETURN_NOT_OK(OnMessageDecoded(std::move(message))); + } + return collect_listener->PopRecordBatchWithMetadata(); + } - ReadStats stats_; + std::shared_ptr schema() const override { + return StreamDecoderInternal::schema(); + } - DictionaryMemo dictionary_memo_; - std::shared_ptr schema_, out_schema_; + ReadStats stats() const override { return StreamDecoderInternal::stats(); } - bool swap_endian_; + private: + std::unique_ptr message_reader_; }; // ---------------------------------------------------------------------- @@ -1013,8 +1065,9 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { Result> RecordBatchStreamReader::Open( std::unique_ptr message_reader, const IpcReadOptions& options) { // Private ctor - auto result = std::make_shared(); - RETURN_NOT_OK(result->Open(std::move(message_reader), options)); + auto result = + std::make_shared(std::move(message_reader), options); + RETURN_NOT_OK(result->Init()); return result; } @@ -1907,46 +1960,17 @@ Status Listener::OnRecordBatchDecoded(std::shared_ptr record_batch) return Status::NotImplemented("OnRecordBatchDecoded() callback isn't implemented"); } -class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener { - private: - enum State { - SCHEMA, - INITIAL_DICTIONARIES, - RECORD_BATCHES, - EOS, - }; +Status Listener::OnRecordBatchWithMetadataDecoded( + RecordBatchWithMetadata record_batch_with_metadata) { + return OnRecordBatchDecoded(std::move(record_batch_with_metadata.batch)); +} +class StreamDecoder::StreamDecoderImpl : public StreamDecoderInternal { public: explicit StreamDecoderImpl(std::shared_ptr listener, IpcReadOptions options) - : listener_(std::move(listener)), - options_(std::move(options)), - state_(State::SCHEMA), + : StreamDecoderInternal(std::move(listener), options), message_decoder_(std::shared_ptr(this, [](void*) {}), - options_.memory_pool), - n_required_dictionaries_(0) {} - - Status OnMessageDecoded(std::unique_ptr message) override { - ++stats_.num_messages; - switch (state_) { - case State::SCHEMA: - ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message))); - break; - case State::INITIAL_DICTIONARIES: - ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message))); - break; - case State::RECORD_BATCHES: - ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message))); - break; - case State::EOS: - break; - } - return Status::OK(); - } - - Status OnEOS() override { - state_ = State::EOS; - return listener_->OnEOS(); - } + options.memory_pool) {} Status Consume(const uint8_t* data, int64_t size) { return message_decoder_.Consume(data, size); @@ -1956,88 +1980,10 @@ class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener { return message_decoder_.Consume(std::move(buffer)); } - std::shared_ptr schema() const { return out_schema_; } - int64_t next_required_size() const { return message_decoder_.next_required_size(); } - ReadStats stats() const { return stats_; } - private: - Status OnSchemaMessageDecoded(std::unique_ptr message) { - RETURN_NOT_OK(UnpackSchemaMessage(*message, options_, &dictionary_memo_, &schema_, - &out_schema_, &field_inclusion_mask_, - &swap_endian_)); - - n_required_dictionaries_ = dictionary_memo_.fields().num_fields(); - if (n_required_dictionaries_ == 0) { - state_ = State::RECORD_BATCHES; - RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_)); - } else { - state_ = State::INITIAL_DICTIONARIES; - } - return Status::OK(); - } - - Status OnInitialDictionaryMessageDecoded(std::unique_ptr message) { - if (message->type() != MessageType::DICTIONARY_BATCH) { - return Status::Invalid("IPC stream did not have the expected number (", - dictionary_memo_.fields().num_fields(), - ") of dictionaries at the start of the stream"); - } - RETURN_NOT_OK(ReadDictionary(*message)); - n_required_dictionaries_--; - if (n_required_dictionaries_ == 0) { - state_ = State::RECORD_BATCHES; - ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_)); - } - return Status::OK(); - } - - Status OnRecordBatchMessageDecoded(std::unique_ptr message) { - if (message->type() == MessageType::DICTIONARY_BATCH) { - return ReadDictionary(*message); - } else { - CHECK_HAS_BODY(*message); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); - IpcReadContext context(&dictionary_memo_, options_, swap_endian_); - ARROW_ASSIGN_OR_RAISE( - auto batch_with_metadata, - ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, - context, reader.get())); - ++stats_.num_record_batches; - return listener_->OnRecordBatchDecoded(std::move(batch_with_metadata.batch)); - } - } - - // Read dictionary from dictionary batch - Status ReadDictionary(const Message& message) { - DictionaryKind kind; - IpcReadContext context(&dictionary_memo_, options_, swap_endian_); - RETURN_NOT_OK(::arrow::ipc::ReadDictionary(message, context, &kind)); - ++stats_.num_dictionary_batches; - switch (kind) { - case DictionaryKind::New: - break; - case DictionaryKind::Delta: - ++stats_.num_dictionary_deltas; - break; - case DictionaryKind::Replacement: - ++stats_.num_replaced_dictionaries; - break; - } - return Status::OK(); - } - - std::shared_ptr listener_; - const IpcReadOptions options_; - State state_; MessageDecoder message_decoder_; - std::vector field_inclusion_mask_; - int n_required_dictionaries_; - DictionaryMemo dictionary_memo_; - std::shared_ptr schema_, out_schema_; - ReadStats stats_; - bool swap_endian_; }; StreamDecoder::StreamDecoder(std::shared_ptr listener, IpcReadOptions options) { diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index ad7969b31c991..edc25608542f1 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -251,7 +251,8 @@ class ARROW_EXPORT Listener { /// \see StreamDecoder virtual Status OnEOS(); - /// \brief Called when a record batch is decoded. + /// \brief Called when a record batch is decoded and + /// OnRecordBatchWithMetadataDecoded() isn't overrided. /// /// The default implementation just returns /// arrow::Status::NotImplemented(). @@ -262,6 +263,19 @@ class ARROW_EXPORT Listener { /// \see StreamDecoder virtual Status OnRecordBatchDecoded(std::shared_ptr record_batch); + /// \brief Called when a record batch with custom metadata is decoded. + /// + /// The default implementation just calls OnRecordBatchDecoded() + /// without custom metadata. + /// + /// \param[in] record_batch_with_metadata a record batch with custom + /// metadata decoded + /// \return Status + /// + /// \see StreamDecoder + virtual Status OnRecordBatchWithMetadataDecoded( + RecordBatchWithMetadata record_batch_with_metadata); + /// \brief Called when a schema is decoded. /// /// The default implementation just returns arrow::Status::OK(). @@ -280,7 +294,7 @@ class ARROW_EXPORT Listener { /// \since 0.17.0 class ARROW_EXPORT CollectListener : public Listener { public: - CollectListener() : schema_(), record_batches_() {} + CollectListener() : schema_(), record_batches_(), metadatas_() {} virtual ~CollectListener() = default; Status OnSchemaDecoded(std::shared_ptr schema) override { @@ -288,8 +302,10 @@ class ARROW_EXPORT CollectListener : public Listener { return Status::OK(); } - Status OnRecordBatchDecoded(std::shared_ptr record_batch) override { - record_batches_.push_back(std::move(record_batch)); + Status OnRecordBatchWithMetadataDecoded( + RecordBatchWithMetadata record_batch_with_metadata) override { + record_batches_.push_back(std::move(record_batch_with_metadata.batch)); + metadatas_.push_back(std::move(record_batch_with_metadata.custom_metadata)); return Status::OK(); } @@ -297,13 +313,43 @@ class ARROW_EXPORT CollectListener : public Listener { std::shared_ptr schema() const { return schema_; } /// \return the all decoded record batches - std::vector> record_batches() const { + const std::vector>& record_batches() const { return record_batches_; } + /// \return the all decoded metadatas + const std::vector>& metadatas() const { + return metadatas_; + } + + /// \return the number of collected record batches + int64_t num_record_batches() const { return record_batches_.size(); } + + /// \return the last decoded record batch and remove it from + /// record_batches + std::shared_ptr PopRecordBatch() { + auto record_batch_with_metadata = PopRecordBatchWithMetadata(); + return std::move(record_batch_with_metadata.batch); + } + + /// \return the last decoded record batch with custom metadata and + /// remove it from record_batches + RecordBatchWithMetadata PopRecordBatchWithMetadata() { + RecordBatchWithMetadata record_batch_with_metadata; + if (record_batches_.empty()) { + return record_batch_with_metadata; + } + record_batch_with_metadata.batch = std::move(record_batches_.back()); + record_batch_with_metadata.custom_metadata = std::move(metadatas_.back()); + record_batches_.pop_back(); + metadatas_.pop_back(); + return record_batch_with_metadata; + } + private: std::shared_ptr schema_; std::vector> record_batches_; + std::vector> metadatas_; }; /// \brief Push style stream decoder that receives data from user. diff --git a/cpp/src/arrow/status.cc b/cpp/src/arrow/status.cc index 168b05df3397a..368e03cac0bd2 100644 --- a/cpp/src/arrow/status.cc +++ b/cpp/src/arrow/status.cc @@ -120,6 +120,24 @@ std::string Status::ToString() const { return result; } +std::string Status::ToStringWithoutContextLines() const { + auto message = ToString(); +#ifdef ARROW_EXTRA_ERROR_CONTEXT + while (true) { + auto last_new_line_position = message.rfind("\n"); + if (last_new_line_position == std::string::npos) { + break; + } + // TODO: We may want to check /:\d+ / + if (message.find(":", last_new_line_position) == std::string::npos) { + break; + } + message = message.substr(0, last_new_line_position); + } +#endif + return message; +} + void Status::Abort() const { Abort(std::string()); } void Status::Abort(const std::string& message) const { diff --git a/cpp/src/arrow/status.h b/cpp/src/arrow/status.h index 1b9ba28637835..ac384fc389a49 100644 --- a/cpp/src/arrow/status.h +++ b/cpp/src/arrow/status.h @@ -314,6 +314,12 @@ class ARROW_EXPORT [[nodiscard]] Status : public util::EqualityComparable