Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions compression/test_util-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_

#include <stddef.h>
#include <stdio.h>

#include <vector>

Expand Down Expand Up @@ -219,6 +220,8 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
// magnitude, but also to f32 accumulation of rows in A and B.
const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
HWY_ASSERT(hn::AllTrue(df, hn::IsFinite(hn::Set(df, norm))));
HWY_ASSERT(hn::AllTrue(df, hn::IsFinite(hn::Set(df, max_abs))));
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
// Dot() uses double-precision summation.
Expand All @@ -232,10 +235,7 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
const double rel_tolerance =
1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());

double max_rel = 0.0;
double worst_l1 = 0.0;
size_t worst_r = 0;
size_t worst_c = 0;
double worst_actual = 0.0;
Expand All @@ -247,34 +247,45 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
for (size_t c = 0; c < B.Rows(); c++) {
const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance;
const double l1 = hwy::ScalarAbs(expected_value - actual_value);
if (l1 > HWY_MAX(tolerance, tolerance * hwy::ScalarAbs(expected_value))) {
fprintf(stderr, "%zu,%zu\n", r, c);
++num_outside;
}

if (!in_range) {
const double max = HWY_MAX(expected_value, actual_value);
const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
worst_expected = expected_value;
worst_actual = actual_value;
worst_r = r;
worst_c = c;
max_rel = rel;
++num_outside;
}
if (l1 > worst_l1) {
worst_l1 = l1;
worst_expected = expected_value;
worst_actual = actual_value;
worst_r = r;
worst_c = c;
}
}
}

if (max_rel > rel_tolerance) {
if (num_outside > 0) {
const size_t r_begin = worst_r >= 1 ? worst_r - 1 : 0;
const size_t r_end = HWY_MIN(r_begin + 3, A.Rows());
const size_t c_begin = worst_c >= 3 ? worst_c - 3 : 0;
const size_t c_end = HWY_MIN(c_begin + 7, B.Rows());
fprintf(stderr,
"%zu outside. Printing rows [%zu, %zu) and columns [%zu, %zu)\n",
num_outside, r_begin, r_end, c_begin, c_end);
for (size_t r = r_begin; r < r_end; r++) {
const float* expected_row = c_slow_batch.Row(r);
const float* actual_row = c_batch.Row(r);
for (size_t c = c_begin; c < c_end; c++) {
fprintf(stderr, "%6.3f=%6.3f ", expected_row[c], actual_row[c]);
}
fprintf(stderr, "\n");
}

hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E num_outside %zu\n",
"tolerance %f worst_l1 %E\n",
worst_r, worst_c, worst_expected, worst_actual, norm, max_abs,
tolerance, max_rel, rel_tolerance, num_outside);
tolerance, worst_l1);
}
HWY_ASSERT(hn::AllFalse(
df, hn::IsEitherNaN(hn::Set(df, norm), hn::Set(df, max_abs))));
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand Down
Loading