Skip to content

Commit

Permalink
Create a new class StatWithPercentiles that inherits Stat.
Browse files Browse the repository at this point in the history
A StatWithPercentiles object keeps track of the values added to it, and supports computing percentile values.

PiperOrigin-RevId: 688045781
  • Loading branch information
Google-ML-Automation committed Oct 21, 2024
1 parent 2b5a381 commit 8bb8575
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
32 changes: 32 additions & 0 deletions xla/tsl/util/stats_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,38 @@ class Stat {
HighPrecisionValueType squared_sum_ = 0;
};

// A `StatWithPercentiles` inherited from `Stat`, also keeps track of the
// values added and can be used to compute the percentile values.
template <typename HighPrecisionValueType = double>
class StatWithPercentiles : public Stat<int64_t, HighPrecisionValueType> {
public:
void UpdateStat(int64_t v) {
Stat<int64_t, HighPrecisionValueType>::UpdateStat(v);
values_.push_back(v);
}

// Returns the percentile value.
int64_t percentile(int percentile) const {
if (this->count() == 0) {
return std::numeric_limits<int>::quiet_NaN();
}
std::vector<int64_t> values = values_;
std::nth_element(values.begin(),
values.begin() + values.size() * percentile / 100,
values.end());
return values[values_.size() * percentile / 100];
}

void OutputToStream(std::ostream* stream) const {
Stat<int64_t, HighPrecisionValueType>::OutputToStream(stream);
*stream << " p10=" << percentile(10) << " median=" << percentile(50)
<< " p90=" << percentile(90);
}

private:
std::vector<int64_t> values_;
};

// A StatsCalculator assists in performance analysis of Graph executions.
//
// It summarizes time spent executing (on GPU/CPU), memory used etc for
Expand Down
34 changes: 34 additions & 0 deletions xla/tsl/util/stats_calculator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/tsl/util/stats_calculator.h"

#include <cfloat>
#include <cmath>

#include "tsl/platform/test.h"

Expand Down Expand Up @@ -104,5 +105,38 @@ TEST(StatsCalculatorTest, UpdateStat) {
EXPECT_NEAR(43.30704330706496060826769, stat.std_deviation(), FLT_EPSILON);
}

TEST(StatsCalculatorTest, StatWithPercentiles) {
StatWithPercentiles<double> stat;
EXPECT_TRUE(stat.empty());
EXPECT_TRUE(stat.all_same());
stat.UpdateStat(1);
EXPECT_TRUE(stat.all_same());
stat.UpdateStat(-1.0);
EXPECT_FALSE(stat.all_same());
stat.UpdateStat(100);
stat.UpdateStat(0);
EXPECT_EQ(4, stat.count());
EXPECT_EQ(-1, stat.min());
EXPECT_EQ(100, stat.max());
EXPECT_EQ(25, stat.avg());
EXPECT_EQ(1, stat.first());
EXPECT_EQ(0, stat.newest());
EXPECT_EQ(10002, stat.squared_sum());
EXPECT_EQ(625, stat.avg() * stat.avg());
// Sample variance
EXPECT_EQ(7502 / 3, stat.sample_variance());
// Sample standard deviation, from WolframAlpha
EXPECT_EQ(50, std::sqrt(stat.sample_variance()));
// Population variance
EXPECT_EQ(7502 / 4, stat.variance());
// Population standard deviation, from WolframAlpha
EXPECT_EQ(43, stat.std_deviation());
EXPECT_EQ(1, stat.percentile(50));
EXPECT_EQ(100, stat.percentile(90));
stat.UpdateStat(150);
EXPECT_EQ(1, stat.percentile(50));
EXPECT_EQ(150, stat.percentile(90));
}

} // namespace
} // namespace tsl

0 comments on commit 8bb8575

Please sign in to comment.