diff --git a/xla/tsl/util/stats_calculator.h b/xla/tsl/util/stats_calculator.h index 84045fb6ceece2..b9d296171477d0 100644 --- a/xla/tsl/util/stats_calculator.h +++ b/xla/tsl/util/stats_calculator.h @@ -117,6 +117,38 @@ class Stat { HighPrecisionValueType squared_sum_ = 0; }; +// A `StatWithPercentiles` inherited from `Stat`, also keeps track of the +// values added and can be used to compute the percentile values. +template +class StatWithPercentiles : public Stat { + public: + void UpdateStat(int64_t v) { + Stat::UpdateStat(v); + values_.push_back(v); + } + + // Returns the percentile value. + int64_t percentile(int percentile) const { + if (this->count() == 0) { + return std::numeric_limits::quiet_NaN(); + } + std::vector values = values_; + std::nth_element(values.begin(), + values.begin() + values.size() * percentile / 100, + values.end()); + return values[values_.size() * percentile / 100]; + } + + void OutputToStream(std::ostream* stream) const { + Stat::OutputToStream(stream); + *stream << " p10=" << percentile(10) << " median=" << percentile(50) + << " p90=" << percentile(90); + } + + private: + std::vector values_; +}; + // A StatsCalculator assists in performance analysis of Graph executions. // // It summarizes time spent executing (on GPU/CPU), memory used etc for diff --git a/xla/tsl/util/stats_calculator_test.cc b/xla/tsl/util/stats_calculator_test.cc index d58186630598f0..47b5eefa884936 100644 --- a/xla/tsl/util/stats_calculator_test.cc +++ b/xla/tsl/util/stats_calculator_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/tsl/util/stats_calculator.h" #include +#include #include "tsl/platform/test.h" @@ -104,5 +105,38 @@ TEST(StatsCalculatorTest, UpdateStat) { EXPECT_NEAR(43.30704330706496060826769, stat.std_deviation(), FLT_EPSILON); } +TEST(StatsCalculatorTest, StatWithPercentiles) { + StatWithPercentiles stat; + EXPECT_TRUE(stat.empty()); + EXPECT_TRUE(stat.all_same()); + stat.UpdateStat(1); + EXPECT_TRUE(stat.all_same()); + stat.UpdateStat(-1.0); + EXPECT_FALSE(stat.all_same()); + stat.UpdateStat(100); + stat.UpdateStat(0); + EXPECT_EQ(4, stat.count()); + EXPECT_EQ(-1, stat.min()); + EXPECT_EQ(100, stat.max()); + EXPECT_EQ(25, stat.avg()); + EXPECT_EQ(1, stat.first()); + EXPECT_EQ(0, stat.newest()); + EXPECT_EQ(10002, stat.squared_sum()); + EXPECT_EQ(625, stat.avg() * stat.avg()); + // Sample variance + EXPECT_EQ(7502 / 3, stat.sample_variance()); + // Sample standard deviation, from WolframAlpha + EXPECT_EQ(50, std::sqrt(stat.sample_variance())); + // Population variance + EXPECT_EQ(7502 / 4, stat.variance()); + // Population standard deviation, from WolframAlpha + EXPECT_EQ(43, stat.std_deviation()); + EXPECT_EQ(1, stat.percentile(50)); + EXPECT_EQ(100, stat.percentile(90)); + stat.UpdateStat(150); + EXPECT_EQ(1, stat.percentile(50)); + EXPECT_EQ(150, stat.percentile(90)); +} + } // namespace } // namespace tsl