Skip to content

Commit bb356e6

Browse files
committed
fix test
1 parent 52b797d commit bb356e6

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/ensemble/random_forest_classifier.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -833,8 +833,9 @@ mod tests {
833833
)]
834834
#[test]
835835
fn test_random_forest_predict_proba() {
836+
use num_traits::FromPrimitive;
836837
// Iris-like dataset (subset)
837-
let x = DenseMatrix::from_2d_array(&[
838+
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
838839
&[5.1, 3.5, 1.4, 0.2],
839840
&[4.9, 3.0, 1.4, 0.2],
840841
&[4.7, 3.2, 1.3, 0.2],
@@ -881,21 +882,22 @@ mod tests {
881882
// These values are approximate and based on typical random forest behavior
882883
for i in 0..5 {
883884
assert!(
884-
*probas.get((i, 0)) > 0.6,
885+
*probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
885886
"Class 0 samples should have high probability for class 0"
886887
);
887888
assert!(
888-
*probas.get((i, 1)) < 0.4,
889+
*probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
889890
"Class 0 samples should have low probability for class 1"
890891
);
891892
}
893+
892894
for i in 5..10 {
893895
assert!(
894-
*probas.get((i, 1)) > 0.6,
896+
*probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
895897
"Class 1 samples should have high probability for class 1"
896898
);
897899
assert!(
898-
*probas.get((i, 0)) < 0.4,
900+
*probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
899901
"Class 1 samples should have low probability for class 0"
900902
);
901903
}

0 commit comments

Comments
 (0)