From fc7f2e61d9eb7785017d855a1c803cf6d0a2e814 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Mon, 20 Jan 2025 15:27:39 +0000 Subject: [PATCH] format --- src/algorithm/neighbour/fastpair.rs | 4 +++- src/linalg/basic/matrix.rs | 4 +--- src/linalg/ndarray/matrix.rs | 8 ++----- src/preprocessing/numerical.rs | 14 +++++++------ src/tree/decision_tree_classifier.rs | 31 +++++++++++++++------------- 5 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index 671517df..4e99261b 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -212,7 +212,9 @@ mod tests_fastpair { use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; /// Brute force algorithm, used only for comparison and testing - pub fn closest_pair_brute(fastpair: &FastPair<'_, f64, DenseMatrix>) -> PairwiseDistance { + pub fn closest_pair_brute( + fastpair: &FastPair<'_, f64, DenseMatrix>, + ) -> PairwiseDistance { use itertools::Itertools; let m = fastpair.samples.shape().0; diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 979e5a55..88a0849c 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -579,9 +579,7 @@ impl Array for DenseMatrix } } -impl MutArray - for DenseMatrixMutView<'_, T> -{ +impl MutArray for DenseMatrixMutView<'_, T> { fn set(&mut self, pos: (usize, usize), x: T) { if self.column_major { self.values[pos.0 + pos.1 * self.stride] = x; diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index e406a198..5040497a 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -146,9 +146,7 @@ impl SVDDecomposable for ArrayBase, Ix2> impl ArrayView2 for ArrayView<'_, T, Ix2> {} -impl BaseArray - for ArrayViewMut<'_, T, Ix2> -{ +impl BaseArray for ArrayViewMut<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } @@ -175,9 +173,7 @@ impl BaseArray } } -impl MutArray - for ArrayViewMut<'_, T, Ix2> -{ +impl MutArray for ArrayViewMut<'_, T, Ix2> { fn set(&mut self, pos: (usize, usize), x: T) { self[[pos.0, pos.1]] = x } diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs index 8593d9f8..674f6814 100644 --- a/src/preprocessing/numerical.rs +++ b/src/preprocessing/numerical.rs @@ -172,12 +172,14 @@ where T: Number + RealNumber, M: Array2, { - columns.first().cloned().map(|output_matrix| columns - .iter() - .skip(1) - .fold(output_matrix, |current_matrix, new_colum| { - current_matrix.h_stack(new_colum) - })) + columns.first().cloned().map(|output_matrix| { + columns + .iter() + .skip(1) + .fold(output_matrix, |current_matrix, new_colum| { + current_matrix.h_stack(new_colum) + }) + }) } #[cfg(test)] diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 712cd87d..f63cc2d9 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -77,9 +77,9 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; +use crate::linalg::basic::arrays::MutArray; use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; use crate::linalg::basic::matrix::DenseMatrix; -use crate::linalg::basic::arrays::MutArray; use crate::numbers::basenum::Number; use crate::rand_custom::get_rng_impl; @@ -890,7 +890,6 @@ impl, Y: Array1> importances } - /// Predict class probabilities for the input samples. /// /// # Arguments @@ -933,7 +932,7 @@ impl, Y: Array1> /// of the input sample belonging to each class. fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec { let mut node = 0; - + while let Some(current_node) = self.nodes().get(node) { if current_node.true_child.is_none() && current_node.false_child.is_none() { // Leaf node reached @@ -941,17 +940,17 @@ impl, Y: Array1> probs[current_node.output] = 1.0; return probs; } - + let split_feature = current_node.split_feature; let split_value = current_node.split_value.unwrap_or(f64::NAN); - + if x.get((row, split_feature)).to_f64().unwrap() <= split_value { node = current_node.true_child.unwrap(); } else { node = current_node.false_child.unwrap(); } } - + // This should never happen if the tree is properly constructed vec![0.0; self.classes().len()] } @@ -960,8 +959,8 @@ impl, Y: Array1> #[cfg(test)] mod tests { use super::*; - use crate::linalg::basic::matrix::DenseMatrix; use crate::linalg::basic::arrays::Array; + use crate::linalg::basic::matrix::DenseMatrix; #[test] fn search_parameters() { @@ -1020,24 +1019,28 @@ mod tests { &[6.9, 3.1, 4.9, 1.5], &[5.5, 2.3, 4.0, 1.3], &[6.5, 2.8, 4.6, 1.5], - ]).unwrap(); + ]) + .unwrap(); let y: Vec = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; - + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); let probabilities = tree.predict_proba(&x).unwrap(); - + assert_eq!(probabilities.shape(), (10, 2)); - + for row in 0..10 { let row_sum: f64 = probabilities.get_row(row).sum(); - assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); + assert!( + (row_sum - 1.0).abs() < 1e-6, + "Row probabilities should sum to 1" + ); } - + // Check if the first 5 samples have higher probability for class 0 for i in 0..5 { assert!(probabilities.get((i, 0)) > probabilities.get((i, 1))); } - + // Check if the last 5 samples have higher probability for class 1 for i in 5..10 { assert!(probabilities.get((i, 1)) > probabilities.get((i, 0)));