@@ -1062,7 +1062,7 @@ mod tests {
1062
1062
}
1063
1063
1064
1064
#[ test]
1065
- fn test_moe_drv_hard ( ) {
1065
+ fn test_moe_drv_smooth ( ) {
1066
1066
let rng = Xoshiro256Plus :: seed_from_u64 ( 0 ) ;
1067
1067
let xt = Lhs :: new ( & array ! [ [ 0. , 1. ] ] ) . sample ( 100 ) ;
1068
1068
let yt = f_test_1d ( & xt) ;
@@ -1071,7 +1071,7 @@ mod tests {
1071
1071
. n_clusters ( 3 )
1072
1072
. regression_spec ( RegressionSpec :: CONSTANT )
1073
1073
. correlation_spec ( CorrelationSpec :: SQUAREDEXPONENTIAL )
1074
- . recombination ( Recombination :: Hard )
1074
+ . recombination ( Recombination :: Smooth ( Some ( 0.5 ) ) )
1075
1075
. with_rng ( rng)
1076
1076
. fit ( & Dataset :: new ( xt, yt) )
1077
1077
. expect ( "MOE fitted" ) ;
@@ -1090,32 +1090,25 @@ mod tests {
1090
1090
for _ in 0 ..20 {
1091
1091
let x1: f64 = rng. gen_range ( 0.0 ..1.0 ) ;
1092
1092
1093
- if ( 0.39 < x1 && x1 < 0.41 ) || ( 0.79 < x1 && x1 < 0.81 ) {
1094
- // avoid testing hard on discoontinuity
1095
- continue ;
1093
+ let h = 1e-4 ;
1094
+ let xtest = array ! [ [ x1] ] ;
1095
+
1096
+ let x = array ! [ [ x1] , [ x1 + h] , [ x1 - h] ] ;
1097
+ let preds = moe. predict_derivatives ( & x) . unwrap ( ) ;
1098
+ let fdiff = preds[ [ 1 , 0 ] ] - preds[ [ 1 , 0 ] ] / 2. * h;
1099
+
1100
+ let drv = moe. predict_derivatives ( & xtest) . unwrap ( ) ;
1101
+ let df = df_test_1d ( & xtest) ;
1102
+
1103
+ let err = if drv[ [ 0 , 0 ] ] < 0.2 {
1104
+ ( drv[ [ 0 , 0 ] ] - fdiff) . abs ( )
1096
1105
} else {
1097
- let h = 1e-4 ;
1098
- let xtest = array ! [ [ x1] ] ;
1099
-
1100
- let x = array ! [ [ x1] , [ x1 + h] , [ x1 - h] ] ;
1101
- let preds = moe. predict_derivatives ( & x) . unwrap ( ) ;
1102
- let fdiff = preds[ [ 1 , 0 ] ] - preds[ [ 1 , 0 ] ] / 2. * h;
1103
-
1104
- let drv = moe. predict_derivatives ( & xtest) . unwrap ( ) ;
1105
- let df = df_test_1d ( & xtest) ;
1106
-
1107
- if ( df[ [ 0 , 0 ] ] - fdiff) . abs ( ) > 10.0 {
1108
- let err = if drv[ [ 0 , 0 ] ] < 0.2 {
1109
- ( drv[ [ 0 , 0 ] ] - fdiff) . abs ( )
1110
- } else {
1111
- ( drv[ [ 0 , 0 ] ] - fdiff) . abs ( ) / drv[ [ 0 , 0 ] ]
1112
- } ;
1113
- println ! (
1114
- "Test predicted derivatives at {xtest}: drv {drv}, true df {df}, fdiff {fdiff}"
1115
- ) ;
1116
- assert_abs_diff_eq ! ( err, 0.0 , epsilon = 2e-1 ) ;
1117
- }
1118
- }
1106
+ ( drv[ [ 0 , 0 ] ] - fdiff) . abs ( ) / drv[ [ 0 , 0 ] ]
1107
+ } ;
1108
+ println ! (
1109
+ "Test predicted derivatives at {xtest}: drv {drv}, true df {df}, fdiff {fdiff}"
1110
+ ) ;
1111
+ assert_abs_diff_eq ! ( err, 0.0 , epsilon = 2e-1 ) ;
1119
1112
}
1120
1113
}
1121
1114
0 commit comments