Skip to content

Commit a944ff7

Browse files
committed
Merge branch 'feature/MIR-678' into develop
2 parents c39c248 + f6d5942 commit a944ff7

File tree

10 files changed

+189
-32
lines changed

10 files changed

+189
-32
lines changed

src/mir/method/WeightMatrix.cc

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "mir/method/WeightMatrix.h"
1414

1515
#include <cmath>
16+
#include <sstream>
17+
#include <unordered_set>
1618

1719
#include "eckit/types/FloatCompare.h"
1820

@@ -101,54 +103,68 @@ void WeightMatrix::cleanup(const double& pruneEpsilon) {
101103

102104

103105
void WeightMatrix::validate(const char* when) const {
104-
constexpr size_t Nerrors = 50;
105-
constexpr size_t Nvalues = 10;
106+
constexpr size_t Nerrors = 10;
107+
size_t errors = 0;
106108

107-
size_t errors = 0;
109+
std::ostringstream what;
110+
const char* sep = "";
108111

112+
// check for weights out of bounds 0 <= W(i,j) <= 1, sum(W(i,:))=(0,1), and dulplicate column indices
109113
for (Size r = 0; r < rows(); r++) {
110-
111-
// check for W(i,j)<0, or W(i,j)>1, or sum(W(i,:))!=(0,1)
112-
double sum = 0.;
113-
bool ok = true;
114-
115-
for (const_iterator it = begin(r); it != end(r); ++it) {
116-
double a = *it;
117-
ok &= eckit::types::is_approximately_greater_or_equal(a, 0.) &&
118-
eckit::types::is_approximately_greater_or_equal(1., a);
114+
Scalar sum = 0.;
115+
std::unordered_set<Size> cols;
116+
117+
bool check_bounds = true;
118+
bool check_no_duplicates = true;
119+
for (auto it = begin(r); it != end(r); ++it) {
120+
auto a = *it;
121+
check_bounds &= eckit::types::is_approximately_greater_or_equal(a, 0.) &&
122+
eckit::types::is_approximately_greater_or_equal(1., a);
119123
sum += a;
120-
}
121124

122-
ok &= (eckit::types::is_approximately_equal(sum, 0.) || eckit::types::is_approximately_equal(sum, 1.));
123-
if (ok) {
124-
continue;
125+
check_no_duplicates &= cols.insert(it.col()).second;
125126
}
126127

127-
// log issues, per row
128-
if (Log::debug_active()) {
128+
auto check_sum = eckit::types::is_approximately_equal(sum, 0.) || eckit::types::is_approximately_equal(sum, 1.);
129+
130+
if (!check_bounds || !check_sum || !check_no_duplicates) {
129131
if (errors < Nerrors) {
130-
if (errors == 0) {
131-
Log::debug() << "WeightMatrix::validate(" << when << ") failed " << std::endl;
132+
what << sep << "row " << r << ": ";
133+
const char* s = "";
134+
135+
if (!check_bounds) {
136+
what << s << "weights out-of-bounds";
137+
s = ", ";
132138
}
133139

134-
Log::debug() << "Row: " << r;
135-
size_t n = 0;
136-
for (const_iterator it = begin(r); it != end(r); ++it, ++n) {
137-
if (n > Nvalues) {
138-
Log::debug() << " ...";
139-
break;
140-
}
141-
Log::debug() << " [" << *it << "]";
140+
if (!check_sum) {
141+
what << s << "weights sum not 0 or 1 (sum=" << sum << ", 1-sum=" << (1 - sum) << ")";
142+
s = ", ";
142143
}
143144

144-
Log::debug() << " sum=" << sum << ", 1-sum " << (1 - sum) << std::endl;
145-
}
146-
else if (errors == Nerrors) {
147-
Log::debug() << "..." << std::endl;
145+
if (!check_no_duplicates) {
146+
what << s << "duplicate indices";
147+
s = ", ";
148+
}
149+
150+
what << s << "contents: ";
151+
s = "";
152+
for (auto it = begin(r); it != end(r); ++it) {
153+
what << s << '(' << it.row() << ',' << it.col() << ',' << *it << ')';
154+
s = ", ";
155+
}
148156
}
157+
149158
errors++;
159+
sep = ", ";
150160
}
151161
}
162+
163+
if (errors > 0) {
164+
std::ostringstream errors_str;
165+
errors_str << Log::Pretty{errors, {"row error"}};
166+
throw exception::InvalidWeightMatrix{when, errors_str.str() + ", " + what.str()};
167+
}
152168
}
153169

154170

src/mir/method/knn/KNearestNeighbours.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,9 @@ void KNearestNeighbours::print(std::ostream& out) const {
162162
}
163163

164164

165+
bool KNearestNeighbours::validateMatrixWeights() const {
166+
return distanceWeighting().validateMatrixWeights();
167+
}
168+
169+
165170
} // namespace mir::method::knn

src/mir/method/knn/KNearestNeighbours.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class KNearestNeighbours : public MethodWeighted {
5555

5656
virtual const pick::Pick& pick() const = 0;
5757
virtual const distance::DistanceWeighting& distanceWeighting() const = 0;
58+
59+
virtual bool validateMatrixWeights() const;
5860
};
5961

6062

src/mir/method/knn/distance/DistanceWeighting.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,9 @@ void DistanceWeightingFactory::list(std::ostream& out) {
8585
}
8686

8787

88+
bool DistanceWeighting::validateMatrixWeights() const {
89+
return true;
90+
}
91+
92+
8893
} // namespace mir::method::knn::distance

src/mir/method/knn/distance/DistanceWeighting.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class DistanceWeighting {
4747

4848
virtual void hash(eckit::MD5&) const = 0;
4949

50+
virtual bool validateMatrixWeights() const;
51+
5052
private:
5153
virtual void json(eckit::JSON&) const = 0;
5254
virtual void print(std::ostream&) const = 0;

src/mir/method/knn/distance/PseudoLaplace.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ void PseudoLaplace::hash(eckit::MD5& h) const {
127127
}
128128

129129

130+
bool PseudoLaplace::validateMatrixWeights() const {
131+
// this method does not produce bounded interpolation weights
132+
return false;
133+
}
134+
135+
130136
static const DistanceWeightingBuilder<PseudoLaplace> __distance("pseudo-laplace");
131137

132138

src/mir/method/knn/distance/PseudoLaplace.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ struct PseudoLaplace : DistanceWeighting {
2828
void json(eckit::JSON&) const override;
2929
void print(std::ostream&) const override;
3030
void hash(eckit::MD5&) const override;
31+
32+
bool validateMatrixWeights() const override;
3133
};
3234

3335

src/mir/util/Exceptions.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,14 @@ class CannotConvert : public eckit::Exception {
5555
};
5656

5757

58+
class InvalidWeightMatrix : public eckit::Exception {
59+
public:
60+
InvalidWeightMatrix(const char* when, const std::string& what) {
61+
std::ostringstream os;
62+
os << "Invalid weight matrix (" << when << "): " << what;
63+
reason(os.str());
64+
}
65+
};
66+
67+
5868
} // namespace mir::exception

tests/unit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ foreach(_t
3030
statistics
3131
style
3232
vector-space
33+
weight_matrix
3334
wind)
3435
ecbuild_add_test(
3536
TARGET mir_tests_unit_${_t}

tests/unit/weight_matrix.cc

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* (C) Copyright 1996- ECMWF.
3+
*
4+
* This software is licensed under the terms of the Apache Licence Version 2.0
5+
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6+
*
7+
* In applying this licence, ECMWF does not waive the privileges and immunities
8+
* granted to it by virtue of its status as an intergovernmental organisation nor
9+
* does it submit to any jurisdiction.
10+
*/
11+
12+
13+
#include "eckit/testing/Test.h"
14+
15+
#include "mir/method/WeightMatrix.h"
16+
#include "mir/util/Exceptions.h"
17+
18+
19+
namespace mir::tests::unit {
20+
21+
22+
CASE("WeightMatrix::validate") {
23+
SECTION("out-of-bounds") {
24+
const auto* when{"out-of-bounds"};
25+
const std::string what{
26+
"Invalid weight matrix (out-of-bounds): 1 row error, "
27+
"row 2: weights out-of-bounds, "
28+
"weights sum not 0 or 1 (sum=-0.1, 1-sum=1.1), contents: (2,2,-0.1)"};
29+
30+
method::WeightMatrix W(3, 3);
31+
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.}, {2, 2, -0.1}});
32+
33+
try {
34+
W.validate(when);
35+
ASSERT(false);
36+
}
37+
catch (exception::InvalidWeightMatrix& e) {
38+
EXPECT(e.what() == what);
39+
}
40+
}
41+
42+
43+
SECTION("weights sum") {
44+
const auto* when{"weights sum"};
45+
const std::string what{
46+
"Invalid weight matrix (weights sum): 2 row errors, "
47+
"row 1: weights sum not 0 or 1 (sum=0.5, 1-sum=0.5), contents: (1,1,0.5), "
48+
"row 2: weights sum not 0 or 1 (sum=0.1, 1-sum=0.9), contents: (2,2,0.1)"};
49+
50+
method::WeightMatrix W(3, 3);
51+
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.5}, {2, 2, 0.1}});
52+
53+
try {
54+
W.validate(when);
55+
ASSERT(false);
56+
}
57+
catch (exception::InvalidWeightMatrix& e) {
58+
EXPECT(e.what() == what);
59+
}
60+
}
61+
62+
63+
SECTION("duplicate indices") {
64+
const auto* when{"duplicate indices"};
65+
const std::string what{
66+
"Invalid weight matrix (duplicate indices): 1 row error, "
67+
"row 1: duplicate indices, contents: (1,1,0.5), (1,1,0.5)"};
68+
69+
method::WeightMatrix W(3, 3);
70+
W.setFromTriplets({{0, 0, 1.}, {1, 1, 0.5}, {1, 1, 0.5}});
71+
72+
try {
73+
W.validate(when);
74+
ASSERT(false);
75+
}
76+
catch (exception::InvalidWeightMatrix& e) {
77+
EXPECT(e.what() == what);
78+
}
79+
}
80+
81+
82+
SECTION("mixed") {
83+
const auto* when{"mixed"};
84+
const std::string what{
85+
"Invalid weight matrix (mixed): 2 row errors, "
86+
"row 0: duplicate indices, contents: (0,0,0.5), (0,0,0.5), "
87+
"row 1: weights sum not 0 or 1 (sum=0.5, 1-sum=0.5), contents: (1,1,0.5)"};
88+
89+
method::WeightMatrix W(3, 3);
90+
W.setFromTriplets({{0, 0, 0.5}, {0, 0, 0.5}, {1, 1, 0.5}});
91+
92+
try {
93+
W.validate(when);
94+
ASSERT(false);
95+
}
96+
catch (exception::InvalidWeightMatrix& e) {
97+
EXPECT(e.what() == what);
98+
}
99+
}
100+
}
101+
102+
103+
} // namespace mir::tests::unit
104+
105+
106+
int main(int argc, char** argv) {
107+
return eckit::testing::run_tests(argc, argv);
108+
}

0 commit comments

Comments
 (0)