Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tree/dataframe/inc/ROOT/RDF/RSampleInfo.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class RSampleInfo {
std::pair<ULong64_t, ULong64_t> fEntryRange;

const ROOT::RDF::Experimental::RSample *fSample = nullptr; // non-owning
ULong64_t fTotalEntries = 0; // Number of entries in current file (if known).

void ThrowIfNoSample() const
{
Expand All @@ -48,8 +49,8 @@ class RSampleInfo {

public:
RSampleInfo(std::string_view id, std::pair<ULong64_t, ULong64_t> entryRange,
const ROOT::RDF::Experimental::RSample *sample = nullptr)
: fID(id), fEntryRange(entryRange), fSample(sample)
const ROOT::RDF::Experimental::RSample *sample = nullptr, ULong64_t totalEntries = 0)
: fID(id), fEntryRange(entryRange), fSample(sample), fTotalEntries{totalEntries}
{
}
RSampleInfo() = default;
Expand Down Expand Up @@ -121,9 +122,14 @@ public:
/// Multiple multi-threading tasks might process different entry ranges of the same sample.
std::pair<ULong64_t, ULong64_t> EntryRange() const { return fEntryRange; }

/// @brief Return the number of entries of this sample that is being taken into consideration.
/// @brief Return the number of entries of this range of the sample.
ULong64_t NEntries() const { return fEntryRange.second - fEntryRange.first; }

/// Return the total number of entries in the underlying dataset.
/// If the total number of entries is not known, the end of the current range is returned.
/// This can be larger than NEntries() if the sample is split in multiple ranges.
ULong64_t NEntriesTotal() const { return std::max(fTotalEntries, fEntryRange.second); }

bool operator==(const RSampleInfo &other) const { return fID == other.fID; }
bool operator!=(const RSampleInfo &other) const { return !(*this == other); }
};
Expand Down
134 changes: 31 additions & 103 deletions tree/dataframe/inc/ROOT/RDFHelpers.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,6 @@ void AddProgressBar(ROOT::RDataFrame df);
/// @param nThread Number of threads that share a TH3D.
void ThreadsPerTH3(unsigned int nThread = 1);

class ProgressBarAction;

/// RDF progress helper.
/// This class provides callback functions to the RDataFrame. The event statistics
/// (including elapsed time, currently processed file, currently processed events, the rate of event processing
Expand All @@ -333,132 +331,62 @@ class ProgressBarAction;
/// ~~~
class ProgressHelper {
private:
std::size_t ComputeTotalEvents() const;
double EvtPerSec() const;
void PrintProgressAndStats(std::ostream &stream, std::size_t currentEventCount,
std::chrono::seconds totalElapsedSeconds) const;
std::pair<std::size_t, std::chrono::seconds> RecordEvtCountAndTime();
void PrintStats(std::ostream &stream, std::size_t currentEventCount, std::chrono::seconds totalElapsedSeconds) const;
void PrintStatsFinal(std::ostream &stream, std::chrono::seconds totalElapsedSeconds) const;
void PrintProgressBar(std::ostream &stream, std::size_t currentEventCount) const;
void Update();

std::chrono::time_point<std::chrono::system_clock> fBeginTime = std::chrono::system_clock::now();
std::chrono::time_point<std::chrono::system_clock> fLastPrintTime = fBeginTime;
std::chrono::seconds fPrintInterval{1};
bool const fIsTTY;
bool const fUseShellColours;

std::atomic<std::size_t> fProcessedEvents{0};
std::size_t fLastProcessedEvents{0};
std::size_t fIncrement;

mutable std::mutex fSampleNameToEventEntriesMutex;
std::map<std::string, ULong64_t> fSampleNameToEventEntries; // Filename, events in the file
std::size_t const fIncrement;
unsigned int const fNColumns;
unsigned int const fTotalFiles;

std::array<double, 20> fEventsPerSecondStatistics;
std::size_t fEventsPerSecondStatisticsIndex{0};
std::array<double, 10> fEventsPerSecondStatistics;
unsigned int fEventsPerSecondStatisticsCounter{0};

unsigned int fBarWidth;
unsigned int fTotalFiles;
std::chrono::time_point<std::chrono::system_clock> const fBeginTime = std::chrono::system_clock::now();
std::chrono::time_point<std::chrono::system_clock> fLastPrintTime = fBeginTime;
std::chrono::seconds const fPrintInterval;

std::mutex fPrintMutex;
bool fIsTTY;
bool fUseShellColours;
// Mutex to ensure that only one thread updates the progress bar.
// Lock this mutex to update any of the members above:
std::mutex fUpdateMutex;

std::shared_ptr<TTree> fTree{nullptr};
mutable std::mutex fSampleNameToEventEntriesMutex; // Mutex to protect access to the below map
std::map<std::string, ULong64_t> fSampleNameToEventEntries; // Filename, events in the file

public:
/// Create a progress helper.
/// \param increment RDF callbacks are called every `n` events. Pass this `n` here.
/// \param totalFiles read total number of files in the RDF.
/// \param progressBarWidth Number of characters the progress bar will occupy.
/// \param printInterval Update every stats every `n` seconds.
/// \param totalFiles number of files read in the RDF.
/// \param printInterval Update stats every `n` seconds.
/// \param useColors Use shell colour codes to colour the output. Automatically disabled when
/// we are not writing to a tty.
ProgressHelper(std::size_t increment, unsigned int totalFiles = 1, unsigned int progressBarWidth = 40,
unsigned int printInterval = 1, bool useColors = true);

ProgressHelper(std::size_t increment, unsigned int totalFiles, unsigned int printInterval = 0,
bool useColors = true);
ProgressHelper(ProgressHelper const &) = delete; // The mutexes and atomics won't allow copy/move
ProgressHelper(ProgressHelper &&) = delete;
~ProgressHelper() = default;
ProgressHelper &operator=(ProgressHelper const &) = delete;
ProgressHelper &operator=(ProgressHelper &&) = delete;

friend class ProgressBarAction;

/// Register a new sample for completion statistics.
/// \see ROOT::RDF::RInterface::DefinePerSample().
/// The *id.AsString()* refers to the name of the currently processed file.
/// The idea is to populate the event entries in the *fSampleNameToEventEntries* map
/// by selecting the greater of the two values:
/// *id.EntryRange().second* which is the upper event entry range of the processed sample
/// and the current value of the event entries in the *fSampleNameToEventEntries* map.
/// In the single threaded case, the two numbers are the same as the entry range corresponds
/// to the number of events in an individual file (each sample is simply a single file).
/// In the multithreaded case, the idea is to accumulate the higher event entry value until
/// the total number of events in a given file is reached.
void registerNewSample(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &id)
{
std::lock_guard<std::mutex> lock(fSampleNameToEventEntriesMutex);
fSampleNameToEventEntries[id.AsString()] =
std::max(id.EntryRange().second, fSampleNameToEventEntries[id.AsString()]);
}
void RegisterNewSample(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &id);

/// Thread-safe callback for RDataFrame.
/// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the
/// fPrintInterval). \param slot Ignored. \param value Ignored.
template <typename T>
void operator()(unsigned int /*slot*/, T &value)
{
operator()(value);
}
// clang-format off
/// Thread-safe callback for RDataFrame.
/// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the fPrintInterval).
/// \param value Ignored.
// clang-format on
template <typename T>
void operator()(T & /*value*/)
{
using namespace std::chrono;
// ***************************************************
// Warning: Here, everything needs to be thread safe:
// ***************************************************
fProcessedEvents += fIncrement;

// We only print every n seconds.
if (duration_cast<seconds>(system_clock::now() - fLastPrintTime) < fPrintInterval) {
return;
}

// ***************************************************
// Protected by lock from here:
// ***************************************************
if (!fPrintMutex.try_lock())
return;
std::lock_guard<std::mutex> lockGuard(fPrintMutex, std::adopt_lock);

std::size_t eventCount;
seconds elapsedSeconds;
std::tie(eventCount, elapsedSeconds) = RecordEvtCountAndTime();

if (fIsTTY)
std::cout << "\r";

PrintProgressBar(std::cout, eventCount);
PrintStats(std::cout, eventCount, elapsedSeconds);

if (fIsTTY)
std::cout << std::flush;
else
std::cout << std::endl;
}

std::size_t ComputeNEventsSoFar() const
{
std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex);
std::size_t result = 0;
for (const auto &item : fSampleNameToEventEntries)
result += item.second;
return result;
}

unsigned int ComputeCurrentFileIdx() const
void operator()(unsigned int /*slot*/, T & /*value*/)
{
std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex);
return fSampleNameToEventEntries.size();
Update();
}
void PrintStatsFinal() const;
};
} // namespace Experimental
} // namespace RDF
Expand Down
Loading
Loading