Skip to content

Commit

Permalink
DPL Analysis: improve arrow::Dataset support for TTree (AliceO2Group#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ktf authored Dec 2, 2024
1 parent feea3ad commit 950b8b7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 73 deletions.
7 changes: 4 additions & 3 deletions Framework/Core/src/RootArrowFilesystem.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "Framework/RuntimeError.h"
#include "Framework/Signpost.h"
#include <Rtypes.h>
#include <arrow/array/array_nested.h>
#include <arrow/array/array_primitive.h>
#include <arrow/array/builder_nested.h>
#include <arrow/array/builder_primitive.h>
Expand Down Expand Up @@ -427,7 +428,7 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry.",
branch->GetName(), valueSize);
// This should probably lookup the
auto column = firstBatch->GetColumnByName(branch->GetName());
auto column = firstBatch->GetColumnByName(schema_->field(i)->name());
auto list = std::static_pointer_cast<arrow::ListArray>(column);
O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. Associated size branch %s and there are %lli entries of size %d in that list.",
branch->GetName(), sizeBranch->GetName(), list->length(), valueSize);
Expand Down Expand Up @@ -497,8 +498,8 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
} break;
case arrow::Type::LIST: {
valueTypes.push_back(field->type()->field(0)->type());
listSizes.back() = 0; // VLA, we need to calculate it on the fly;
std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id()));
listSizes.back() = -1; // VLA, we need to calculate it on the fly;
std::string sizeLeafList = field->name() + "_size/I";
sizesBranches.push_back(treeStream->CreateBranch((field->name() + "_size").c_str(), sizeLeafList.c_str()));
branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str()));
Expand Down Expand Up @@ -765,7 +766,7 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
typeSize = fixedSizeList->field(0)->type()->byte_width();
} else if (auto vlaListType = std::dynamic_pointer_cast<arrow::ListType>(physicalField->type())) {
listSize = -1;
typeSize = fixedSizeList->field(0)->type()->byte_width();
typeSize = vlaListType->field(0)->type()->byte_width();
}
if (listSize == -1) {
mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str());
Expand Down
120 changes: 50 additions & 70 deletions Framework/Core/test/test_Root2ArrowTable.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,26 @@ bool validateContents(std::shared_ptr<arrow::RecordBatch> batch)
REQUIRE(bool_array->Value(1) == (i % 5 == 0));
}
}

{
auto list_array = std::static_pointer_cast<arrow::ListArray>(batch->GetColumnByName("vla"));

REQUIRE(list_array->length() == 100);
for (int64_t i = 0; i < list_array->length(); i++) {
auto value_slice = list_array->value_slice(i);
REQUIRE(value_slice->length() == (i % 10));
auto int_array = std::static_pointer_cast<arrow::Int32Array>(value_slice);
for (size_t j = 0; j < value_slice->length(); j++) {
REQUIRE(int_array->Value(j) == j);
}
}
}
return true;
}

bool validateSchema(std::shared_ptr<arrow::Schema> schema)
{
REQUIRE(schema->num_fields() == 9);
REQUIRE(schema->num_fields() == 10);
REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id());
Expand All @@ -337,6 +351,7 @@ bool validateSchema(std::shared_ptr<arrow::Schema> schema)
REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id());
REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id());
REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id());
REQUIRE(schema->field(9)->type()->id() == arrow::list(arrow::int32())->id());
return true;
}

Expand Down Expand Up @@ -390,6 +405,8 @@ TEST_CASE("RootTree2Dataset")
Int_t ev;
bool oneBool;
bool manyBool[2];
int vla[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int vlaSize = 0;

t->Branch("px", &px, "px/F");
t->Branch("py", &py, "py/F");
Expand All @@ -400,6 +417,8 @@ TEST_CASE("RootTree2Dataset")
t->Branch("ij", ij, "ij[2]/I");
t->Branch("bools", &oneBool, "bools/O");
t->Branch("manyBools", &manyBool, "manyBools[2]/O");
t->Branch("vla_size", &vlaSize, "vla_size/I");
t->Branch("vla", vla, "vla[vla_size]/I");
// fill the tree
for (Int_t i = 0; i < 100; i++) {
xyz[0] = 1;
Expand All @@ -415,9 +434,11 @@ TEST_CASE("RootTree2Dataset")
oneBool = (i % 3 == 0);
manyBool[0] = (i % 4 == 0);
manyBool[1] = (i % 5 == 0);
vlaSize = i % 10;
t->Fill();
}
}
f->Write();

size_t totalSizeCompressed = 0;
size_t totalSizeUncompressed = 0;
Expand All @@ -428,16 +449,7 @@ TEST_CASE("RootTree2Dataset")
auto schemaOpt = format->Inspect(source);
REQUIRE(schemaOpt.ok());
auto schema = *schemaOpt;
REQUIRE(schema->num_fields() == 9);
REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(3)->type()->id() == arrow::float64()->id());
REQUIRE(schema->field(4)->type()->id() == arrow::int32()->id());
REQUIRE(schema->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id());
REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id());
REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id());
REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id());
validateSchema(schema);

auto fragment = format->MakeFragment(source, {}, schema);
REQUIRE(fragment.ok());
Expand All @@ -448,41 +460,9 @@ TEST_CASE("RootTree2Dataset")
auto batches = (*scanner)();
auto result = batches.result();
REQUIRE(result.ok());
REQUIRE((*result)->columns().size() == 9);
REQUIRE((*result)->columns().size() == 10);
REQUIRE((*result)->num_rows() == 100);

{
auto int_array = std::static_pointer_cast<arrow::Int32Array>((*result)->GetColumnByName("ev"));
for (int64_t j = 0; j < int_array->length(); j++) {
REQUIRE(int_array->Value(j) == j + 1);
}
}

{
auto list_array = std::static_pointer_cast<arrow::FixedSizeListArray>((*result)->GetColumnByName("xyz"));

// Iterate over the FixedSizeListArray
for (int64_t i = 0; i < list_array->length(); i++) {
auto value_slice = list_array->value_slice(i);
auto float_array = std::static_pointer_cast<arrow::FloatArray>(value_slice);

REQUIRE(float_array->Value(0) == 1);
REQUIRE(float_array->Value(1) == 2);
REQUIRE(float_array->Value(2) == i + 1);
}
}

{
auto list_array = std::static_pointer_cast<arrow::FixedSizeListArray>((*result)->GetColumnByName("ij"));

// Iterate over the FixedSizeListArray
for (int64_t i = 0; i < list_array->length(); i++) {
auto value_slice = list_array->value_slice(i);
auto int_array = std::static_pointer_cast<arrow::Int32Array>(value_slice);
REQUIRE(int_array->Value(0) == i);
REQUIRE(int_array->Value(1) == i + 1);
}
}
validateContents(*result);

auto* output = new TMemFile("foo", "RECREATE");
auto outFs = std::make_shared<TFileFileSystem>(output, 0);
Expand All @@ -497,31 +477,31 @@ TEST_CASE("RootTree2Dataset")
auto success = writer->get()->Write(*result);
auto rootDestination = std::dynamic_pointer_cast<TDirectoryFileOutputStream>(*destination);

REQUIRE(success.ok());
// Let's read it back...
arrow::dataset::FileSource source2("/DF_3", outFs);
auto newTreeFS = outFs->GetSubFilesystem(source2);

REQUIRE(format->IsSupported(source) == true);

auto schemaOptWritten = format->Inspect(source);
REQUIRE(schemaOptWritten.ok());
auto schemaWritten = *schemaOptWritten;
REQUIRE(validateSchema(schemaWritten));

auto fragmentWritten = format->MakeFragment(source, {}, schema);
REQUIRE(fragmentWritten.ok());
auto optionsWritten = std::make_shared<arrow::dataset::ScanOptions>();
options->dataset_schema = schemaWritten;
auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment);
REQUIRE(scannerWritten.ok());
auto batchesWritten = (*scanner)();
auto resultWritten = batches.result();
REQUIRE(resultWritten.ok());
REQUIRE((*resultWritten)->columns().size() == 9);
REQUIRE((*resultWritten)->num_rows() == 100);
validateContents(*resultWritten);

SECTION("Read tree")
{
REQUIRE(success.ok());
// Let's read it back...
arrow::dataset::FileSource source2("/DF_3", outFs);
auto newTreeFS = outFs->GetSubFilesystem(source2);

REQUIRE(format->IsSupported(source) == true);

auto schemaOptWritten = format->Inspect(source);
REQUIRE(schemaOptWritten.ok());
auto schemaWritten = *schemaOptWritten;
REQUIRE(validateSchema(schemaWritten));

auto fragmentWritten = format->MakeFragment(source, {}, schema);
REQUIRE(fragmentWritten.ok());
auto optionsWritten = std::make_shared<arrow::dataset::ScanOptions>();
options->dataset_schema = schemaWritten;
auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment);
REQUIRE(scannerWritten.ok());
auto batchesWritten = (*scanner)();
auto resultWritten = batches.result();
REQUIRE(resultWritten.ok());
REQUIRE((*resultWritten)->columns().size() == 10);
REQUIRE((*resultWritten)->num_rows() == 100);
validateContents(*resultWritten);
}
}

0 comments on commit 950b8b7

Please sign in to comment.