From edd0bad949b02a38728391971838c5bac2b3e52f Mon Sep 17 00:00:00 2001 From: Giulio Eulisse <10544+ktf@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:47:50 +0100 Subject: [PATCH] DPL Analysis: add RNTuple arrow::Dataset support --- Framework/Core/CMakeLists.txt | 3 + .../include/Framework/RootArrowFilesystem.h | 115 +++ Framework/Core/src/RootArrowFilesystem.cxx | 653 +++++++++++++++++- Framework/Core/test/test_Root2ArrowTable.cxx | 43 ++ 4 files changed, 813 insertions(+), 1 deletion(-) diff --git a/Framework/Core/CMakeLists.txt b/Framework/Core/CMakeLists.txt index 02367afdcc556..d1dde7e78fdd1 100644 --- a/Framework/Core/CMakeLists.txt +++ b/Framework/Core/CMakeLists.txt @@ -159,6 +159,8 @@ o2_add_library(Framework FairMQ::FairMQ ROOT::Tree ROOT::Hist + ROOT::ROOTNTuple + ROOT::ROOTNTupleUtil O2::FrameworkFoundation O2::CommonConstants O2::Headers @@ -298,6 +300,7 @@ add_executable(o2-test-framework-root target_link_libraries(o2-test-framework-root PRIVATE O2::Framework) target_link_libraries(o2-test-framework-root PRIVATE O2::Catch2) target_link_libraries(o2-test-framework-root PRIVATE ROOT::ROOTDataFrame) +target_link_libraries(o2-test-framework-root PRIVATE ROOT::ROOTNTuple) set_property(TARGET o2-test-framework-root PROPERTY RUNTIME_OUTPUT_DIRECTORY ${outdir}) add_test(NAME framework:root COMMAND o2-test-framework-root --skip-benchmarks) add_test(NAME framework:crash COMMAND sh -e -c "PATH=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}:$PATH ${CMAKE_CURRENT_LIST_DIR}/test/test_AllCrashTypes.sh") diff --git a/Framework/Core/include/Framework/RootArrowFilesystem.h b/Framework/Core/include/Framework/RootArrowFilesystem.h index 48d817bc9ddf2..234757ea596b4 100644 --- a/Framework/Core/include/Framework/RootArrowFilesystem.h +++ b/Framework/Core/include/Framework/RootArrowFilesystem.h @@ -23,6 +23,11 @@ class TTree; class TBufferFile; class TDirectoryFile; +namespace ROOT::Experimental +{ +class RNTuple; +} // namespace ROOT::Experimental + namespace o2::framework { @@ -35,6 +40,15 @@ class TTreeFileWriteOptions : public arrow::dataset::FileWriteOptions } }; +class RNTupleFileWriteOptions : public arrow::dataset::FileWriteOptions +{ + public: + RNTupleFileWriteOptions(std::shared_ptr format) + : FileWriteOptions(format) + { + } +}; + // This is to avoid having to implement a bunch of unimplemented methods // for all the possible virtual filesystem we can invent on top of ROOT // data structures. @@ -97,6 +111,19 @@ class TTreeFileSystem : public VirtualRootFileSystemBase virtual TTree* GetTree(arrow::dataset::FileSource source) = 0; }; +// A filesystem which allows me to get a RNTuple +class RNTupleFileSystem : public VirtualRootFileSystemBase +{ + public: + ~RNTupleFileSystem() override; + + std::shared_ptr GetSubFilesystem(arrow::dataset::FileSource source) override + { + return std::dynamic_pointer_cast(shared_from_this()); + }; + virtual ROOT::Experimental::RNTuple* GetRNTuple(arrow::dataset::FileSource source) = 0; +}; + class SingleTreeFileSystem : public TTreeFileSystem { public: @@ -121,6 +148,30 @@ class SingleTreeFileSystem : public TTreeFileSystem TTree* mTree; }; +class SingleRNTupleFileSystem : public RNTupleFileSystem +{ + public: + SingleRNTupleFileSystem(ROOT::Experimental::RNTuple* tuple) + : RNTupleFileSystem(), + mTuple(tuple) + { + } + + std::string type_name() const override + { + return "rntuple"; + } + + ROOT::Experimental::RNTuple* GetRNTuple(arrow::dataset::FileSource) override + { + // Simply return the only TTree we have + return mTuple; + } + + private: + ROOT::Experimental::RNTuple* mTuple; +}; + class TFileFileSystem : public VirtualRootFileSystemBase { public: @@ -179,6 +230,70 @@ class TTreeFileFragment : public arrow::dataset::FileFragment } }; +class RNTupleFileFragment : public arrow::dataset::FileFragment +{ + public: + RNTupleFileFragment(arrow::dataset::FileSource source, + std::shared_ptr format, + arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) + : FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema)) + { + } +}; + +class RNTupleFileFormat : public arrow::dataset::FileFormat +{ + size_t& mTotCompressedSize; + size_t& mTotUncompressedSize; + + public: + RNTupleFileFormat(size_t& totalCompressedSize, size_t& totalUncompressedSize) + : FileFormat({}), + mTotCompressedSize(totalCompressedSize), + mTotUncompressedSize(totalUncompressedSize) + { + } + + ~RNTupleFileFormat() override = default; + + std::string type_name() const override + { + return "rntuple"; + } + + bool Equals(const FileFormat& other) const override + { + return other.type_name() == this->type_name(); + } + + arrow::Result IsSupported(const arrow::dataset::FileSource& source) const override + { + auto fs = std::dynamic_pointer_cast(source.filesystem()); + auto subFs = fs->GetSubFilesystem(source); + if (std::dynamic_pointer_cast(subFs)) { + return true; + } + return false; + } + + arrow::Result> Inspect(const arrow::dataset::FileSource& source) const override; + + arrow::Result ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& fragment) const override; + + std::shared_ptr DefaultWriteOptions() override; + + arrow::Result> MakeWriter(std::shared_ptr destination, + std::shared_ptr schema, + std::shared_ptr options, + arrow::fs::FileLocator destination_locator) const override; + arrow::Result> MakeFragment( + arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) override; +}; + class TTreeFileFormat : public arrow::dataset::FileFormat { size_t& mTotCompressedSize; diff --git a/Framework/Core/src/RootArrowFilesystem.cxx b/Framework/Core/src/RootArrowFilesystem.cxx index 5f2d21d942d37..9efd895847a9b 100644 --- a/Framework/Core/src/RootArrowFilesystem.cxx +++ b/Framework/Core/src/RootArrowFilesystem.cxx @@ -8,6 +8,7 @@ // In applying this license CERN does not waive the privileges and immunities // granted to it by virtue of its status as an Intergovernmental Organization // or submit itself to any jurisdiction. +#include #include "Framework/RootArrowFilesystem.h" #include "Framework/Endian.h" #include "Framework/RuntimeError.h" @@ -17,6 +18,8 @@ #include #include #include +#include +#include #include #include #include @@ -30,10 +33,20 @@ #include #include #include - +#include +#include +#include +#include +#include #include #include +template class + std::unique_ptr; + +template class + std::shared_ptr; + O2_DECLARE_DYNAMIC_LOG(root_arrow_fs); namespace @@ -85,6 +98,61 @@ auto arrowTypeFromROOT(EDataType type, int size) throw o2::framework::runtime_error_f("Unsupported branch type: %d", static_cast(type)); } } + +struct RootNTupleVisitor : public ROOT::Experimental::Detail::RFieldVisitor { + void VisitArrayField(const ROOT::Experimental::RArrayField& field) override + { + int size = field.GetLength(); + RootNTupleVisitor valueVisitor{}; + auto valueField = field.GetSubFields()[0]; + valueField->AcceptVisitor(valueVisitor); + auto type = valueVisitor.datatype; + this->datatype = arrow::fixed_size_list(type, size); + } + + void VisitRVecField(const ROOT::Experimental::RRVecField& field) override + { + RootNTupleVisitor valueVisitor{}; + auto valueField = field.GetSubFields()[0]; + valueField->AcceptVisitor(valueVisitor); + auto type = valueVisitor.datatype; + this->datatype = arrow::list(type); + } + + void VisitField(const ROOT::Experimental::RFieldBase& field) override + { + throw o2::framework::runtime_error_f("Unknown field %s with type %s", field.GetFieldName().c_str(), field.GetTypeName().c_str()); + } + + void VisitIntField(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::int32(); + } + + void VisitBoolField(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::boolean(); + } + + void VisitFloatField(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::float32(); + } + + void VisitDoubleField(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::float64(); + } + std::shared_ptr datatype; +}; + +auto arrowTypeFromRNTuple(ROOT::Experimental::RFieldBase const& field, int size) +{ + RootNTupleVisitor visitor; + field.AcceptVisitor(visitor); + return visitor.datatype; +} + namespace o2::framework { using arrow::Status; @@ -103,6 +171,11 @@ std::shared_ptr TFileFileSystem::GetSubFilesystem(arr return std::shared_ptr(new SingleTreeFileSystem(tree)); } + auto rntuple = (ROOT::Experimental::RNTuple*)mFile->Get(source.path().c_str()); + if (rntuple) { + return std::shared_ptr(new SingleRNTupleFileSystem(rntuple)); + } + auto directory = (TDirectoryFile*)mFile->GetObjectChecked(source.path().c_str(), TClass::GetClass()); if (directory) { return std::shared_ptr(new TFileFileSystem(directory, 50 * 1024 * 1024)); @@ -294,6 +367,18 @@ arrow::Result> TTreeFileFormat::Ma return std::dynamic_pointer_cast(fragment); } +arrow::Result> RNTupleFileFormat::MakeFragment( + arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) +{ + std::shared_ptr format = std::make_shared(mTotCompressedSize, mTotUncompressedSize); + + auto fragment = std::make_shared(std::move(source), std::move(format), + std::move(partition_expression), + std::move(physical_schema)); + return std::dynamic_pointer_cast(fragment); +} + // An arrow outputstream which allows to write to a ttree TDirectoryFileOutputStream::TDirectoryFileOutputStream(TDirectoryFile* f) : mDirectory(f) @@ -393,6 +478,37 @@ char const* rootSuffixFromArrow(arrow::Type::type id) } } +std::unique_ptr rootFieldFromArrow(std::shared_ptr field, std::string name) +{ + using namespace ROOT::Experimental; + switch (field->type()->id()) { + case arrow::Type::BOOL: + return std::make_unique>(name); + case arrow::Type::UINT8: + return std::make_unique>(name); + case arrow::Type::UINT16: + return std::make_unique>(name); + case arrow::Type::UINT32: + return std::make_unique>(name); + case arrow::Type::UINT64: + return std::make_unique>(name); + case arrow::Type::INT8: + return std::make_unique>(name); + case arrow::Type::INT16: + return std::make_unique>(name); + case arrow::Type::INT32: + return std::make_unique>(name); + case arrow::Type::INT64: + return std::make_unique>(name); + case arrow::Type::FLOAT: + return std::make_unique>(name); + case arrow::Type::DOUBLE: + return std::make_unique>(name); + default: + throw runtime_error("Unsupported arrow column type"); + } +} + class TTreeFileWriter : public arrow::dataset::FileWriter { std::vector branches; @@ -628,6 +744,212 @@ class TTreeFileWriter : public arrow::dataset::FileWriter }; }; +class RNTupleFileWriter : public arrow::dataset::FileWriter +{ + std::shared_ptr mWriter; + bool firstBatch = true; + std::vector> valueArrays; + std::vector> valueTypes; + std::vector valueCount; + + public: + RNTupleFileWriter(std::shared_ptr schema, std::shared_ptr options, + std::shared_ptr destination, + arrow::fs::FileLocator destination_locator) + : FileWriter(schema, options, destination, destination_locator) + { + using namespace ROOT::Experimental; + + auto model = RNTupleModel::Create(); + // Let's create a model from the physical schema + for (auto i = 0u; i < schema->fields().size(); ++i) { + auto& field = schema->field(i); + + // Construct all the needed branches. + switch (field->type()->id()) { + case arrow::Type::FIXED_SIZE_LIST: { + auto list = std::static_pointer_cast(field->type()); + auto valueField = field->type()->field(0); + model->AddField(std::make_unique(field->name(), rootFieldFromArrow(valueField, "_0"), list->list_size())); + } break; + case arrow::Type::LIST: { + auto valueField = field->type()->field(0); + model->AddField(std::make_unique(field->name(), rootFieldFromArrow(valueField, "_0"))); + } break; + default: { + model->AddField(rootFieldFromArrow(field, field->name())); + } break; + } + } + auto fileStream = std::dynamic_pointer_cast(destination_); + auto* file = dynamic_cast(fileStream->GetDirectory()); + mWriter = RNTupleWriter::Append(std::move(model), destination_locator_.path, *file, {}); + } + + arrow::Status Write(const std::shared_ptr& batch) override + { + if (firstBatch) { + firstBatch = false; + } + + // Support writing empty tables + if (batch->columns().empty() || batch->num_rows() == 0) { + return arrow::Status::OK(); + } + + for (auto i = 0u; i < batch->columns().size(); ++i) { + auto column = batch->column(i); + auto& field = batch->schema()->field(i); + + valueArrays.push_back(nullptr); + valueTypes.push_back(nullptr); + valueCount.push_back(1); + + switch (field->type()->id()) { + case arrow::Type::FIXED_SIZE_LIST: { + auto list = std::static_pointer_cast(column); + auto listType = std::static_pointer_cast(field->type()); + if (field->type()->field(0)->type()->id() == arrow::Type::BOOL) { + auto boolArray = std::static_pointer_cast(list->values()); + int64_t length = boolArray->length(); + arrow::UInt8Builder builder; + auto ok = builder.Reserve(length); + + for (int64_t i = 0; i < length; ++i) { + if (boolArray->IsValid(i)) { + // Expand each boolean value (true/false) to uint8 (1/0) + uint8_t value = boolArray->Value(i) ? 1 : 0; + auto ok = builder.Append(value); + } else { + // Append null for invalid entries + auto ok = builder.AppendNull(); + } + } + valueArrays.back() = *builder.Finish(); + valueTypes.back() = valueArrays.back()->type(); + } else { + valueArrays.back() = list->values(); + valueTypes.back() = field->type()->field(0)->type(); + } + valueCount.back() = listType->list_size(); + } break; + case arrow::Type::LIST: { + auto list = std::static_pointer_cast(column); + valueArrays.back() = list; + valueTypes.back() = field->type()->field(0)->type(); + valueCount.back() = -1; + } break; + case arrow::Type::BOOL: { + // We unpack the array + auto boolArray = std::static_pointer_cast(column); + int64_t length = boolArray->length(); + arrow::UInt8Builder builder; + auto ok = builder.Reserve(length); + + for (int64_t i = 0; i < length; ++i) { + if (boolArray->IsValid(i)) { + // Expand each boolean value (true/false) to uint8 (1/0) + uint8_t value = boolArray->Value(i) ? 1 : 0; + auto ok = builder.Append(value); + } else { + // Append null for invalid entries + auto ok = builder.AppendNull(); + } + } + valueArrays.back() = *builder.Finish(); + valueTypes.back() = valueArrays.back()->type(); + } break; + default: + valueArrays.back() = column; + valueTypes.back() = field->type(); + break; + } + } + + int64_t pos = 0; + + auto entry = mWriter->CreateEntry(); + while (pos < batch->num_rows()) { + for (size_t ci = 0; ci < batch->columns().size(); ++ci) { + auto type = batch->column(ci)->type(); + auto field = batch->schema()->field(ci); + auto token = entry->GetToken(field->name()); + + switch (type->id()) { + case arrow::Type::LIST: { + auto list = std::static_pointer_cast(valueArrays[ci]); + auto value_slice = list->value_slice(pos); + + valueCount[ci] = value_slice->length(); + auto bindValue = [&vc = valueCount, ci, token](auto array, std::unique_ptr& entry) -> void { + using value_type = std::decay_t::value_type; + auto v = std::make_shared>((value_type*)array->raw_values(), vc[ci]); + entry->BindValue(token, v); + }; + switch (valueTypes[ci]->id()) { + case arrow::Type::FLOAT: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::DOUBLE: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::INT8: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::INT16: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::INT32: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::INT64: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::UINT8: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::UINT16: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::UINT32: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::UINT64: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + default: { + throw runtime_error("Unsupported kind of VLA"); + } break; + } + } break; + case arrow::Type::FIXED_SIZE_LIST: { + entry->BindRawPtr(token, (void*)(valueArrays[ci]->data()->buffers[1]->data() + pos * valueCount[ci] * valueTypes[ci]->byte_width())); + } break; + case arrow::Type::BOOL: { + // Not sure we actually need this + entry->BindRawPtr(token, (bool*)(valueArrays[ci]->data()->buffers[1]->data() + pos * 1)); + } break; + default: + // By default we consider things scalars. + entry->BindRawPtr(token, (void*)(valueArrays[ci]->data()->buffers[1]->data() + pos * valueTypes[ci]->byte_width())); + break; + } + } + mWriter->Fill(*entry); + ++pos; + } + mWriter->CommitCluster(); + + return arrow::Status::OK(); + } + + arrow::Future<> + FinishInternal() override + { + return {}; + }; +}; + arrow::Result> TTreeFileFormat::MakeWriter(std::shared_ptr destination, std::shared_ptr schema, std::shared_ptr options, arrow::fs::FileLocator destination_locator) const { auto writer = std::make_shared(schema, options, destination, destination_locator); @@ -838,6 +1160,314 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( return generator; } +arrow::Result> RNTupleFileFormat::Inspect(const arrow::dataset::FileSource& source) const +{ + + auto fs = std::dynamic_pointer_cast(source.filesystem()); + // Actually get the TTree from the ROOT file. + auto ntupleFs = std::dynamic_pointer_cast(fs->GetSubFilesystem(source)); + if (!ntupleFs.get()) { + throw runtime_error_f("Unknown filesystem %s\n", source.filesystem()->type_name().c_str()); + } + ROOT::Experimental::RNTuple* rntuple = ntupleFs->GetRNTuple(source); + + auto inspector = ROOT::Experimental::RNTupleInspector::Create(rntuple); + + auto reader = ROOT::Experimental::RNTupleReader::Open(rntuple); + + auto& tupleField0 = reader->GetModel().GetFieldZero(); + std::vector> fields; + for (auto& tupleField : tupleField0.GetSubFields()) { + auto field = std::make_shared(tupleField->GetFieldName(), arrowTypeFromRNTuple(*tupleField, tupleField->GetValueSize())); + fields.push_back(field); + } + + return std::make_shared(fields); +} + +arrow::Result RNTupleFileFormat::ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& fragment) const +{ + auto dataset_schema = options->dataset_schema; + auto ntupleFragment = std::dynamic_pointer_cast(fragment); + + auto generator = [pool = options->pool, ntupleFragment, dataset_schema, &totalCompressedSize = mTotCompressedSize, + &totalUncompressedSize = mTotUncompressedSize]() -> arrow::Future> { + using namespace ROOT::Experimental; + std::vector> columns; + std::vector> fields = dataset_schema->fields(); + + auto containerFS = std::dynamic_pointer_cast(ntupleFragment->source().filesystem()); + auto fs = std::dynamic_pointer_cast(containerFS->GetSubFilesystem(ntupleFragment->source())); + + int64_t rows = -1; + ROOT::Experimental::RNTuple* rntuple = fs->GetRNTuple(ntupleFragment->source()); + auto reader = ROOT::Experimental::RNTupleReader::Open(rntuple); + auto& model = reader->GetModel(); + for (auto& physicalField : fields) { + auto bulk = model.CreateBulk(physicalField->name()); + + auto listType = std::dynamic_pointer_cast(physicalField->type()); + + auto& descriptor = reader->GetDescriptor(); + auto totalEntries = reader->GetNEntries(); + + if (rows == -1) { + rows = totalEntries; + } + if (rows != totalEntries) { + throw runtime_error_f("Unmatching number of rows for branch %s", physicalField->name().c_str()); + } + arrow::Status status; + int readEntries = 0; + std::shared_ptr array; + if (physicalField->type() == arrow::boolean() || + (listType && physicalField->type()->field(0)->type() == arrow::boolean())) { + if (listType) { + std::unique_ptr builder = nullptr; + auto status = arrow::MakeBuilder(pool, physicalField->type()->field(0)->type(), &builder); + if (!status.ok()) { + throw runtime_error("Cannot create value builder"); + } + auto listBuilder = std::make_unique(pool, std::move(builder), listType->list_size()); + auto valueBuilder = listBuilder.get()->value_builder(); + // boolean array special case: we need to use builder to create the bitmap + status = valueBuilder->Reserve(totalEntries * listType->list_size()); + status &= listBuilder->Reserve(totalEntries); + if (!status.ok()) { + throw runtime_error("Failed to reserve memory for array builder"); + } + auto clusterIt = descriptor.FindClusterId(0, 0); + // No adoption for now... + // bulk.AdoptBuffer(buffer, totalEntries) + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + void* ptr = bulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()); + int readLast = index.GetNEntries(); + readEntries += readLast; + status &= static_cast(valueBuilder)->AppendValues(reinterpret_cast(ptr), readLast * listType->list_size()); + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + status &= static_cast(listBuilder.get())->AppendValues(readEntries); + if (!status.ok()) { + throw runtime_error("Failed to append values to array"); + } + status &= listBuilder->Finish(&array); + if (!status.ok()) { + throw runtime_error("Failed to create array"); + } + } else if (listType == nullptr) { + std::unique_ptr builder = nullptr; + auto status = arrow::MakeBuilder(pool, physicalField->type(), &builder); + if (!status.ok()) { + throw runtime_error("Cannot create builder"); + } + auto valueBuilder = static_cast(builder.get()); + // boolean array special case: we need to use builder to create the bitmap + status = valueBuilder->Reserve(totalEntries); + if (!status.ok()) { + throw runtime_error("Failed to reserve memory for array builder"); + } + auto clusterIt = descriptor.FindClusterId(0, 0); + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + void* ptr = bulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()); + int readLast = index.GetNEntries(); + readEntries += readLast; + status &= valueBuilder->AppendValues(reinterpret_cast(ptr), readLast); + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + if (!status.ok()) { + throw runtime_error("Failed to append values to array"); + } + status &= valueBuilder->Finish(&array); + if (!status.ok()) { + throw runtime_error("Failed to create array"); + } + } + } else { + // other types: use serialized read to build arrays directly. + auto typeSize = physicalField->type()->byte_width(); + // FIXME: for now... + auto bytes = 0; + auto branchSize = bytes ? bytes : 1000000; + auto&& result = arrow::AllocateResizableBuffer(branchSize, pool); + if (!result.ok()) { + throw runtime_error("Cannot allocate values buffer"); + } + std::shared_ptr arrowValuesBuffer = std::move(result).ValueUnsafe(); + auto ptr = arrowValuesBuffer->mutable_data(); + if (ptr == nullptr) { + throw runtime_error("Invalid buffer"); + } + + std::unique_ptr offsetBuffer = nullptr; + + std::shared_ptr arrowOffsetBuffer; + std::span offsets; + int size = 0; + uint32_t totalSize = 0; + int64_t listSize = 1; + if (auto fixedSizeList = std::dynamic_pointer_cast(physicalField->type())) { + listSize = fixedSizeList->list_size(); + typeSize = fixedSizeList->field(0)->type()->byte_width(); + auto clusterIt = descriptor.FindClusterId(0, 0); + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + void* inPtr = bulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()); + + int readLast = index.GetNEntries(); + if (listSize == -1) { + size = offsets[readEntries + readLast] - offsets[readEntries]; + } else { + size = readLast * listSize; + } + readEntries += readLast; + memcpy(ptr, inPtr, size * typeSize); + ptr += (ptrdiff_t)(size * typeSize); + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + } else if (auto vlaListType = std::dynamic_pointer_cast(physicalField->type())) { + listSize = -1; + typeSize = vlaListType->field(0)->type()->byte_width(); + offsetBuffer = std::make_unique(TBuffer::EMode::kWrite, 4 * 1024 * 1024); + result = arrow::AllocateResizableBuffer((totalEntries + 1) * (int64_t)sizeof(int), pool); + if (!result.ok()) { + throw runtime_error("Cannot allocate offset buffer"); + } + arrowOffsetBuffer = std::move(result).ValueUnsafe(); + + // Offset bulk + auto offsetBulk = model.CreateBulk(physicalField->name()); + // Actual values are in a different place... + bulk = model.CreateBulk(physicalField->name()); + auto clusterIt = descriptor.FindClusterId(0, 0); + auto* ptrOffset = reinterpret_cast(arrowOffsetBuffer->mutable_data()); + auto* tPtrOffset = reinterpret_cast(ptrOffset); + offsets = std::span{tPtrOffset, tPtrOffset + totalEntries + 1}; + + auto copyOffsets = [&arrowValuesBuffer, &pool, &ptrOffset, &ptr, &totalSize](auto inPtr, size_t total) { + using value_type = typename std::decay_t::value_type; + for (size_t i = 0; i < total; i++) { + *ptrOffset++ = totalSize; + totalSize += inPtr[i].size(); + } + *ptrOffset = totalSize; + auto&& result = arrow::AllocateResizableBuffer(totalSize * sizeof(value_type), pool); + if (!result.ok()) { + throw runtime_error("Cannot allocate values buffer"); + } + arrowValuesBuffer = std::move(result).ValueUnsafe(); + ptr = (uint8_t*)(arrowValuesBuffer->mutable_data()); + // Calculate the size of the buffer here. + for (size_t i = 0; i < total; i++) { + int vlaSizeInBytes = inPtr[i].size() * sizeof(value_type); + if (vlaSizeInBytes == 0) { + continue; + } + memcpy(ptr, inPtr[i].data(), vlaSizeInBytes); + ptr += vlaSizeInBytes; + } + }; + + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + int readLast = index.GetNEntries(); + switch (vlaListType->field(0)->type()->id()) { + case arrow::Type::FLOAT: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::DOUBLE: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::INT8: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::INT16: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::INT32: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::INT64: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::UINT8: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::UINT16: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::UINT32: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::UINT64: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + default: { + throw runtime_error("Unsupported kind of VLA"); + } break; + } + + readEntries += readLast; + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + } else { + auto clusterIt = descriptor.FindClusterId(0, 0); + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + void* inPtr = bulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()); + + int readLast = index.GetNEntries(); + if (listSize == -1) { + size = offsets[readEntries + readLast] - offsets[readEntries]; + } else { + size = readLast * listSize; + } + readEntries += readLast; + memcpy(ptr, inPtr, size * typeSize); + ptr += (ptrdiff_t)(size * typeSize); + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + } + switch (listSize) { + case -1: { + auto varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); + array = std::make_shared(physicalField->type(), readEntries, arrowOffsetBuffer, varray); + } break; + case 1: { + totalSize = readEntries * listSize; + array = std::make_shared(physicalField->type(), readEntries, arrowValuesBuffer); + + } break; + default: { + totalSize = readEntries * listSize; + auto varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); + array = std::make_shared(physicalField->type(), readEntries, varray); + } + } + } + columns.push_back(array); + } + + auto batch = arrow::RecordBatch::Make(dataset_schema, rows, columns); + return batch; + }; + + return generator; +} + arrow::Result> TTreeFileSystem::OpenOutputStream( const std::string& path, const std::shared_ptr& metadata) @@ -850,6 +1480,21 @@ arrow::Result> TTreeFileSystem::OpenOut return std::make_shared(GetTree(source), ""); } +std::shared_ptr + RNTupleFileFormat::DefaultWriteOptions() +{ + return std::make_shared(shared_from_this()); +} + +arrow::Result> RNTupleFileFormat::MakeWriter(std::shared_ptr destination, + std::shared_ptr schema, + std::shared_ptr options, + arrow::fs::FileLocator destination_locator) const +{ + auto writer = std::make_shared(schema, options, destination, destination_locator); + return std::dynamic_pointer_cast(writer); +} + TBufferFileFS::TBufferFileFS(TBufferFile* f) : VirtualRootFileSystemBase(), mBuffer(f), @@ -859,6 +1504,8 @@ TBufferFileFS::TBufferFileFS(TBufferFile* f) TTreeFileSystem::~TTreeFileSystem() = default; +RNTupleFileSystem::~RNTupleFileSystem() = default; + arrow::Result TBufferFileFS::GetFileInfo(const std::string& path) { arrow::fs::FileInfo result; @@ -876,6 +1523,10 @@ arrow::Result TBufferFileFS::GetFileInfo(const std::string& result.set_type(arrow::fs::FileType::File); return result; } + if (std::dynamic_pointer_cast(mFilesystem)) { + result.set_type(arrow::fs::FileType::File); + return result; + } return result; } diff --git a/Framework/Core/test/test_Root2ArrowTable.cxx b/Framework/Core/test/test_Root2ArrowTable.cxx index 2b0ab9154250c..cd9a4b4685e7a 100644 --- a/Framework/Core/test/test_Root2ArrowTable.cxx +++ b/Framework/Core/test/test_Root2ArrowTable.cxx @@ -26,6 +26,13 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -504,4 +511,40 @@ TEST_CASE("RootTree2Dataset") REQUIRE((*resultWritten)->num_rows() == 100); validateContents(*resultWritten); } + // Let's write back an RNTuple + auto rNtupleFormat = std::make_shared(totalSizeCompressed, totalSizeUncompressed); + arrow::fs::FileLocator rnTupleLocator{outFs, "/rntuple"}; + // We write an RNTuple in the same TMemFile, using /rntuple as a location + auto rntupleDestination = std::dynamic_pointer_cast(*destination); + + { + auto rNtupleWriter = rNtupleFormat->MakeWriter(*destination, schema, {}, rnTupleLocator); + auto rNtupleSuccess = rNtupleWriter->get()->Write(*result); + REQUIRE(rNtupleSuccess.ok()); + } + + // And now we can read back the RNTuple into a RecordBatch + arrow::dataset::FileSource writtenRntupleSource("/rntuple", outFs); + auto newRNTupleFS = outFs->GetSubFilesystem(writtenRntupleSource); + + REQUIRE(rNtupleFormat->IsSupported(writtenRntupleSource) == true); + + auto rntupleSchemaOpt = rNtupleFormat->Inspect(writtenRntupleSource); + REQUIRE(rntupleSchemaOpt.ok()); + auto rntupleSchemaWritten = *rntupleSchemaOpt; + REQUIRE(validateSchema(rntupleSchemaWritten)); + + auto rntupleFragmentWritten = rNtupleFormat->MakeFragment(writtenRntupleSource, {}, rntupleSchemaWritten); + REQUIRE(rntupleFragmentWritten.ok()); + auto rntupleOptionsWritten = std::make_shared(); + rntupleOptionsWritten->dataset_schema = rntupleSchemaWritten; + auto rntupleScannerWritten = rNtupleFormat->ScanBatchesAsync(rntupleOptionsWritten, *rntupleFragmentWritten); + REQUIRE(rntupleScannerWritten.ok()); + auto rntupleBatchesWritten = (*rntupleScannerWritten)(); + auto rntupleResultWritten = rntupleBatchesWritten.result(); + REQUIRE(rntupleResultWritten.ok()); + REQUIRE((*rntupleResultWritten)->columns().size() == 10); + REQUIRE(validateSchema((*rntupleResultWritten)->schema())); + REQUIRE((*rntupleResultWritten)->num_rows() == 100); + REQUIRE(validateContents(*rntupleResultWritten)); }