diff --git a/Framework/Core/include/Framework/RootArrowFilesystem.h b/Framework/Core/include/Framework/RootArrowFilesystem.h index 7c8385ccd2b9d..48d817bc9ddf2 100644 --- a/Framework/Core/include/Framework/RootArrowFilesystem.h +++ b/Framework/Core/include/Framework/RootArrowFilesystem.h @@ -17,6 +17,8 @@ #include #include +class TFile; +class TBranch; class TTree; class TBufferFile; class TDirectoryFile; @@ -227,11 +229,38 @@ class TTreeFileFormat : public arrow::dataset::FileFormat const std::shared_ptr& fragment) const override; }; -// An arrow outputstream which allows to write to a ttree +// An arrow outputstream which allows to write to a TDirectoryFile. +// This will point to the location of the file itself. You can +// specify the location of the actual object inside it by passing the +// associated path to the Write() API. +class TDirectoryFileOutputStream : public arrow::io::OutputStream +{ + public: + TDirectoryFileOutputStream(TDirectoryFile*); + + arrow::Status Close() override; + + arrow::Result Tell() const override; + + arrow::Status Write(const void* data, int64_t nbytes) override; + + bool closed() const override; + + TDirectoryFile* GetDirectory() + { + return mDirectory; + } + + private: + TDirectoryFile* mDirectory; +}; + +// An arrow outputstream which allows to write to a TTree. Eventually +// with a prefix for the branches. class TTreeOutputStream : public arrow::io::OutputStream { public: - TTreeOutputStream(TTree* t); + TTreeOutputStream(TTree*, std::string branchPrefix); arrow::Status Close() override; @@ -241,6 +270,8 @@ class TTreeOutputStream : public arrow::io::OutputStream bool closed() const override; + TBranch* CreateBranch(char const* branchName, char const* sizeBranch); + TTree* GetTree() { return mTree; @@ -248,6 +279,7 @@ class TTreeOutputStream : public arrow::io::OutputStream private: TTree* mTree; + std::string mBranchPrefix; }; } // namespace o2::framework diff --git a/Framework/Core/src/RootArrowFilesystem.cxx b/Framework/Core/src/RootArrowFilesystem.cxx index 7581ee57e5b9f..7e331814272a6 100644 --- a/Framework/Core/src/RootArrowFilesystem.cxx +++ b/Framework/Core/src/RootArrowFilesystem.cxx @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -28,8 +27,11 @@ #include #include #include +#include #include +#include +#include O2_DECLARE_DYNAMIC_LOG(root_arrow_fs); @@ -100,7 +102,6 @@ std::shared_ptr TFileFileSystem::GetSubFilesystem(arr return std::shared_ptr(new SingleTreeFileSystem(tree)); } - auto directory = (TDirectoryFile*)mFile->GetObjectChecked(source.path().c_str(), TClass::GetClass()); if (directory) { return std::shared_ptr(new TFileFileSystem(directory, 50 * 1024 * 1024)); @@ -129,8 +130,15 @@ arrow::Result> TFileFileSystem::OpenOut const std::string& path, const std::shared_ptr& metadata) { - auto* t = new TTree(path.c_str(), "should put a name here"); - auto stream = std::make_shared(t); + if (path == "/") { + return std::make_shared(this->GetFile()); + } + + auto* dir = dynamic_cast(this->GetFile()->Get(path.c_str())); + if (!dir) { + throw runtime_error_f("Unable to open directory %s in file %s", path.c_str(), GetFile()->GetName()); + } + auto stream = std::make_shared(dir); return stream; } @@ -286,13 +294,46 @@ arrow::Result> TTreeFileFormat::Ma } // An arrow outputstream which allows to write to a ttree -TTreeOutputStream::TTreeOutputStream(TTree* t) - : mTree(t) +TDirectoryFileOutputStream::TDirectoryFileOutputStream(TDirectoryFile* f) + : mDirectory(f) +{ +} + +arrow::Status TDirectoryFileOutputStream::Close() +{ + mDirectory->GetFile()->Close(); + return arrow::Status::OK(); +} + +arrow::Result TDirectoryFileOutputStream::Tell() const +{ + return arrow::Result(arrow::Status::NotImplemented("Cannot move")); +} + +arrow::Status TDirectoryFileOutputStream::Write(const void* data, int64_t nbytes) +{ + return arrow::Status::NotImplemented("Cannot write raw bytes to a TTree"); +} + +bool TDirectoryFileOutputStream::closed() const +{ + return mDirectory->GetFile()->IsOpen() == false; +} + +// An arrow outputstream which allows to write to a ttree +// @a branch prefix is to be used to identify a set of branches which all belong to +// the same table. +TTreeOutputStream::TTreeOutputStream(TTree* f, std::string branchPrefix) + : mTree(f), + mBranchPrefix(std::move(branchPrefix)) { } arrow::Status TTreeOutputStream::Close() { + if (mTree->GetCurrentFile() == nullptr) { + return arrow::Status::Invalid("Cannot close a tree not attached to a file"); + } mTree->GetCurrentFile()->Close(); return arrow::Status::OK(); } @@ -309,9 +350,18 @@ arrow::Status TTreeOutputStream::Write(const void* data, int64_t nbytes) bool TTreeOutputStream::closed() const { + // A standalone tree is never closed. + if (mTree->GetCurrentFile() == nullptr) { + return false; + } return mTree->GetCurrentFile()->IsOpen() == false; } +TBranch* TTreeOutputStream::CreateBranch(char const* branchName, char const* sizeBranch) +{ + return mTree->Branch((mBranchPrefix + "/" + branchName).c_str(), (char*)nullptr, (mBranchPrefix + sizeBranch).c_str()); +} + char const* rootSuffixFromArrow(arrow::Type::type id) { switch (id) { @@ -411,8 +461,24 @@ class TTreeFileWriter : public arrow::dataset::FileWriter : FileWriter(schema, options, destination, destination_locator) { // Batches have the same number of entries for each column. + auto directoryStream = std::dynamic_pointer_cast(destination_); auto treeStream = std::dynamic_pointer_cast(destination_); - TTree* tree = treeStream->GetTree(); + + if (directoryStream.get()) { + TDirectoryFile* dir = directoryStream->GetDirectory(); + dir->cd(); + auto* tree = new TTree(destination_locator_.path.c_str(), ""); + treeStream = std::make_shared(tree, ""); + } else if (treeStream.get()) { + // We already have a tree stream, let's derive a new one + // with the destination_locator_.path as prefix for the branches + // This way we can multiplex multiple tables in the same tree. + auto tree = treeStream->GetTree(); + treeStream = std::make_shared(tree, destination_locator_.path); + } else { + // I could simply set a prefix here to merge to an already existing tree. + throw std::runtime_error("Unsupported backend."); + } for (auto i = 0u; i < schema->fields().size(); ++i) { auto& field = schema->field(i); @@ -427,15 +493,15 @@ class TTreeFileWriter : public arrow::dataset::FileWriter valueTypes.push_back(field->type()->field(0)->type()); sizesBranches.push_back(nullptr); std::string leafList = fmt::format("{}[{}]{}", field->name(), listSizes.back(), rootSuffixFromArrow(valueTypes.back()->id())); - branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str())); + branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); } break; case arrow::Type::LIST: { valueTypes.push_back(field->type()->field(0)->type()); listSizes.back() = 0; // VLA, we need to calculate it on the fly; std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id())); std::string sizeLeafList = field->name() + "_size/I"; - sizesBranches.push_back(tree->Branch((field->name() + "_size").c_str(), (char*)nullptr, sizeLeafList.c_str())); - branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str())); + sizesBranches.push_back(treeStream->CreateBranch((field->name() + "_size").c_str(), sizeLeafList.c_str())); + branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); // Notice that this could be replaced by a better guess of the // average size of the list elements, but this is not trivial. } break; @@ -443,7 +509,7 @@ class TTreeFileWriter : public arrow::dataset::FileWriter valueTypes.push_back(field->type()); std::string leafList = field->name() + rootSuffixFromArrow(valueTypes.back()->id()); sizesBranches.push_back(nullptr); - branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str())); + branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); } break; } } @@ -463,11 +529,18 @@ class TTreeFileWriter : public arrow::dataset::FileWriter } // Batches have the same number of entries for each column. + auto directoryStream = std::dynamic_pointer_cast(destination_); + TTree* tree = nullptr; + if (directoryStream.get()) { + TDirectoryFile* dir = directoryStream->GetDirectory(); + tree = (TTree*)dir->Get(destination_locator_.path.c_str()); + } auto treeStream = std::dynamic_pointer_cast(destination_); - TTree* tree = treeStream->GetTree(); - // Caches for the vectors of bools. - std::vector> caches; + if (!tree) { + // I could simply set a prefix here to merge to an already existing tree. + throw std::runtime_error("Unsupported backend."); + } for (auto i = 0u; i < batch->columns().size(); ++i) { auto column = batch->column(i); @@ -484,24 +557,11 @@ class TTreeFileWriter : public arrow::dataset::FileWriter auto list = std::static_pointer_cast(column); valueArrays.back() = list; } break; - default: - valueArrays.back() = column; - } - } - - int64_t pos = 0; - while (pos < batch->num_rows()) { - for (size_t bi = 0; bi < branches.size(); ++bi) { - auto* branch = branches[bi]; - auto* sizeBranch = sizesBranches[bi]; - auto array = batch->column(bi); - auto& field = batch->schema()->field(bi); - auto& listSize = listSizes[bi]; - auto valueType = valueTypes[bi]; - auto valueArray = valueArrays[bi]; + case arrow::Type::BOOL: { + // In case of arrays of booleans, we need to go back to their + // char based representation for ROOT to save them. + auto boolArray = std::static_pointer_cast(column); - if (field->type()->id() == arrow::Type::BOOL) { - auto boolArray = std::static_pointer_cast(array); int64_t length = boolArray->length(); arrow::UInt8Builder builder; auto ok = builder.Reserve(length); @@ -516,11 +576,24 @@ class TTreeFileWriter : public arrow::dataset::FileWriter auto ok = builder.AppendNull(); } } + valueArrays.back() = *builder.Finish(); + } break; + default: + valueArrays.back() = column; + } + } + + int64_t pos = 0; + while (pos < batch->num_rows()) { + for (size_t bi = 0; bi < branches.size(); ++bi) { + auto* branch = branches[bi]; + auto* sizeBranch = sizesBranches[bi]; + auto array = batch->column(bi); + auto& field = batch->schema()->field(bi); + auto& listSize = listSizes[bi]; + auto valueType = valueTypes[bi]; + auto valueArray = valueArrays[bi]; - ok = builder.Finish(&caches[bi]); - branch->SetAddress((void*)(caches[bi]->values()->data())); - continue; - } switch (field->type()->id()) { case arrow::Type::LIST: { auto list = std::static_pointer_cast(array); @@ -764,13 +837,16 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( return generator; } - arrow::Result> TTreeFileSystem::OpenOutputStream( const std::string& path, const std::shared_ptr& metadata) { - auto stream = std::make_shared(GetTree({path, shared_from_this()})); - return stream; + arrow::dataset::FileSource source{path, shared_from_this()}; + auto prefix = metadata->Get("branch_prefix"); + if (prefix.ok()) { + return std::make_shared(GetTree(source), *prefix); + } + return std::make_shared(GetTree(source), ""); } TBufferFileFS::TBufferFileFS(TBufferFile* f) @@ -782,7 +858,6 @@ TBufferFileFS::TBufferFileFS(TBufferFile* f) TTreeFileSystem::~TTreeFileSystem() = default; - arrow::Result TBufferFileFS::GetFileInfo(const std::string& path) { arrow::fs::FileInfo result; diff --git a/Framework/Core/test/test_Root2ArrowTable.cxx b/Framework/Core/test/test_Root2ArrowTable.cxx index 03f0977a4c0c4..a659d488ae24a 100644 --- a/Framework/Core/test/test_Root2ArrowTable.cxx +++ b/Framework/Core/test/test_Root2ArrowTable.cxx @@ -20,13 +20,18 @@ #include #include #include +#include #include #include #include #include #include +#include +#include +#include #include +#include #include #include #include @@ -259,6 +264,82 @@ TEST_CASE("RootTree2Fragment") REQUIRE((*result)->num_rows() == 1000); } +bool validateContents(std::shared_ptr batch) +{ + { + auto int_array = std::static_pointer_cast(batch->GetColumnByName("ev")); + REQUIRE(int_array->length() == 100); + for (int64_t j = 0; j < int_array->length(); j++) { + REQUIRE(int_array->Value(j) == j + 1); + } + } + + { + auto list_array = std::static_pointer_cast(batch->GetColumnByName("xyz")); + + REQUIRE(list_array->length() == 100); + // Iterate over the FixedSizeListArray + for (int64_t i = 0; i < list_array->length(); i++) { + auto value_slice = list_array->value_slice(i); + auto float_array = std::static_pointer_cast(value_slice); + + REQUIRE(float_array->Value(0) == 1); + REQUIRE(float_array->Value(1) == 2); + REQUIRE(float_array->Value(2) == i + 1); + } + } + + { + auto list_array = std::static_pointer_cast(batch->GetColumnByName("ij")); + + REQUIRE(list_array->length() == 100); + // Iterate over the FixedSizeListArray + for (int64_t i = 0; i < list_array->length(); i++) { + auto value_slice = list_array->value_slice(i); + auto int_array = std::static_pointer_cast(value_slice); + REQUIRE(int_array->Value(0) == i); + REQUIRE(int_array->Value(1) == i + 1); + } + } + + { + auto bool_array = std::static_pointer_cast(batch->GetColumnByName("bools")); + + REQUIRE(bool_array->length() == 100); + for (int64_t j = 0; j < bool_array->length(); j++) { + REQUIRE(bool_array->Value(j) == (j % 3 == 0)); + } + } + + { + auto list_array = std::static_pointer_cast(batch->GetColumnByName("manyBools")); + + REQUIRE(list_array->length() == 100); + for (int64_t i = 0; i < list_array->length(); i++) { + auto value_slice = list_array->value_slice(i); + auto bool_array = std::static_pointer_cast(value_slice); + REQUIRE(bool_array->Value(0) == (i % 4 == 0)); + REQUIRE(bool_array->Value(1) == (i % 5 == 0)); + } + } + return true; +} + +bool validateSchema(std::shared_ptr schema) +{ + REQUIRE(schema->num_fields() == 9); + REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id()); + REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id()); + REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id()); + REQUIRE(schema->field(3)->type()->id() == arrow::float64()->id()); + REQUIRE(schema->field(4)->type()->id() == arrow::int32()->id()); + REQUIRE(schema->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id()); + REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id()); + REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id()); + REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id()); + return true; +} + TEST_CASE("RootTree2Dataset") { using namespace o2::framework; @@ -307,6 +388,9 @@ TEST_CASE("RootTree2Dataset") Float_t px = 0, py = 1, pz = 2; Double_t random; Int_t ev; + bool oneBool; + bool manyBool[2]; + t->Branch("px", &px, "px/F"); t->Branch("py", &py, "py/F"); t->Branch("pz", &pz, "pz/F"); @@ -314,6 +398,8 @@ TEST_CASE("RootTree2Dataset") t->Branch("ev", &ev, "ev/I"); t->Branch("xyz", xyz, "xyz[3]/F"); t->Branch("ij", ij, "ij[2]/I"); + t->Branch("bools", &oneBool, "bools/O"); + t->Branch("manyBools", &manyBool, "manyBools[2]/O"); // fill the tree for (Int_t i = 0; i < 100; i++) { xyz[0] = 1; @@ -326,6 +412,9 @@ TEST_CASE("RootTree2Dataset") ij[1] = i + 1; random = gRandom->Rndm(); ev = i + 1; + oneBool = (i % 3 == 0); + manyBool[0] = (i % 4 == 0); + manyBool[1] = (i % 5 == 0); t->Fill(); } } @@ -339,7 +428,7 @@ TEST_CASE("RootTree2Dataset") auto schemaOpt = format->Inspect(source); REQUIRE(schemaOpt.ok()); auto schema = *schemaOpt; - REQUIRE(schema->num_fields() == 7); + REQUIRE(schema->num_fields() == 9); REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id()); REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id()); REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id()); @@ -347,6 +436,9 @@ TEST_CASE("RootTree2Dataset") REQUIRE(schema->field(4)->type()->id() == arrow::int32()->id()); REQUIRE(schema->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id()); REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id()); + REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id()); + REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id()); + auto fragment = format->MakeFragment(source, {}, schema); REQUIRE(fragment.ok()); auto options = std::make_shared(); @@ -356,7 +448,7 @@ TEST_CASE("RootTree2Dataset") auto batches = (*scanner)(); auto result = batches.result(); REQUIRE(result.ok()); - REQUIRE((*result)->columns().size() == 7); + REQUIRE((*result)->columns().size() == 9); REQUIRE((*result)->num_rows() == 100); { @@ -394,14 +486,16 @@ TEST_CASE("RootTree2Dataset") auto* output = new TMemFile("foo", "RECREATE"); auto outFs = std::make_shared(output, 0); - arrow::fs::FileLocator locator{outFs, "/DF_3"}; - auto destination = outFs->OpenOutputStream(locator.path, {}); + // Open a stream at toplevel + auto destination = outFs->OpenOutputStream("/", {}); REQUIRE(destination.ok()); + // Write to the /DF_3 tree at top level + arrow::fs::FileLocator locator{outFs, "/DF_3"}; auto writer = format->MakeWriter(*destination, schema, {}, locator); auto success = writer->get()->Write(*result); - auto rootDestination = std::dynamic_pointer_cast(*destination); + auto rootDestination = std::dynamic_pointer_cast(*destination); REQUIRE(success.ok()); // Let's read it back... @@ -413,14 +507,7 @@ TEST_CASE("RootTree2Dataset") auto schemaOptWritten = format->Inspect(source); REQUIRE(schemaOptWritten.ok()); auto schemaWritten = *schemaOptWritten; - REQUIRE(schemaWritten->num_fields() == 7); - REQUIRE(schemaWritten->field(0)->type()->id() == arrow::float32()->id()); - REQUIRE(schemaWritten->field(1)->type()->id() == arrow::float32()->id()); - REQUIRE(schemaWritten->field(2)->type()->id() == arrow::float32()->id()); - REQUIRE(schemaWritten->field(3)->type()->id() == arrow::float64()->id()); - REQUIRE(schemaWritten->field(4)->type()->id() == arrow::int32()->id()); - REQUIRE(schemaWritten->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id()); - REQUIRE(schemaWritten->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id()); + REQUIRE(validateSchema(schemaWritten)); auto fragmentWritten = format->MakeFragment(source, {}, schema); REQUIRE(fragmentWritten.ok()); @@ -431,39 +518,10 @@ TEST_CASE("RootTree2Dataset") auto batchesWritten = (*scanner)(); auto resultWritten = batches.result(); REQUIRE(resultWritten.ok()); - REQUIRE((*resultWritten)->columns().size() == 7); + REQUIRE((*resultWritten)->columns().size() == 9); REQUIRE((*resultWritten)->num_rows() == 100); + validateContents(*resultWritten); { - auto int_array = std::static_pointer_cast((*resultWritten)->GetColumnByName("ev")); - for (int64_t j = 0; j < int_array->length(); j++) { - REQUIRE(int_array->Value(j) == j + 1); - } - } - - { - auto list_array = std::static_pointer_cast((*result)->GetColumnByName("xyz")); - - // Iterate over the FixedSizeListArray - for (int64_t i = 0; i < list_array->length(); i++) { - auto value_slice = list_array->value_slice(i); - auto float_array = std::static_pointer_cast(value_slice); - - REQUIRE(float_array->Value(0) == 1); - REQUIRE(float_array->Value(1) == 2); - REQUIRE(float_array->Value(2) == i + 1); - } - } - - { - auto list_array = std::static_pointer_cast((*result)->GetColumnByName("ij")); - - // Iterate over the FixedSizeListArray - for (int64_t i = 0; i < list_array->length(); i++) { - auto value_slice = list_array->value_slice(i); - auto int_array = std::static_pointer_cast(value_slice); - REQUIRE(int_array->Value(0) == i); - REQUIRE(int_array->Value(1) == i + 1); - } } }