Skip to content

Commit

Permalink
DPL: improve arrow::Dataset integration
Browse files Browse the repository at this point in the history
- Modularise filesystem to allow easier navigation and support for
  multiple formats.
- Add initial support to multiplex multiple tables on top of the same tree.
- Improve support for writing boolean fields.
  • Loading branch information
ktf committed Nov 24, 2024
1 parent c3ffb66 commit cc0b2f8
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 85 deletions.
36 changes: 34 additions & 2 deletions Framework/Core/include/Framework/RootArrowFilesystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <arrow/type_fwd.h>
#include <memory>

class TFile;
class TBranch;
class TTree;
class TBufferFile;
class TDirectoryFile;
Expand Down Expand Up @@ -227,11 +229,38 @@ class TTreeFileFormat : public arrow::dataset::FileFormat
const std::shared_ptr<arrow::dataset::FileFragment>& fragment) const override;
};

// An arrow outputstream which allows to write to a ttree
// An arrow outputstream which allows to write to a TDirectoryFile.
// This will point to the location of the file itself. You can
// specify the location of the actual object inside it by passing the
// associated path to the Write() API.
class TDirectoryFileOutputStream : public arrow::io::OutputStream
{
public:
TDirectoryFileOutputStream(TDirectoryFile*);

arrow::Status Close() override;

arrow::Result<int64_t> Tell() const override;

arrow::Status Write(const void* data, int64_t nbytes) override;

bool closed() const override;

TDirectoryFile* GetDirectory()
{
return mDirectory;
}

private:
TDirectoryFile* mDirectory;
};

// An arrow outputstream which allows to write to a TTree. Eventually
// with a prefix for the branches.
class TTreeOutputStream : public arrow::io::OutputStream
{
public:
TTreeOutputStream(TTree* t);
TTreeOutputStream(TTree*, std::string branchPrefix);

arrow::Status Close() override;

Expand All @@ -241,13 +270,16 @@ class TTreeOutputStream : public arrow::io::OutputStream

bool closed() const override;

TBranch* CreateBranch(char const* branchName, char const* sizeBranch);

TTree* GetTree()
{
return mTree;
}

private:
TTree* mTree;
std::string mBranchPrefix;
};

} // namespace o2::framework
Expand Down
153 changes: 114 additions & 39 deletions Framework/Core/src/RootArrowFilesystem.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <arrow/array/builder_nested.h>
#include <arrow/array/builder_primitive.h>
#include <memory>
#include <stdexcept>
#include <TFile.h>
#include <TLeaf.h>
#include <TBufferFile.h>
Expand All @@ -28,8 +27,11 @@
#include <arrow/dataset/file_base.h>
#include <arrow/result.h>
#include <arrow/status.h>
#include <arrow/util/key_value_metadata.h>
#include <fmt/format.h>

#include <stdexcept>
#include <utility>

O2_DECLARE_DYNAMIC_LOG(root_arrow_fs);

Expand Down Expand Up @@ -100,7 +102,6 @@ std::shared_ptr<VirtualRootFileSystemBase> TFileFileSystem::GetSubFilesystem(arr
return std::shared_ptr<VirtualRootFileSystemBase>(new SingleTreeFileSystem(tree));
}


auto directory = (TDirectoryFile*)mFile->GetObjectChecked(source.path().c_str(), TClass::GetClass<TDirectory>());
if (directory) {
return std::shared_ptr<VirtualRootFileSystemBase>(new TFileFileSystem(directory, 50 * 1024 * 1024));
Expand Down Expand Up @@ -129,8 +130,15 @@ arrow::Result<std::shared_ptr<arrow::io::OutputStream>> TFileFileSystem::OpenOut
const std::string& path,
const std::shared_ptr<const arrow::KeyValueMetadata>& metadata)
{
auto* t = new TTree(path.c_str(), "should put a name here");
auto stream = std::make_shared<TTreeOutputStream>(t);
if (path == "/") {
return std::make_shared<TDirectoryFileOutputStream>(this->GetFile());
}

auto* dir = dynamic_cast<TDirectoryFile*>(this->GetFile()->Get(path.c_str()));
if (!dir) {
throw runtime_error_f("Unable to open directory %s in file %s", path.c_str(), GetFile()->GetName());
}
auto stream = std::make_shared<TDirectoryFileOutputStream>(dir);
return stream;
}

Expand Down Expand Up @@ -286,13 +294,46 @@ arrow::Result<std::shared_ptr<arrow::dataset::FileFragment>> TTreeFileFormat::Ma
}

// An arrow outputstream which allows to write to a ttree
TTreeOutputStream::TTreeOutputStream(TTree* t)
: mTree(t)
TDirectoryFileOutputStream::TDirectoryFileOutputStream(TDirectoryFile* f)
: mDirectory(f)
{
}

arrow::Status TDirectoryFileOutputStream::Close()
{
mDirectory->GetFile()->Close();
return arrow::Status::OK();
}

arrow::Result<int64_t> TDirectoryFileOutputStream::Tell() const
{
return arrow::Result<int64_t>(arrow::Status::NotImplemented("Cannot move"));
}

arrow::Status TDirectoryFileOutputStream::Write(const void* data, int64_t nbytes)
{
return arrow::Status::NotImplemented("Cannot write raw bytes to a TTree");
}

bool TDirectoryFileOutputStream::closed() const
{
return mDirectory->GetFile()->IsOpen() == false;
}

// An arrow outputstream which allows to write to a ttree
// @a branch prefix is to be used to identify a set of branches which all belong to
// the same table.
TTreeOutputStream::TTreeOutputStream(TTree* f, std::string branchPrefix)
: mTree(f),
mBranchPrefix(std::move(branchPrefix))
{
}

arrow::Status TTreeOutputStream::Close()
{
if (mTree->GetCurrentFile() == nullptr) {
return arrow::Status::Invalid("Cannot close a tree not attached to a file");
}
mTree->GetCurrentFile()->Close();
return arrow::Status::OK();
}
Expand All @@ -309,9 +350,18 @@ arrow::Status TTreeOutputStream::Write(const void* data, int64_t nbytes)

bool TTreeOutputStream::closed() const
{
// A standalone tree is never closed.
if (mTree->GetCurrentFile() == nullptr) {
return false;
}
return mTree->GetCurrentFile()->IsOpen() == false;
}

TBranch* TTreeOutputStream::CreateBranch(char const* branchName, char const* sizeBranch)
{
return mTree->Branch((mBranchPrefix + "/" + branchName).c_str(), (char*)nullptr, (mBranchPrefix + sizeBranch).c_str());
}

char const* rootSuffixFromArrow(arrow::Type::type id)
{
switch (id) {
Expand Down Expand Up @@ -411,8 +461,24 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
: FileWriter(schema, options, destination, destination_locator)
{
// Batches have the same number of entries for each column.
auto directoryStream = std::dynamic_pointer_cast<TDirectoryFileOutputStream>(destination_);
auto treeStream = std::dynamic_pointer_cast<TTreeOutputStream>(destination_);
TTree* tree = treeStream->GetTree();

if (directoryStream.get()) {
TDirectoryFile* dir = directoryStream->GetDirectory();
dir->cd();
auto* tree = new TTree(destination_locator_.path.c_str(), "");
treeStream = std::make_shared<TTreeOutputStream>(tree, "");
} else if (treeStream.get()) {
// We already have a tree stream, let's derive a new one
// with the destination_locator_.path as prefix for the branches
// This way we can multiplex multiple tables in the same tree.
auto tree = treeStream->GetTree();
treeStream = std::make_shared<TTreeOutputStream>(tree, destination_locator_.path);
} else {
// I could simply set a prefix here to merge to an already existing tree.
throw std::runtime_error("Unsupported backend.");
}

for (auto i = 0u; i < schema->fields().size(); ++i) {
auto& field = schema->field(i);
Expand All @@ -427,23 +493,23 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
valueTypes.push_back(field->type()->field(0)->type());
sizesBranches.push_back(nullptr);
std::string leafList = fmt::format("{}[{}]{}", field->name(), listSizes.back(), rootSuffixFromArrow(valueTypes.back()->id()));
branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str()));
branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str()));
} break;
case arrow::Type::LIST: {
valueTypes.push_back(field->type()->field(0)->type());
listSizes.back() = 0; // VLA, we need to calculate it on the fly;
std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id()));
std::string sizeLeafList = field->name() + "_size/I";
sizesBranches.push_back(tree->Branch((field->name() + "_size").c_str(), (char*)nullptr, sizeLeafList.c_str()));
branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str()));
sizesBranches.push_back(treeStream->CreateBranch((field->name() + "_size").c_str(), sizeLeafList.c_str()));
branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str()));
// Notice that this could be replaced by a better guess of the
// average size of the list elements, but this is not trivial.
} break;
default: {
valueTypes.push_back(field->type());
std::string leafList = field->name() + rootSuffixFromArrow(valueTypes.back()->id());
sizesBranches.push_back(nullptr);
branches.push_back(tree->Branch(field->name().c_str(), (char*)nullptr, leafList.c_str()));
branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str()));
} break;
}
}
Expand All @@ -463,11 +529,18 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
}

// Batches have the same number of entries for each column.
auto directoryStream = std::dynamic_pointer_cast<TDirectoryFileOutputStream>(destination_);
TTree* tree = nullptr;
if (directoryStream.get()) {
TDirectoryFile* dir = directoryStream->GetDirectory();
tree = (TTree*)dir->Get(destination_locator_.path.c_str());
}
auto treeStream = std::dynamic_pointer_cast<TTreeOutputStream>(destination_);
TTree* tree = treeStream->GetTree();

// Caches for the vectors of bools.
std::vector<std::shared_ptr<arrow::UInt8Array>> caches;
if (!tree) {
// I could simply set a prefix here to merge to an already existing tree.
throw std::runtime_error("Unsupported backend.");
}

for (auto i = 0u; i < batch->columns().size(); ++i) {
auto column = batch->column(i);
Expand All @@ -484,24 +557,11 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
auto list = std::static_pointer_cast<arrow::ListArray>(column);
valueArrays.back() = list;
} break;
default:
valueArrays.back() = column;
}
}

int64_t pos = 0;
while (pos < batch->num_rows()) {
for (size_t bi = 0; bi < branches.size(); ++bi) {
auto* branch = branches[bi];
auto* sizeBranch = sizesBranches[bi];
auto array = batch->column(bi);
auto& field = batch->schema()->field(bi);
auto& listSize = listSizes[bi];
auto valueType = valueTypes[bi];
auto valueArray = valueArrays[bi];
case arrow::Type::BOOL: {
// In case of arrays of booleans, we need to go back to their
// char based representation for ROOT to save them.
auto boolArray = std::static_pointer_cast<arrow::BooleanArray>(column);

if (field->type()->id() == arrow::Type::BOOL) {
auto boolArray = std::static_pointer_cast<arrow::BooleanArray>(array);
int64_t length = boolArray->length();
arrow::UInt8Builder builder;
auto ok = builder.Reserve(length);
Expand All @@ -516,11 +576,24 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
auto ok = builder.AppendNull();
}
}
valueArrays.back() = *builder.Finish();
} break;
default:
valueArrays.back() = column;
}
}

int64_t pos = 0;
while (pos < batch->num_rows()) {
for (size_t bi = 0; bi < branches.size(); ++bi) {
auto* branch = branches[bi];
auto* sizeBranch = sizesBranches[bi];
auto array = batch->column(bi);
auto& field = batch->schema()->field(bi);
auto& listSize = listSizes[bi];
auto valueType = valueTypes[bi];
auto valueArray = valueArrays[bi];

ok = builder.Finish(&caches[bi]);
branch->SetAddress((void*)(caches[bi]->values()->data()));
continue;
}
switch (field->type()->id()) {
case arrow::Type::LIST: {
auto list = std::static_pointer_cast<arrow::ListArray>(array);
Expand Down Expand Up @@ -764,13 +837,16 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
return generator;
}


arrow::Result<std::shared_ptr<arrow::io::OutputStream>> TTreeFileSystem::OpenOutputStream(
const std::string& path,
const std::shared_ptr<const arrow::KeyValueMetadata>& metadata)
{
auto stream = std::make_shared<TTreeOutputStream>(GetTree({path, shared_from_this()}));
return stream;
arrow::dataset::FileSource source{path, shared_from_this()};
auto prefix = metadata->Get("branch_prefix");
if (prefix.ok()) {
return std::make_shared<TTreeOutputStream>(GetTree(source), *prefix);
}
return std::make_shared<TTreeOutputStream>(GetTree(source), "");
}

TBufferFileFS::TBufferFileFS(TBufferFile* f)
Expand All @@ -782,7 +858,6 @@ TBufferFileFS::TBufferFileFS(TBufferFile* f)

TTreeFileSystem::~TTreeFileSystem() = default;


arrow::Result<arrow::fs::FileInfo> TBufferFileFS::GetFileInfo(const std::string& path)
{
arrow::fs::FileInfo result;
Expand Down
Loading

0 comments on commit cc0b2f8

Please sign in to comment.