From 463a3682be768244acecc51e4f6f9987b32c5ba8 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 23 Feb 2026 08:55:04 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 874097322 --- compression/test_util-inl.h | 57 ++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 99b34b5b..bb2fadb0 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -18,6 +18,7 @@ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ #include +#include #include @@ -219,6 +220,8 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, // magnitude, but also to f32 accumulation of rows in A and B. const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch); const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch); + HWY_ASSERT(hn::AllTrue(df, hn::IsFinite(hn::Set(df, norm)))); + HWY_ASSERT(hn::AllTrue(df, hn::IsFinite(hn::Set(df, max_abs)))); const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); // Dot() uses double-precision summation. @@ -232,10 +235,7 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, if (tolerance > 500.0) { HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); } - const double rel_tolerance = - 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); - - double max_rel = 0.0; + double worst_l1 = 0.0; size_t worst_r = 0; size_t worst_c = 0; double worst_actual = 0.0; @@ -247,34 +247,45 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, for (size_t c = 0; c < B.Rows(); c++) { const double expected_value = static_cast(expected_row[c]); const double actual_value = static_cast(actual_row[c]); - const bool in_range = expected_value - tolerance <= actual_value && - actual_value <= expected_value + tolerance; + const double l1 = hwy::ScalarAbs(expected_value - actual_value); + if (l1 > HWY_MAX(tolerance, tolerance * hwy::ScalarAbs(expected_value))) { + fprintf(stderr, "%zu,%zu\n", r, c); + ++num_outside; + } - if (!in_range) { - const double max = HWY_MAX(expected_value, actual_value); - const double min = HWY_MIN(expected_value, actual_value); - const double rel = max / HWY_MAX(min, 1E-6); - if (rel > max_rel) { - worst_expected = expected_value; - worst_actual = actual_value; - worst_r = r; - worst_c = c; - max_rel = rel; - ++num_outside; - } + if (l1 > worst_l1) { + worst_l1 = l1; + worst_expected = expected_value; + worst_actual = actual_value; + worst_r = r; + worst_c = c; } } } - if (max_rel > rel_tolerance) { + if (num_outside > 0) { + const size_t r_begin = worst_r >= 1 ? worst_r - 1 : 0; + const size_t r_end = HWY_MIN(r_begin + 3, A.Rows()); + const size_t c_begin = worst_c >= 3 ? worst_c - 3 : 0; + const size_t c_end = HWY_MIN(c_begin + 7, B.Rows()); + fprintf(stderr, + "%zu outside. Printing rows [%zu, %zu) and columns [%zu, %zu)\n", + num_outside, r_begin, r_end, c_begin, c_end); + for (size_t r = r_begin; r < r_end; r++) { + const float* expected_row = c_slow_batch.Row(r); + const float* actual_row = c_batch.Row(r); + for (size_t c = c_begin; c < c_end; c++) { + fprintf(stderr, "%6.3f=%6.3f ", expected_row[c], actual_row[c]); + } + fprintf(stderr, "\n"); + } + hwy::Abort(__FILE__, line, "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " - "tolerance %f rel %E max_rel %E num_outside %zu\n", + "tolerance %f worst_l1 %E\n", worst_r, worst_c, worst_expected, worst_actual, norm, max_abs, - tolerance, max_rel, rel_tolerance, num_outside); + tolerance, worst_l1); } - HWY_ASSERT(hn::AllFalse( - df, hn::IsEitherNaN(hn::Set(df, norm), hn::Set(df, max_abs)))); } // NOLINTNEXTLINE(google-readability-namespace-comments)