diff --git a/Framework/Core/src/RootArrowFilesystem.cxx b/Framework/Core/src/RootArrowFilesystem.cxx index 7e331814272a6..5f2d21d942d37 100644 --- a/Framework/Core/src/RootArrowFilesystem.cxx +++ b/Framework/Core/src/RootArrowFilesystem.cxx @@ -13,6 +13,7 @@ #include "Framework/RuntimeError.h" #include "Framework/Signpost.h" #include +#include #include #include #include @@ -427,7 +428,7 @@ class TTreeFileWriter : public arrow::dataset::FileWriter O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry.", branch->GetName(), valueSize); // This should probably lookup the - auto column = firstBatch->GetColumnByName(branch->GetName()); + auto column = firstBatch->GetColumnByName(schema_->field(i)->name()); auto list = std::static_pointer_cast(column); O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. Associated size branch %s and there are %lli entries of size %d in that list.", branch->GetName(), sizeBranch->GetName(), list->length(), valueSize); @@ -497,8 +498,8 @@ class TTreeFileWriter : public arrow::dataset::FileWriter } break; case arrow::Type::LIST: { valueTypes.push_back(field->type()->field(0)->type()); - listSizes.back() = 0; // VLA, we need to calculate it on the fly; std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id())); + listSizes.back() = -1; // VLA, we need to calculate it on the fly; std::string sizeLeafList = field->name() + "_size/I"; sizesBranches.push_back(treeStream->CreateBranch((field->name() + "_size").c_str(), sizeLeafList.c_str())); branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); @@ -765,7 +766,7 @@ arrow::Result TTreeFileFormat::ScanBatchesAsync( typeSize = fixedSizeList->field(0)->type()->byte_width(); } else if (auto vlaListType = std::dynamic_pointer_cast(physicalField->type())) { listSize = -1; - typeSize = fixedSizeList->field(0)->type()->byte_width(); + typeSize = vlaListType->field(0)->type()->byte_width(); } if (listSize == -1) { mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str()); diff --git a/Framework/Core/test/test_Root2ArrowTable.cxx b/Framework/Core/test/test_Root2ArrowTable.cxx index a659d488ae24a..2b0ab9154250c 100644 --- a/Framework/Core/test/test_Root2ArrowTable.cxx +++ b/Framework/Core/test/test_Root2ArrowTable.cxx @@ -322,12 +322,26 @@ bool validateContents(std::shared_ptr batch) REQUIRE(bool_array->Value(1) == (i % 5 == 0)); } } + + { + auto list_array = std::static_pointer_cast(batch->GetColumnByName("vla")); + + REQUIRE(list_array->length() == 100); + for (int64_t i = 0; i < list_array->length(); i++) { + auto value_slice = list_array->value_slice(i); + REQUIRE(value_slice->length() == (i % 10)); + auto int_array = std::static_pointer_cast(value_slice); + for (size_t j = 0; j < value_slice->length(); j++) { + REQUIRE(int_array->Value(j) == j); + } + } + } return true; } bool validateSchema(std::shared_ptr schema) { - REQUIRE(schema->num_fields() == 9); + REQUIRE(schema->num_fields() == 10); REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id()); REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id()); REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id()); @@ -337,6 +351,7 @@ bool validateSchema(std::shared_ptr schema) REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id()); REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id()); REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id()); + REQUIRE(schema->field(9)->type()->id() == arrow::list(arrow::int32())->id()); return true; } @@ -390,6 +405,8 @@ TEST_CASE("RootTree2Dataset") Int_t ev; bool oneBool; bool manyBool[2]; + int vla[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + int vlaSize = 0; t->Branch("px", &px, "px/F"); t->Branch("py", &py, "py/F"); @@ -400,6 +417,8 @@ TEST_CASE("RootTree2Dataset") t->Branch("ij", ij, "ij[2]/I"); t->Branch("bools", &oneBool, "bools/O"); t->Branch("manyBools", &manyBool, "manyBools[2]/O"); + t->Branch("vla_size", &vlaSize, "vla_size/I"); + t->Branch("vla", vla, "vla[vla_size]/I"); // fill the tree for (Int_t i = 0; i < 100; i++) { xyz[0] = 1; @@ -415,9 +434,11 @@ TEST_CASE("RootTree2Dataset") oneBool = (i % 3 == 0); manyBool[0] = (i % 4 == 0); manyBool[1] = (i % 5 == 0); + vlaSize = i % 10; t->Fill(); } } + f->Write(); size_t totalSizeCompressed = 0; size_t totalSizeUncompressed = 0; @@ -428,16 +449,7 @@ TEST_CASE("RootTree2Dataset") auto schemaOpt = format->Inspect(source); REQUIRE(schemaOpt.ok()); auto schema = *schemaOpt; - REQUIRE(schema->num_fields() == 9); - REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id()); - REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id()); - REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id()); - REQUIRE(schema->field(3)->type()->id() == arrow::float64()->id()); - REQUIRE(schema->field(4)->type()->id() == arrow::int32()->id()); - REQUIRE(schema->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id()); - REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id()); - REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id()); - REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id()); + validateSchema(schema); auto fragment = format->MakeFragment(source, {}, schema); REQUIRE(fragment.ok()); @@ -448,41 +460,9 @@ TEST_CASE("RootTree2Dataset") auto batches = (*scanner)(); auto result = batches.result(); REQUIRE(result.ok()); - REQUIRE((*result)->columns().size() == 9); + REQUIRE((*result)->columns().size() == 10); REQUIRE((*result)->num_rows() == 100); - - { - auto int_array = std::static_pointer_cast((*result)->GetColumnByName("ev")); - for (int64_t j = 0; j < int_array->length(); j++) { - REQUIRE(int_array->Value(j) == j + 1); - } - } - - { - auto list_array = std::static_pointer_cast((*result)->GetColumnByName("xyz")); - - // Iterate over the FixedSizeListArray - for (int64_t i = 0; i < list_array->length(); i++) { - auto value_slice = list_array->value_slice(i); - auto float_array = std::static_pointer_cast(value_slice); - - REQUIRE(float_array->Value(0) == 1); - REQUIRE(float_array->Value(1) == 2); - REQUIRE(float_array->Value(2) == i + 1); - } - } - - { - auto list_array = std::static_pointer_cast((*result)->GetColumnByName("ij")); - - // Iterate over the FixedSizeListArray - for (int64_t i = 0; i < list_array->length(); i++) { - auto value_slice = list_array->value_slice(i); - auto int_array = std::static_pointer_cast(value_slice); - REQUIRE(int_array->Value(0) == i); - REQUIRE(int_array->Value(1) == i + 1); - } - } + validateContents(*result); auto* output = new TMemFile("foo", "RECREATE"); auto outFs = std::make_shared(output, 0); @@ -497,31 +477,31 @@ TEST_CASE("RootTree2Dataset") auto success = writer->get()->Write(*result); auto rootDestination = std::dynamic_pointer_cast(*destination); - REQUIRE(success.ok()); - // Let's read it back... - arrow::dataset::FileSource source2("/DF_3", outFs); - auto newTreeFS = outFs->GetSubFilesystem(source2); - - REQUIRE(format->IsSupported(source) == true); - - auto schemaOptWritten = format->Inspect(source); - REQUIRE(schemaOptWritten.ok()); - auto schemaWritten = *schemaOptWritten; - REQUIRE(validateSchema(schemaWritten)); - - auto fragmentWritten = format->MakeFragment(source, {}, schema); - REQUIRE(fragmentWritten.ok()); - auto optionsWritten = std::make_shared(); - options->dataset_schema = schemaWritten; - auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment); - REQUIRE(scannerWritten.ok()); - auto batchesWritten = (*scanner)(); - auto resultWritten = batches.result(); - REQUIRE(resultWritten.ok()); - REQUIRE((*resultWritten)->columns().size() == 9); - REQUIRE((*resultWritten)->num_rows() == 100); - validateContents(*resultWritten); - + SECTION("Read tree") { + REQUIRE(success.ok()); + // Let's read it back... + arrow::dataset::FileSource source2("/DF_3", outFs); + auto newTreeFS = outFs->GetSubFilesystem(source2); + + REQUIRE(format->IsSupported(source) == true); + + auto schemaOptWritten = format->Inspect(source); + REQUIRE(schemaOptWritten.ok()); + auto schemaWritten = *schemaOptWritten; + REQUIRE(validateSchema(schemaWritten)); + + auto fragmentWritten = format->MakeFragment(source, {}, schema); + REQUIRE(fragmentWritten.ok()); + auto optionsWritten = std::make_shared(); + options->dataset_schema = schemaWritten; + auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment); + REQUIRE(scannerWritten.ok()); + auto batchesWritten = (*scanner)(); + auto resultWritten = batches.result(); + REQUIRE(resultWritten.ok()); + REQUIRE((*resultWritten)->columns().size() == 10); + REQUIRE((*resultWritten)->num_rows() == 100); + validateContents(*resultWritten); } }