Skip to content

Commit

Permalink
Bin count in ProbabilityEvolution
Browse files Browse the repository at this point in the history
  • Loading branch information
PKua007 committed Oct 13, 2023
1 parent a0a79f3 commit 9549a38
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 14 deletions.
13 changes: 9 additions & 4 deletions src/core/observables/correlation/ProbabilityEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ ProbabilityEvolution::ProbabilityEvolution(double maxDistance, std::size_t numDi
std::shared_ptr<PairEnumerator> pairEnumerator,
std::pair<double, double> functionRange,
std::size_t numFunctionBins, std::shared_ptr<CorrelationFunction> function,
Normalization normalization, std::size_t numThreads)
Normalization normalization, bool printCount, std::size_t numThreads)
: PairConsumer(numThreads), maxDistance{maxDistance}, functionRange{functionRange},
histogramBuilder({0, functionRange.first}, {maxDistance, functionRange.second},
{numDistanceBins, numFunctionBins}, numThreads),
pairEnumerator{std::move(pairEnumerator)}, function{std::move(function)}, normalization{normalization}
pairEnumerator{std::move(pairEnumerator)}, function{std::move(function)}, normalization{normalization},
printCount{printCount}
{
Expects(maxDistance > 0);
}
Expand All @@ -29,8 +30,12 @@ void ProbabilityEvolution::print(std::ostream &out) const {
Histogram2D histogram = this->histogramBuilder.dumpHistogram(ReductionMethod::SUM);
this->renormalizeHistogram(histogram);

for (auto [xy, z, count] : histogram.dumpValues())
out << xy[0] << " " << xy[1] << " " << z << std::endl;
for (auto [xy, z, count] : histogram.dumpValues()) {
out << xy[0] << " " << xy[1] << " " << z;
if (this->printCount)
out << " " << count;
out << std::endl;
}
}

void ProbabilityEvolution::clear() {
Expand Down
17 changes: 11 additions & 6 deletions src/core/observables/correlation/ProbabilityEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ProbabilityEvolution : public BulkObservable, public PairConsumer {
std::shared_ptr<PairEnumerator> pairEnumerator;
std::shared_ptr<CorrelationFunction> function;
Normalization normalization{};
bool printCount{};

void renormalizeHistogram(Histogram2D &histogram) const;
void consumePair(const Packing &packing, const std::pair<std::size_t, std::size_t> &idxPair,
Expand All @@ -64,30 +65,34 @@ class ProbabilityEvolution : public BulkObservable, public PairConsumer {
* @param numFunctionBins number of bins for the CorrelationFunction values
* @param function the CorrelationFunctions to compute
* @param normalization type of normalization
* @param printCount if @a true, raw bin count will be printed in addition to function values
* @param numThreads number of threads used to generate the histogram. If 0, all available threads will be used
*/
ProbabilityEvolution(double maxDistance, std::size_t numDistanceBins,
std::shared_ptr<PairEnumerator> pairEnumerator, std::pair<double, double> functionRange,
std::size_t numFunctionBins, std::shared_ptr<CorrelationFunction> function,
Normalization normalization = Normalization::PDF, std::size_t numThreads = 1);
Normalization normalization = Normalization::PDF, bool printCount = false,
std::size_t numThreads = 1);

void addSnapshot(const Packing &packing, double temperature, double pressure,
const ShapeTraits &shapeTraits) override;

/**
* @brief Output the histogram.
* @brief Outputs the histogram.
* @details The format is
* @code
* [distance 1] [function value 1] [probability/count 1,1]
* [distance 1] [function value 1] [probability/count 1,1] [raw count 1,1]
* ...
* [distance 1] [function value N] [probability/count 1,N]
* [distance 1] [function value N] [probability/count 1,N] [raw count 1,N]
*
* ...
*
* [distance M] [function value 1] [probability/count M,1]
* [distance M] [function value 1] [probability/count M,1] [raw count M,1]
* ...
* [distance M] [function value N] [probability/count M,N]
* [distance M] [function value N] [probability/count M,N] [raw count M,N]
* @endcode
*
* The last column is outputted only if @a printCount was set @a true.
*/
void print(std::ostream &out) const override;

Expand Down
7 changes: 5 additions & 2 deletions src/frontend/matchers/ObservablesMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ namespace {
{"fun_range", functionRange},
{"n_bins_fun", nBins},
{"function", correlationFunction},
{"normalization", normalization, "None"}})
{"normalization", normalization, "None"},
{"print_count", MatcherBoolean{}, "False"}})
.mapTo([maxThreads](const DataclassData &probEvolution) -> std::shared_ptr<BulkObservable> {
auto maxR = probEvolution["max_r"].as<double>();
auto nBinsR = probEvolution["n_bins_r"].as<std::size_t>();
Expand All @@ -458,8 +459,10 @@ namespace {
auto nBinsFun = probEvolution["n_bins_fun"].as<std::size_t>();
auto function = probEvolution["function"].as<std::shared_ptr<CorrelationFunction>>();
auto normalization = probEvolution["normalization"].as<Normalization>();
auto printCount = probEvolution["print_count"].as<bool>();

return std::make_shared<ProbabilityEvolution>(
maxR, nBinsR, binning, funRange, nBinsFun, function, normalization, maxThreads
maxR, nBinsR, binning, funRange, nBinsFun, function, normalization, printCount, maxThreads
);
});
}
Expand Down
4 changes: 2 additions & 2 deletions test/unit_tests/core/observables/ProbabilityEvolutionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TEST_CASE("ProbabilityEvolution") {
}

SECTION("normalization: PDF") {
ProbabilityEvolution evolution(2, 2, enumerator, {0, 2}, 2, function, Normalization::PDF);
ProbabilityEvolution evolution(2, 2, enumerator, {0, 2}, 2, function, Normalization::PDF, false);
evolution.addSnapshot(packing, 1, 1, traits);
evolution.addSnapshot(packing, 1, 1, traits);

Expand All @@ -93,7 +93,7 @@ TEST_CASE("ProbabilityEvolution") {
}

SECTION("normalization: UNIT") {
ProbabilityEvolution evolution(2, 2, enumerator, {0, 2}, 2, function, Normalization::UNIT);
ProbabilityEvolution evolution(2, 2, enumerator, {0, 2}, 2, function, Normalization::UNIT, false);
evolution.addSnapshot(packing, 1, 1, traits);
evolution.addSnapshot(packing, 1, 1, traits);

Expand Down

0 comments on commit 9549a38

Please sign in to comment.