@@ -833,8 +833,9 @@ mod tests {
833
833
) ]
834
834
#[ test]
835
835
fn test_random_forest_predict_proba ( ) {
836
+ use num_traits:: FromPrimitive ;
836
837
// Iris-like dataset (subset)
837
- let x = DenseMatrix :: from_2d_array ( & [
838
+ let x: DenseMatrix < f64 > = DenseMatrix :: from_2d_array ( & [
838
839
& [ 5.1 , 3.5 , 1.4 , 0.2 ] ,
839
840
& [ 4.9 , 3.0 , 1.4 , 0.2 ] ,
840
841
& [ 4.7 , 3.2 , 1.3 , 0.2 ] ,
@@ -881,21 +882,22 @@ mod tests {
881
882
// These values are approximate and based on typical random forest behavior
882
883
for i in 0 ..5 {
883
884
assert ! (
884
- * probas. get( ( i, 0 ) ) > 0.6 ,
885
+ * probas. get( ( i, 0 ) ) > f64 :: from_f32 ( 0.6 ) . unwrap ( ) ,
885
886
"Class 0 samples should have high probability for class 0"
886
887
) ;
887
888
assert ! (
888
- * probas. get( ( i, 1 ) ) < 0.4 ,
889
+ * probas. get( ( i, 1 ) ) < f64 :: from_f32 ( 0.4 ) . unwrap ( ) ,
889
890
"Class 0 samples should have low probability for class 1"
890
891
) ;
891
892
}
893
+
892
894
for i in 5 ..10 {
893
895
assert ! (
894
- * probas. get( ( i, 1 ) ) > 0.6 ,
896
+ * probas. get( ( i, 1 ) ) > f64 :: from_f32 ( 0.6 ) . unwrap ( ) ,
895
897
"Class 1 samples should have high probability for class 1"
896
898
) ;
897
899
assert ! (
898
- * probas. get( ( i, 0 ) ) < 0.4 ,
900
+ * probas. get( ( i, 0 ) ) < f64 :: from_f32 ( 0.4 ) . unwrap ( ) ,
899
901
"Class 1 samples should have low probability for class 0"
900
902
) ;
901
903
}
0 commit comments