Skip to content

Commit

Permalink
fix test conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
Mec-iS committed Jan 22, 2025
1 parent 4878042 commit 4aee603
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
}
}

let n_trees = self.trees.as_ref().unwrap().len() as f64;
let n_trees: f64 = self.trees.as_ref().unwrap().len() as f64;
probas.mul_scalar_mut(1.0 / n_trees);

Ok(probas)
Expand Down Expand Up @@ -884,22 +884,22 @@ mod tests {
// These values are approximate and based on typical random forest behavior
for i in 0..(pro_n_rows / 2) {
assert!(
*probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
f64::from_f32(0.6).unwrap().lt(probas.get((i, 0))),
"Class 0 samples should have high probability for class 0"
);
assert!(
*probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
f64::from_f32(0.4).unwrap().gt(probas.get((i, 1))),
"Class 0 samples should have low probability for class 1"
);
}

for i in (pro_n_rows / 2)..pro_n_rows {
assert!(
*probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
f64::from_f32(0.6).unwrap().lt(probas.get((i, 1))),
"Class 1 samples should have high probability for class 1"
);
assert!(
*probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
f64::from_f32(0.4).unwrap().gt(probas.get((i, 0))),
"Class 1 samples should have low probability for class 0"
);
}
Expand Down

0 comments on commit 4aee603

Please sign in to comment.