Skip to content

Commit

Permalink
DPL: improve handling of RNTuple
Browse files Browse the repository at this point in the history
- Support more integer types, including tests.
- Add ability to support objects which are not grouped in a TDirectory
  • Loading branch information
ktf committed Dec 19, 2024
1 parent 6fa29aa commit 331d446
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 7 deletions.
27 changes: 27 additions & 0 deletions Framework/AnalysisSupport/src/RNTuplePlugin.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,31 @@ struct RootNTupleVisitor : public ROOT::Experimental::Detail::RFieldVisitor {
this->datatype = arrow::int32();
}

void VisitInt8Field(const ROOT::Experimental::RField<std::int8_t>& field) override
{
this->datatype = arrow::int8();
}

void VisitInt16Field(const ROOT::Experimental::RField<std::int16_t>& field) override
{
this->datatype = arrow::int16();
}

void VisitUInt32Field(const ROOT::Experimental::RField<std::uint32_t>& field) override
{
this->datatype = arrow::uint32();
}

void VisitUInt8Field(const ROOT::Experimental::RField<std::uint8_t>& field) override
{
this->datatype = arrow::uint8();
}

void VisitUInt16Field(const ROOT::Experimental::RField<std::uint16_t>& field) override
{
this->datatype = arrow::int16();
}

void VisitBoolField(const ROOT::Experimental::RField<bool>& field) override
{
this->datatype = arrow::boolean();
Expand Down Expand Up @@ -240,6 +265,8 @@ std::unique_ptr<ROOT::Experimental::RFieldBase> rootFieldFromArrow(std::shared_p
return std::make_unique<RField<float>>(name);
case arrow::Type::DOUBLE:
return std::make_unique<RField<double>>(name);
case arrow::Type::STRING:
return std::make_unique<RField<std::string>>(name);
default:
throw runtime_error("Unsupported arrow column type");
}
Expand Down
5 changes: 5 additions & 0 deletions Framework/Core/include/Framework/RootArrowFilesystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ struct RootArrowFactoryPlugin {
struct RootObjectReadingCapability {
// The unique name of this capability
std::string name = "unknown";
// Convert a logical filename to an actual object to be read
// This can be used, e.g. to read an RNTuple stored in
// a flat directory structure in a TFile vs a TTree stored inside
// a TDirectory (e.g. /DF_1000/o2tracks).
std::function<std::string(std::string)> lfn2objectPath;
// Given a TFile, return the object which this capability support
// Use a void * in order not to expose the kind of object to the
// generic reading code. This is also where we load the plugin
Expand Down
14 changes: 12 additions & 2 deletions Framework/Core/src/Plugin.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ struct ImplementationContext {

std::function<void*(TDirectoryFile*, std::string const&)> getHandleByClass(char const* classname)
{
return [classname](TDirectoryFile* file, std::string const& path) { return file->GetObjectChecked(path.c_str(), TClass::GetClass(classname)); };
return [c = TClass::GetClass(classname)](TDirectoryFile* file, std::string const& path) {

return file->GetObjectChecked(path.c_str(), c); };
}

std::function<void*(TBufferFile*, std::string const&)> getBufferHandleByClass(char const* classname)
{
return [classname](TBufferFile* buffer, std::string const& path) { buffer->Reset(); return buffer->ReadObjectAny(TClass::GetClass(classname)); };
return [c = TClass::GetClass(classname)](TBufferFile* buffer, std::string const& path) { buffer->Reset(); return buffer->ReadObjectAny(c); };
}

void lazyLoadFactory(std::vector<RootArrowFactory>& implementations, char const* specs)
Expand All @@ -210,6 +212,13 @@ struct RNTupleObjectReadingCapability : o2::framework::RootObjectReadingCapabili

return new RootObjectReadingCapability{
.name = "rntuple",
.lfn2objectPath = [](std::string s) {
std::replace(s.begin()+1, s.end(), '/', '-');
if (s.starts_with("/")) {
return s;
} else {
return "/" + s;
} },
.getHandle = getHandleByClass("ROOT::Experimental::RNTuple"),
.getBufferHandle = getBufferHandleByClass("ROOT::Experimental::RNTuple"),
.factory = [context]() -> RootArrowFactory& {
Expand All @@ -226,6 +235,7 @@ struct TTreeObjectReadingCapability : o2::framework::RootObjectReadingCapability

return new RootObjectReadingCapability{
.name = "ttree",
.lfn2objectPath = [](std::string s) { return s; },
.getHandle = getHandleByClass("TTree"),
.getBufferHandle = getBufferHandleByClass("TTree"),
.factory = [context]() -> RootArrowFactory& {
Expand Down
4 changes: 3 additions & 1 deletion Framework/Core/src/RootArrowFilesystem.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ std::shared_ptr<VirtualRootFileSystemBase> TFileFileSystem::GetSubFilesystem(arr
// file, so that we can support TTree and RNTuple at the same time
// without having to depend on both.
for (auto& capability : mObjectFactory.capabilities) {
void* handle = capability.getHandle(mFile, source.path());
auto objectPath = capability.lfn2objectPath(source.path());
void* handle = capability.getHandle(mFile, objectPath);
if (!handle) {
continue;
}
Expand Down Expand Up @@ -238,6 +239,7 @@ std::shared_ptr<VirtualRootFileSystemBase> TBufferFileFS::GetSubFilesystem(arrow
// file, so that we can support TTree and RNTuple at the same time
// without having to depend on both.
for (auto& capability : mObjectFactory.capabilities) {

void* handle = capability.getBufferHandle(mBuffer, source.path());
if (handle) {
mFilesystem = capability.factory().getSubFilesystem(handle);
Expand Down
12 changes: 8 additions & 4 deletions Framework/Core/test/test_Root2ArrowTable.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ bool validateContents(std::shared_ptr<arrow::RecordBatch> batch)

bool validateSchema(std::shared_ptr<arrow::Schema> schema)
{
REQUIRE(schema->num_fields() == 10);
REQUIRE(schema->num_fields() == 11);
REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id());
Expand All @@ -380,6 +380,7 @@ bool validateSchema(std::shared_ptr<arrow::Schema> schema)
REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id());
REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id());
REQUIRE(schema->field(9)->type()->id() == arrow::list(arrow::int32())->id());
REQUIRE(schema->field(10)->type()->id() == arrow::int8()->id());
return true;
}

Expand Down Expand Up @@ -435,6 +436,7 @@ TEST_CASE("RootTree2Dataset")
bool manyBool[2];
int vla[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int vlaSize = 0;
char byte;

t->Branch("px", &px, "px/F");
t->Branch("py", &py, "py/F");
Expand All @@ -447,6 +449,7 @@ TEST_CASE("RootTree2Dataset")
t->Branch("manyBools", &manyBool, "manyBools[2]/O");
t->Branch("vla_size", &vlaSize, "vla_size/I");
t->Branch("vla", vla, "vla[vla_size]/I");
t->Branch("byte", &byte, "byte/B");
// fill the tree
for (Int_t i = 0; i < 100; i++) {
xyz[0] = 1;
Expand All @@ -463,6 +466,7 @@ TEST_CASE("RootTree2Dataset")
manyBool[0] = (i % 4 == 0);
manyBool[1] = (i % 5 == 0);
vlaSize = i % 10;
byte = i;
t->Fill();
}
}
Expand Down Expand Up @@ -512,7 +516,7 @@ TEST_CASE("RootTree2Dataset")
auto batches = (*scanner)();
auto result = batches.result();
REQUIRE(result.ok());
REQUIRE((*result)->columns().size() == 10);
REQUIRE((*result)->columns().size() == 11);
REQUIRE((*result)->num_rows() == 100);
validateContents(*result);

Expand Down Expand Up @@ -552,7 +556,7 @@ TEST_CASE("RootTree2Dataset")
auto batchesWritten = (*scanner)();
auto resultWritten = batches.result();
REQUIRE(resultWritten.ok());
REQUIRE((*resultWritten)->columns().size() == 10);
REQUIRE((*resultWritten)->columns().size() == 11);
REQUIRE((*resultWritten)->num_rows() == 100);
validateContents(*resultWritten);
}
Expand Down Expand Up @@ -586,7 +590,7 @@ TEST_CASE("RootTree2Dataset")
auto rntupleBatchesWritten = (*rntupleScannerWritten)();
auto rntupleResultWritten = rntupleBatchesWritten.result();
REQUIRE(rntupleResultWritten.ok());
REQUIRE((*rntupleResultWritten)->columns().size() == 10);
REQUIRE((*rntupleResultWritten)->columns().size() == 11);
REQUIRE(validateSchema((*rntupleResultWritten)->schema()));
REQUIRE((*rntupleResultWritten)->num_rows() == 100);
REQUIRE(validateContents(*rntupleResultWritten));
Expand Down

0 comments on commit 331d446

Please sign in to comment.