diff --git a/husarion_ugv_utils/include/husarion_ugv_utils/moving_average.hpp b/husarion_ugv_utils/include/husarion_ugv_utils/moving_average.hpp index 0a6c76363..fc1d6b4fe 100644 --- a/husarion_ugv_utils/include/husarion_ugv_utils/moving_average.hpp +++ b/husarion_ugv_utils/include/husarion_ugv_utils/moving_average.hpp @@ -16,6 +16,7 @@ #define HUSARION_UGV_UTILS_MOVING_AVERAGE_HPP_ #include +#include namespace husarion_ugv_utils { @@ -25,40 +26,40 @@ class MovingAverage { public: MovingAverage(const std::size_t window_size = 5, const T initial_value = T(0)) - : window_size_(window_size), initial_value_(initial_value), sum_(T(0)) + : window_size_(window_size), initial_value_(initial_value) { + if (window_size_ == 0) { + throw std::invalid_argument("Window size must be greater than 0"); + } } void Roll(const T value) { - values_.push_back(value); - sum_ += value; + buffer_.push_back(value); - if (values_.size() > window_size_) { - sum_ -= values_.front(); - values_.pop_front(); + if (buffer_.size() > window_size_) { + buffer_.pop_front(); } } - void Reset() - { - values_.erase(values_.begin(), values_.end()); - sum_ = T(0); - } + void Reset() { buffer_.erase(buffer_.begin(), buffer_.end()); } T GetAverage() const { - if (values_.size() == 0) { + if (buffer_.size() == 0) { return initial_value_; } - return sum_ / static_cast(values_.size()); + + T sum = std::accumulate(buffer_.begin(), buffer_.end(), T(0)); + T average = sum / static_cast(buffer_.size()); + + return average; } private: const std::size_t window_size_; - std::deque values_; + std::deque buffer_; const T initial_value_; - T sum_; }; } // namespace husarion_ugv_utils diff --git a/husarion_ugv_utils/test/test_moving_average.cpp b/husarion_ugv_utils/test/test_moving_average.cpp index 653483ce7..2934f28e9 100644 --- a/husarion_ugv_utils/test/test_moving_average.cpp +++ b/husarion_ugv_utils/test/test_moving_average.cpp @@ -66,7 +66,8 @@ TEST(TestMovingAverage, TestHighOverload) // test every 1000 rolls expected average if (i % window_len == 0) { - EXPECT_EQ(sum / double(window_len), ma.GetAverage()); + ASSERT_NEAR( + sum / double(window_len), ma.GetAverage(), std::numeric_limits::epsilon()); sum = 0.0; } } @@ -107,6 +108,25 @@ TEST(TestMovingAverage, TestResetToInitialValue) EXPECT_EQ(7.0, ma.GetAverage()); } +TEST(TestMovingAverage, TestInfInjectionHandling) +{ + husarion_ugv_utils::MovingAverage ma(4); + ma.Roll(1.0); + ma.Roll(2.0); + ma.Roll(3.0); + ma.Roll(4.0); + EXPECT_EQ(2.5, ma.GetAverage()); + + ma.Roll(std::numeric_limits::infinity()); + EXPECT_EQ(std::numeric_limits::infinity(), ma.GetAverage()); + + ma.Roll(1.0); + ma.Roll(2.0); + ma.Roll(3.0); + ma.Roll(4.0); + EXPECT_EQ(2.5, ma.GetAverage()); +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv);