Skip to content

Commit 7d75aad

Browse files
committed
MIR-678 improved weight matrix validation (exception throwing, added check for duplicates), fine grained method-specific matrix validation
1 parent a944ff7 commit 7d75aad

16 files changed

+60
-56
lines changed

src/mir/caching/WeightCache.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ void WeightCacheTraits::load(const eckit::CacheManagerBase& manager, value_type&
8383
value_type tmp(matrix::MatrixLoaderFactory::build(manager.loader(), path));
8484
w.swap(tmp);
8585

86-
static bool matrixValidate = eckit::Resource<bool>("$MIR_MATRIX_VALIDATE", false);
87-
if (matrixValidate) {
88-
w.validate("fromCache");
89-
}
86+
w.validate("fromCache"); // check matrix structure (only)
9087
}
9188

9289

src/mir/method/MethodWeighted.cc

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ MethodWeighted::MethodWeighted(const param::MIRParametrisation& parametrisation)
8181
parametrisation_.get("pole-displacement-in-degree", poleDisplacement_);
8282
ASSERT(poleDisplacement_ >= 0);
8383

84-
matrixValidate_ = eckit::Resource<bool>("$MIR_MATRIX_VALIDATE", false);
8584
matrixAssemble_ = parametrisation_.userParametrisation().has("filter");
8685

8786
std::string nonLinear = "missing-if-heaviest-missing";
@@ -165,14 +164,15 @@ void MethodWeighted::createMatrix(context::Context& ctx, const repres::Represent
165164
const repres::Representation& out, WeightMatrix& W, const lsm::LandSeaMasks& masks,
166165
const Cropping& /*cropping*/) const {
167166
trace::ResourceUsage usage(std::string("MethodWeighted::createMatrix [") + name() + "]");
167+
const auto checks = validateMatrixWeights();
168168

169-
computeMatrixWeights(ctx, in, out, W, validateMatrixWeights());
169+
// matrix validation always happens after creation, because the matrix can/will be cached
170+
computeMatrixWeights(ctx, in, out, W);
171+
W.validate("computeMatrixWeights", checks);
170172

171173
if (masks.active() && masks.cacheable()) {
172174
applyMasks(W, masks);
173-
if (matrixValidate_) {
174-
W.validate("applyMasks");
175-
}
175+
W.validate("applyMasks", checks);
176176
}
177177
}
178178

@@ -265,9 +265,7 @@ const WeightMatrix& MethodWeighted::getMatrix(context::Context& ctx, const repre
265265
// it will be cached in memory nevertheless
266266
if (masks.active() && !masks.cacheable()) {
267267
applyMasks(W, masks);
268-
if (matrixValidate_) {
269-
W.validate("applyMasks");
270-
}
268+
W.validate("applyMasks", validateMatrixWeights());
271269
}
272270

273271

@@ -404,8 +402,8 @@ lsm::LandSeaMasks MethodWeighted::getMasks(const repres::Representation& in, con
404402
}
405403

406404

407-
bool MethodWeighted::validateMatrixWeights() const {
408-
return true;
405+
WeightMatrix::Check MethodWeighted::validateMatrixWeights() const {
406+
return {};
409407
}
410408

411409

@@ -494,9 +492,7 @@ void MethodWeighted::execute(context::Context& ctx, const repres::Representation
494492
trace::Timer t(str.str());
495493

496494
if (n->treatment(A, M, B, field.values(i), missingValue)) {
497-
if (matrixValidate_) {
498-
M.validate(str.str().c_str());
499-
}
495+
M.validate(str.str().c_str(), validateMatrixWeights());
500496
}
501497
}
502498

@@ -528,7 +524,7 @@ void MethodWeighted::execute(context::Context& ctx, const repres::Representation
528524

529525

530526
void MethodWeighted::computeMatrixWeights(context::Context& ctx, const repres::Representation& in,
531-
const repres::Representation& out, WeightMatrix& W, bool validate) const {
527+
const repres::Representation& out, WeightMatrix& W) const {
532528
auto timing(ctx.statistics().computeMatrixTimer());
533529

534530
if (in.sameAs(out) && !matrixAssemble_) {
@@ -565,11 +561,6 @@ void MethodWeighted::computeMatrixWeights(context::Context& ctx, const repres::R
565561
W.swap(w);
566562
}
567563
}
568-
569-
// matrix validation always happens after creation, because the matrix can/will be cached
570-
if (validate) {
571-
W.validate("computeMatrixWeights");
572-
}
573564
}
574565

575566

src/mir/method/MethodWeighted.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ class MethodWeighted : public Method {
114114
std::unique_ptr<const reorder::Reorder> reorderRows_;
115115
std::unique_ptr<const reorder::Reorder> reorderCols_;
116116

117-
bool matrixValidate_;
118117
bool matrixAssemble_;
119118

120119
// -- Methods
@@ -127,10 +126,10 @@ class MethodWeighted : public Method {
127126

128127
virtual void applyMasks(WeightMatrix&, const lsm::LandSeaMasks&) const;
129128
virtual lsm::LandSeaMasks getMasks(const repres::Representation& in, const repres::Representation& out) const;
130-
virtual bool validateMatrixWeights() const;
129+
virtual WeightMatrix::Check validateMatrixWeights() const;
131130

132131
void computeMatrixWeights(context::Context&, const repres::Representation& in, const repres::Representation& out,
133-
WeightMatrix&, bool validate) const;
132+
WeightMatrix&) const;
134133
void createMatrix(context::Context&, const repres::Representation& in, const repres::Representation& out,
135134
WeightMatrix&, const lsm::LandSeaMasks&, const Cropping&) const;
136135

src/mir/method/WeightMatrix.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <sstream>
1717
#include <unordered_set>
1818

19+
#include "eckit/config/Resource.h"
1920
#include "eckit/types/FloatCompare.h"
2021

2122
#include "mir/util/Exceptions.h"
@@ -102,7 +103,12 @@ void WeightMatrix::cleanup(const double& pruneEpsilon) {
102103
}
103104

104105

105-
void WeightMatrix::validate(const char* when) const {
106+
void WeightMatrix::validate(const char* when, Check check) const {
107+
static bool matrixValidate = eckit::Resource<bool>("$MIR_MATRIX_VALIDATE", true);
108+
if (!matrixValidate || (!check.duplicates && !check.bounds && !check.sum)) {
109+
return;
110+
}
111+
106112
constexpr size_t Nerrors = 10;
107113
size_t errors = 0;
108114

@@ -114,20 +120,25 @@ void WeightMatrix::validate(const char* when) const {
114120
Scalar sum = 0.;
115121
std::unordered_set<Size> cols;
116122

117-
bool check_bounds = true;
118-
bool check_no_duplicates = true;
123+
auto check_bounds = true;
124+
auto check_duplicates = true;
119125
for (auto it = begin(r); it != end(r); ++it) {
120126
auto a = *it;
121127
check_bounds &= eckit::types::is_approximately_greater_or_equal(a, 0.) &&
122128
eckit::types::is_approximately_greater_or_equal(1., a);
123129
sum += a;
124130

125-
check_no_duplicates &= cols.insert(it.col()).second;
131+
check_duplicates &= cols.insert(it.col()).second;
126132
}
127133

128134
auto check_sum = eckit::types::is_approximately_equal(sum, 0.) || eckit::types::is_approximately_equal(sum, 1.);
129135

130-
if (!check_bounds || !check_sum || !check_no_duplicates) {
136+
// ignore checks as required
137+
check_duplicates |= !check.duplicates;
138+
check_bounds |= !check.bounds;
139+
check_sum |= !check.sum;
140+
141+
if (!check_bounds || !check_sum || !check_duplicates) {
131142
if (errors < Nerrors) {
132143
what << sep << "row " << r << ": ";
133144
const char* s = "";
@@ -142,7 +153,7 @@ void WeightMatrix::validate(const char* when) const {
142153
s = ", ";
143154
}
144155

145-
if (!check_no_duplicates) {
156+
if (!check_duplicates) {
146157
what << s << "duplicate indices";
147158
s = ", ";
148159
}

src/mir/method/WeightMatrix.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,18 @@ namespace mir::method {
2222

2323

2424
class WeightMatrix final : public eckit::linalg::SparseMatrix {
25-
public: // types
25+
public:
2626
using Triplet = eckit::linalg::Triplet;
2727
using Scalar = eckit::linalg::Scalar;
2828
using Size = eckit::linalg::Size;
2929

30-
public: // methods
30+
struct Check {
31+
bool duplicates = true;
32+
bool bounds = true;
33+
bool sum = true;
34+
};
35+
36+
public:
3137
WeightMatrix(SparseMatrix::Allocator* = nullptr);
3238

3339
WeightMatrix(const eckit::PathName&);
@@ -38,9 +44,10 @@ class WeightMatrix final : public eckit::linalg::SparseMatrix {
3844

3945
void cleanup(const double& pruneEpsilon = 0);
4046

41-
void validate(const char* when) const;
47+
// Validate interpolation weights (default check matrix structure only)
48+
void validate(const char* when, Check = {true, false, false}) const;
4249

43-
private: // members
50+
private:
4451
void print(std::ostream&) const;
4552

4653
friend std::ostream& operator<<(std::ostream& out, const WeightMatrix& m) {

src/mir/method/gridbox/GridBoxMethod.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ void GridBoxMethod::print(std::ostream& out) const {
4545
}
4646

4747

48-
bool GridBoxMethod::validateMatrixWeights() const {
49-
return false;
48+
WeightMatrix::Check GridBoxMethod::validateMatrixWeights() const {
49+
return {true, true, false};
5050
}
5151

5252

src/mir/method/gridbox/GridBoxMethod.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class GridBoxMethod : public MethodWeighted {
2727
bool sameAs(const Method&) const override;
2828
void json(eckit::JSON&) const override;
2929
void print(std::ostream&) const override;
30-
bool validateMatrixWeights() const override;
30+
WeightMatrix::Check validateMatrixWeights() const override;
3131
};
3232

3333

src/mir/method/knn/KNearestNeighbours.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ void KNearestNeighbours::print(std::ostream& out) const {
162162
}
163163

164164

165-
bool KNearestNeighbours::validateMatrixWeights() const {
165+
WeightMatrix::Check KNearestNeighbours::validateMatrixWeights() const {
166166
return distanceWeighting().validateMatrixWeights();
167167
}
168168

src/mir/method/knn/KNearestNeighbours.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class KNearestNeighbours : public MethodWeighted {
5656
virtual const pick::Pick& pick() const = 0;
5757
virtual const distance::DistanceWeighting& distanceWeighting() const = 0;
5858

59-
virtual bool validateMatrixWeights() const;
59+
WeightMatrix::Check validateMatrixWeights() const override;
6060
};
6161

6262

src/mir/method/knn/distance/DistanceWeighting.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ void DistanceWeightingFactory::list(std::ostream& out) {
8585
}
8686

8787

88-
bool DistanceWeighting::validateMatrixWeights() const {
89-
return true;
88+
WeightMatrix::Check DistanceWeighting::validateMatrixWeights() const {
89+
return {};
9090
}
9191

9292

src/mir/method/knn/distance/DistanceWeighting.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class DistanceWeighting {
4747

4848
virtual void hash(eckit::MD5&) const = 0;
4949

50-
virtual bool validateMatrixWeights() const;
50+
virtual WeightMatrix::Check validateMatrixWeights() const;
5151

5252
private:
5353
virtual void json(eckit::JSON&) const = 0;

src/mir/method/knn/distance/PseudoLaplace.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ void PseudoLaplace::hash(eckit::MD5& h) const {
127127
}
128128

129129

130-
bool PseudoLaplace::validateMatrixWeights() const {
130+
WeightMatrix::Check PseudoLaplace::validateMatrixWeights() const {
131131
// this method does not produce bounded interpolation weights
132-
return false;
132+
return {true, false, false};
133133
}
134134

135135

src/mir/method/knn/distance/PseudoLaplace.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct PseudoLaplace : DistanceWeighting {
2929
void print(std::ostream&) const override;
3030
void hash(eckit::MD5&) const override;
3131

32-
bool validateMatrixWeights() const override;
32+
WeightMatrix::Check validateMatrixWeights() const override;
3333
};
3434

3535

src/mir/method/voronoi/VoronoiMethod.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ void VoronoiMethod::print(std::ostream& out) const {
168168
}
169169

170170

171-
bool VoronoiMethod::validateMatrixWeights() const {
172-
return false;
171+
WeightMatrix::Check VoronoiMethod::validateMatrixWeights() const {
172+
return {true, true, false};
173173
}
174174

175175

src/mir/method/voronoi/VoronoiMethod.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class VoronoiMethod : public MethodWeighted {
3030
bool sameAs(const Method&) const override;
3131
void json(eckit::JSON&) const override;
3232
void print(std::ostream&) const override;
33-
bool validateMatrixWeights() const override;
33+
WeightMatrix::Check validateMatrixWeights() const override;
3434
const char* name() const override;
3535

3636
knn::pick::NClosestOrNearest pick_;

tests/unit/weight_matrix.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@ CASE("WeightMatrix::validate") {
2424
const auto* when{"out-of-bounds"};
2525
const std::string what{
2626
"Invalid weight matrix (out-of-bounds): 1 row error, "
27-
"row 2: weights out-of-bounds, "
28-
"weights sum not 0 or 1 (sum=-0.1, 1-sum=1.1), contents: (2,2,-0.1)"};
27+
"row 2: weights out-of-bounds, contents: (2,2,-0.1)"};
2928

3029
method::WeightMatrix W(3, 3);
3130
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.}, {2, 2, -0.1}});
3231

3332
try {
34-
W.validate(when);
33+
W.validate(when, {false, true, false});
3534
ASSERT(false);
3635
}
3736
catch (exception::InvalidWeightMatrix& e) {
@@ -51,7 +50,7 @@ CASE("WeightMatrix::validate") {
5150
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.5}, {2, 2, 0.1}});
5251

5352
try {
54-
W.validate(when);
53+
W.validate(when, {false, false, true});
5554
ASSERT(false);
5655
}
5756
catch (exception::InvalidWeightMatrix& e) {
@@ -70,7 +69,7 @@ CASE("WeightMatrix::validate") {
7069
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.5}, {1, 1, 0.5}});
7170

7271
try {
73-
W.validate(when);
72+
W.validate(when, {true, false, false});
7473
ASSERT(false);
7574
}
7675
catch (exception::InvalidWeightMatrix& e) {
@@ -90,7 +89,7 @@ CASE("WeightMatrix::validate") {
9089
W.setFromTriplets({{0, 0, 0.5}, {0, 0, 0.5}, {1, 1, 0.5}});
9190

9291
try {
93-
W.validate(when);
92+
W.validate(when, {true, false, true});
9493
ASSERT(false);
9594
}
9695
catch (exception::InvalidWeightMatrix& e) {

0 commit comments

Comments
 (0)