Skip to content

Commit 3ce1fe7

Browse files
committed
Fix bad test in MOE: do not test derivatives on hard recombination
1 parent 3e13421 commit 3ce1fe7

File tree

1 file changed

+20
-27
lines changed

1 file changed

+20
-27
lines changed

moe/src/algorithm.rs

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ mod tests {
10621062
}
10631063

10641064
#[test]
1065-
fn test_moe_drv_hard() {
1065+
fn test_moe_drv_smooth() {
10661066
let rng = Xoshiro256Plus::seed_from_u64(0);
10671067
let xt = Lhs::new(&array![[0., 1.]]).sample(100);
10681068
let yt = f_test_1d(&xt);
@@ -1071,7 +1071,7 @@ mod tests {
10711071
.n_clusters(3)
10721072
.regression_spec(RegressionSpec::CONSTANT)
10731073
.correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1074-
.recombination(Recombination::Hard)
1074+
.recombination(Recombination::Smooth(Some(0.5)))
10751075
.with_rng(rng)
10761076
.fit(&Dataset::new(xt, yt))
10771077
.expect("MOE fitted");
@@ -1090,32 +1090,25 @@ mod tests {
10901090
for _ in 0..20 {
10911091
let x1: f64 = rng.gen_range(0.0..1.0);
10921092

1093-
if (0.39 < x1 && x1 < 0.41) || (0.79 < x1 && x1 < 0.81) {
1094-
// avoid testing hard on discoontinuity
1095-
continue;
1093+
let h = 1e-4;
1094+
let xtest = array![[x1]];
1095+
1096+
let x = array![[x1], [x1 + h], [x1 - h]];
1097+
let preds = moe.predict_derivatives(&x).unwrap();
1098+
let fdiff = preds[[1, 0]] - preds[[1, 0]] / 2. * h;
1099+
1100+
let drv = moe.predict_derivatives(&xtest).unwrap();
1101+
let df = df_test_1d(&xtest);
1102+
1103+
let err = if drv[[0, 0]] < 0.2 {
1104+
(drv[[0, 0]] - fdiff).abs()
10961105
} else {
1097-
let h = 1e-4;
1098-
let xtest = array![[x1]];
1099-
1100-
let x = array![[x1], [x1 + h], [x1 - h]];
1101-
let preds = moe.predict_derivatives(&x).unwrap();
1102-
let fdiff = preds[[1, 0]] - preds[[1, 0]] / 2. * h;
1103-
1104-
let drv = moe.predict_derivatives(&xtest).unwrap();
1105-
let df = df_test_1d(&xtest);
1106-
1107-
if (df[[0, 0]] - fdiff).abs() > 10.0 {
1108-
let err = if drv[[0, 0]] < 0.2 {
1109-
(drv[[0, 0]] - fdiff).abs()
1110-
} else {
1111-
(drv[[0, 0]] - fdiff).abs() / drv[[0, 0]]
1112-
};
1113-
println!(
1114-
"Test predicted derivatives at {xtest}: drv {drv}, true df {df}, fdiff {fdiff}"
1115-
);
1116-
assert_abs_diff_eq!(err, 0.0, epsilon = 2e-1);
1117-
}
1118-
}
1106+
(drv[[0, 0]] - fdiff).abs() / drv[[0, 0]]
1107+
};
1108+
println!(
1109+
"Test predicted derivatives at {xtest}: drv {drv}, true df {df}, fdiff {fdiff}"
1110+
);
1111+
assert_abs_diff_eq!(err, 0.0, epsilon = 2e-1);
11191112
}
11201113
}
11211114

0 commit comments

Comments
 (0)