diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 4f36e5b9..3c9deaf7 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier< num_classes: usize, classes: Vec, depth: u16, + num_features: usize, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, @@ -159,11 +160,13 @@ pub enum SplitCriterion { #[derive(Debug, Clone)] struct Node { output: usize, + n_node_samples: usize, split_feature: usize, split_value: Option, split_score: Option, true_child: Option, false_child: Option, + impurity: Option, } impl, Y: Array1> PartialEq @@ -400,14 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters { } impl Node { - fn new(output: usize) -> Self { + fn new(output: usize, n_node_samples: usize) -> Self { Node { output, + n_node_samples, split_feature: 0, split_value: Option::None, split_score: Option::None, true_child: Option::None, false_child: Option::None, + impurity: Option::None, } } } @@ -507,6 +512,7 @@ impl, Y: Array1> num_classes: 0usize, classes: vec![], depth: 0u16, + num_features: 0usize, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, @@ -578,7 +584,7 @@ impl, Y: Array1> count[yi[i]] += samples[i]; } - let root = Node::new(which_max(&count)); + let root = Node::new(which_max(&count), y_ncols); change_nodes.push(root); let mut order: Vec> = Vec::new(); @@ -593,6 +599,7 @@ impl, Y: Array1> num_classes: k, classes, depth: 0u16, + num_features: num_attributes, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, @@ -678,16 +685,7 @@ impl, Y: Array1> } } - if is_pure { - return false; - } - let n = visitor.samples.iter().sum(); - - if n <= self.parameters().min_samples_split { - return false; - } - let mut count = vec![0; self.num_classes]; let mut false_count = vec![0; self.num_classes]; for i in 0..n_rows { @@ -696,7 +694,15 @@ impl, Y: Array1> } } - let parent_impurity = impurity(&self.parameters().criterion, &count, n); + self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n)); + + if is_pure { + return false; + } + + if n <= self.parameters().min_samples_split { + return false; + } let mut variables = (0..n_attr).collect::>(); @@ -705,14 +711,7 @@ impl, Y: Array1> } for variable in variables.iter().take(mtry) { - self.find_best_split( - visitor, - n, - &count, - &mut false_count, - parent_impurity, - *variable, - ); + self.find_best_split(visitor, n, &count, &mut false_count, *variable); } self.nodes()[visitor.node].split_score.is_some() @@ -724,7 +723,6 @@ impl, Y: Array1> n: usize, count: &[usize], false_count: &mut [usize], - parent_impurity: f64, j: usize, ) { let mut true_count = vec![0; self.num_classes]; @@ -760,6 +758,7 @@ impl, Y: Array1> let true_label = which_max(&true_count); let false_label = which_max(false_count); + let parent_impurity = self.nodes()[visitor.node].impurity.unwrap(); let gain = parent_impurity - tc as f64 / n as f64 * impurity(&self.parameters().criterion, &true_count, tc) @@ -827,9 +826,9 @@ impl, Y: Array1> let true_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.true_child_output)); + self.nodes.push(Node::new(visitor.true_child_output, tc)); let false_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.false_child_output)); + self.nodes.push(Node::new(visitor.false_child_output, fc)); self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx); @@ -863,6 +862,33 @@ impl, Y: Array1> true } + + /// Compute feature importances for the fitted tree. + pub fn compute_feature_importances(&self, normalize: bool) -> Vec { + let mut importances = vec![0f64; self.num_features]; + + for node in self.nodes().iter() { + if node.true_child.is_none() && node.false_child.is_none() { + continue; + } + let left = &self.nodes()[node.true_child.unwrap()]; + let right = &self.nodes()[node.false_child.unwrap()]; + + importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap() + - left.n_node_samples as f64 * left.impurity.unwrap() + - right.n_node_samples as f64 * right.impurity.unwrap(); + } + for item in importances.iter_mut() { + *item /= self.nodes()[0].n_node_samples as f64; + } + if normalize { + let sum = importances.iter().sum::(); + for importance in importances.iter_mut() { + *importance /= sum; + } + } + importances + } } #[cfg(test)] @@ -1016,6 +1042,42 @@ mod tests { ); } + #[test] + fn test_compute_feature_importances() { + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[1., 1., 1., 0.], + &[1., 1., 1., 0.], + &[1., 1., 1., 1.], + &[1., 1., 0., 0.], + &[1., 1., 0., 1.], + &[1., 0., 1., 0.], + &[1., 0., 1., 0.], + &[1., 0., 1., 1.], + &[1., 0., 0., 0.], + &[1., 0., 0., 1.], + &[0., 1., 1., 0.], + &[0., 1., 1., 0.], + &[0., 1., 1., 1.], + &[0., 1., 0., 0.], + &[0., 1., 0., 1.], + &[0., 0., 1., 0.], + &[0., 0., 1., 0.], + &[0., 0., 1., 1.], + &[0., 0., 0., 0.], + &[0., 0., 0., 1.], + ]); + let y: Vec = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + assert_eq!( + tree.compute_feature_importances(false), + vec![0., 0., 0.21333333333333332, 0.26666666666666666] + ); + assert_eq!( + tree.compute_feature_importances(true), + vec![0., 0., 0.4444444444444444, 0.5555555555555556] + ); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test