From 4cf2d8ce18379f8121efdcf00aaa0018d25ecd9b Mon Sep 17 00:00:00 2001 From: tushushu Date: Sun, 21 Jan 2024 17:45:51 +0800 Subject: [PATCH 01/11] store impurity in the node --- src/tree/decision_tree_classifier.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 4f36e5b9..a61bfa02 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -164,6 +164,7 @@ struct Node { split_score: Option, true_child: Option, false_child: Option, + impurity: Option, } impl, Y: Array1> PartialEq @@ -408,6 +409,7 @@ impl Node { split_score: Option::None, true_child: Option::None, false_child: Option::None, + impurity: Option::None, } } } @@ -696,7 +698,7 @@ impl, Y: Array1> } } - let parent_impurity = impurity(&self.parameters().criterion, &count, n); + self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n)); let mut variables = (0..n_attr).collect::>(); @@ -705,14 +707,7 @@ impl, Y: Array1> } for variable in variables.iter().take(mtry) { - self.find_best_split( - visitor, - n, - &count, - &mut false_count, - parent_impurity, - *variable, - ); + self.find_best_split(visitor, n, &count, &mut false_count, *variable); } self.nodes()[visitor.node].split_score.is_some() @@ -724,7 +719,6 @@ impl, Y: Array1> n: usize, count: &[usize], false_count: &mut [usize], - parent_impurity: f64, j: usize, ) { let mut true_count = vec![0; self.num_classes]; @@ -760,6 +754,7 @@ impl, Y: Array1> let true_label = which_max(&true_count); let false_label = which_max(false_count); + let parent_impurity = self.nodes()[visitor.node].impurity.unwrap(); let gain = parent_impurity - tc as f64 / n as f64 * impurity(&self.parameters().criterion, &true_count, tc) From e8aabbfa9347e115554090a48ea99b38ff93f4d2 Mon Sep 17 00:00:00 2001 From: tushushu Date: Sun, 21 Jan 2024 18:09:02 +0800 Subject: [PATCH 02/11] add number of features --- src/tree/decision_tree_classifier.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index a61bfa02..86a9f813 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier< num_classes: usize, classes: Vec, depth: u16, + num_features: usize, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, @@ -509,6 +510,7 @@ impl, Y: Array1> num_classes: 0usize, classes: vec![], depth: 0u16, + num_features: 0usize, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, @@ -595,6 +597,7 @@ impl, Y: Array1> num_classes: k, classes, depth: 0u16, + num_features: num_attributes, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, From 656bd2303098f490e615f68d26e0eb3c1150b856 Mon Sep 17 00:00:00 2001 From: tushushu Date: Sun, 21 Jan 2024 18:14:31 +0800 Subject: [PATCH 03/11] add a TODO --- src/tree/decision_tree_classifier.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 86a9f813..619e6bf8 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -683,6 +683,8 @@ impl, Y: Array1> } } + // TODO: Always calculate the impurity. It is needed for the feature importance. + if is_pure { return false; } From 965a19f25892f1f8f8a06ee83236f4ac8299eacc Mon Sep 17 00:00:00 2001 From: tushushu Date: Sat, 27 Jan 2024 18:07:37 +0800 Subject: [PATCH 04/11] draft feature importance --- src/tree/decision_tree_classifier.rs | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 619e6bf8..5ee3ecdc 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -863,6 +863,35 @@ impl, Y: Array1> true } + + pub fn compute_feature_importances(&self) -> f64 { + let mut importances = vec![0f64; self.num_features]; + + for node in self.nodes().iter() { + if node.true_child.is_none() && node.false_child.is_none() { + continue; + } + let left = self.nodes()[node.true_child.unwrap()]; + let right = self.nodes()[node.false_child.unwrap()]; + + importances[node.split_feature] += (node.impurity.unwrap() + * node.impurity.unwrap() + * node.impurity.unwrap() + - left.impurity.unwrap() * left.impurity.unwrap() * left.impurity.unwrap() + - right.impurity.unwrap() * right.impurity.unwrap() * right.impurity.unwrap(),) + } + + let sum = importances.iter().sum::(); + + for importance in importances.iter_mut() { + *importance /= sum; + } + + *importances + .iter() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap() + } } #[cfg(test)] From aca711c401bdf02a0fb3ecfff70e6ea2fb0ea9eb Mon Sep 17 00:00:00 2001 From: tushushu Date: Sun, 28 Jan 2024 17:22:33 +0800 Subject: [PATCH 05/11] feat --- src/tree/decision_tree_classifier.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 5ee3ecdc..f0872027 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -166,6 +166,7 @@ struct Node { true_child: Option, false_child: Option, impurity: Option, + n_node_samples: usize, } impl, Y: Array1> PartialEq @@ -411,6 +412,7 @@ impl Node { true_child: Option::None, false_child: Option::None, impurity: Option::None, + n_node_samples: 0, } } } From 1746a7f0ef1529243af440ae7db8e9819a7f39df Mon Sep 17 00:00:00 2001 From: tushushu Date: Sun, 28 Jan 2024 20:01:06 +0800 Subject: [PATCH 06/11] n_samples of node --- src/tree/decision_tree_classifier.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index f0872027..c20c2013 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -160,13 +160,13 @@ pub enum SplitCriterion { #[derive(Debug, Clone)] struct Node { output: usize, + n_node_samples: usize, split_feature: usize, split_value: Option, split_score: Option, true_child: Option, false_child: Option, impurity: Option, - n_node_samples: usize, } impl, Y: Array1> PartialEq @@ -403,16 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters { } impl Node { - fn new(output: usize) -> Self { + fn new(output: usize, n_node_samples: usize) -> Self { Node { output, + n_node_samples, split_feature: 0, split_value: Option::None, split_score: Option::None, true_child: Option::None, false_child: Option::None, impurity: Option::None, - n_node_samples: 0, } } } @@ -584,7 +584,7 @@ impl, Y: Array1> count[yi[i]] += samples[i]; } - let root = Node::new(which_max(&count)); + let root = Node::new(which_max(&count), y_ncols); change_nodes.push(root); let mut order: Vec> = Vec::new(); @@ -829,9 +829,9 @@ impl, Y: Array1> let true_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.true_child_output)); + self.nodes.push(Node::new(visitor.true_child_output, tc)); let false_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.false_child_output)); + self.nodes.push(Node::new(visitor.false_child_output, fc)); self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx); From 016290167979b5bfdfeac86967984f95fa850262 Mon Sep 17 00:00:00 2001 From: tushushu Date: Sun, 28 Jan 2024 20:06:22 +0800 Subject: [PATCH 07/11] compute_feature_importances --- src/tree/decision_tree_classifier.rs | 36 +++++++++++++--------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index c20c2013..f1901669 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -866,33 +866,31 @@ impl, Y: Array1> true } - pub fn compute_feature_importances(&self) -> f64 { + pub fn compute_feature_importances(&self, normalize: bool) -> Vec { let mut importances = vec![0f64; self.num_features]; for node in self.nodes().iter() { if node.true_child.is_none() && node.false_child.is_none() { continue; } - let left = self.nodes()[node.true_child.unwrap()]; - let right = self.nodes()[node.false_child.unwrap()]; - - importances[node.split_feature] += (node.impurity.unwrap() - * node.impurity.unwrap() - * node.impurity.unwrap() - - left.impurity.unwrap() * left.impurity.unwrap() * left.impurity.unwrap() - - right.impurity.unwrap() * right.impurity.unwrap() * right.impurity.unwrap(),) - } - - let sum = importances.iter().sum::(); + let left = &self.nodes()[node.true_child.unwrap()]; + let right = &self.nodes()[node.false_child.unwrap()]; - for importance in importances.iter_mut() { - *importance /= sum; + importances[node.split_feature] += node.n_node_samples as f64 + * (node.impurity.unwrap() + - left.n_node_samples as f64 * left.impurity.unwrap() + - right.n_node_samples as f64 * right.impurity.unwrap()); } - - *importances - .iter() - .max_by(|a, b| a.partial_cmp(b).unwrap()) - .unwrap() + for i in 0..self.num_features { + importances[i] /= self.nodes()[0].n_node_samples as f64; + } + if normalize { + let sum = importances.iter().sum::(); + for importance in importances.iter_mut() { + *importance /= sum; + } + } + importances } } From 2c2da127e8967e2b92020eb162d025d2e1a76ba1 Mon Sep 17 00:00:00 2001 From: tushushu Date: Fri, 2 Feb 2024 19:31:05 +0800 Subject: [PATCH 08/11] unit tests --- src/tree/decision_tree_classifier.rs | 37 ++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index f1901669..3a3e416f 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -866,6 +866,7 @@ impl, Y: Array1> true } + /// Compute feature importances for the fitted tree. pub fn compute_feature_importances(&self, normalize: bool) -> Vec { let mut importances = vec![0f64; self.num_features]; @@ -1045,6 +1046,42 @@ mod tests { ); } + #[test] + fn test_compute_feature_importances() { + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[1., 1., 1., 0.], + &[1., 1., 1., 0.], + &[1., 1., 1., 1.], + &[1., 1., 0., 0.], + &[1., 1., 0., 1.], + &[1., 0., 1., 0.], + &[1., 0., 1., 0.], + &[1., 0., 1., 1.], + &[1., 0., 0., 0.], + &[1., 0., 0., 1.], + &[0., 1., 1., 0.], + &[0., 1., 1., 0.], + &[0., 1., 1., 1.], + &[0., 1., 0., 0.], + &[0., 1., 0., 1.], + &[0., 0., 1., 0.], + &[0., 0., 1., 0.], + &[0., 0., 1., 1.], + &[0., 0., 0., 0.], + &[0., 0., 0., 1.], + ]); + let y: Vec = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + assert_eq!( + tree.compute_feature_importances(false), + vec![0.0, 0.0, 0.0, 0.0] + ); + assert_eq!( + tree.compute_feature_importances(true), + vec![0.0, 0.0, 0.0, 0.0] + ); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test From 584389be2769fe1f02a680f1c36b2556860be284 Mon Sep 17 00:00:00 2001 From: tushushu Date: Sat, 3 Feb 2024 19:56:22 +0800 Subject: [PATCH 09/11] always calculate impurity --- src/tree/decision_tree_classifier.rs | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 3a3e416f..ecd9110d 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -685,18 +685,7 @@ impl, Y: Array1> } } - // TODO: Always calculate the impurity. It is needed for the feature importance. - - if is_pure { - return false; - } - let n = visitor.samples.iter().sum(); - - if n <= self.parameters().min_samples_split { - return false; - } - let mut count = vec![0; self.num_classes]; let mut false_count = vec![0; self.num_classes]; for i in 0..n_rows { @@ -707,6 +696,14 @@ impl, Y: Array1> self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n)); + if is_pure { + return false; + } + + if n <= self.parameters().min_samples_split { + return false; + } + let mut variables = (0..n_attr).collect::>(); if mtry < n_attr { @@ -1078,7 +1075,7 @@ mod tests { ); assert_eq!( tree.compute_feature_importances(true), - vec![0.0, 0.0, 0.0, 0.0] + vec![0., 0., 0.44444444, 0.55555556] ); } From 1faae6401456de2f82c1b925388883e870d6b34a Mon Sep 17 00:00:00 2001 From: tushushu Date: Sat, 3 Feb 2024 20:28:50 +0800 Subject: [PATCH 10/11] fix bug --- src/tree/decision_tree_classifier.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index ecd9110d..2fe283e1 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -874,10 +874,9 @@ impl, Y: Array1> let left = &self.nodes()[node.true_child.unwrap()]; let right = &self.nodes()[node.false_child.unwrap()]; - importances[node.split_feature] += node.n_node_samples as f64 - * (node.impurity.unwrap() - - left.n_node_samples as f64 * left.impurity.unwrap() - - right.n_node_samples as f64 * right.impurity.unwrap()); + importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap() + - left.n_node_samples as f64 * left.impurity.unwrap() + - right.n_node_samples as f64 * right.impurity.unwrap(); } for i in 0..self.num_features { importances[i] /= self.nodes()[0].n_node_samples as f64; @@ -1071,11 +1070,11 @@ mod tests { let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); assert_eq!( tree.compute_feature_importances(false), - vec![0.0, 0.0, 0.0, 0.0] + vec![0., 0., 0.21333333333333332, 0.26666666666666666] ); assert_eq!( tree.compute_feature_importances(true), - vec![0., 0., 0.44444444, 0.55555556] + vec![0., 0., 0.4444444444444444, 0.5555555555555556] ); } From 42c49492ee5dc06f5991799c414b3d1c3c7db90f Mon Sep 17 00:00:00 2001 From: tushushu Date: Wed, 7 Feb 2024 09:14:01 +0800 Subject: [PATCH 11/11] fix linter --- src/tree/decision_tree_classifier.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 2fe283e1..3c9deaf7 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -878,8 +878,8 @@ impl, Y: Array1> - left.n_node_samples as f64 * left.impurity.unwrap() - right.n_node_samples as f64 * right.impurity.unwrap(); } - for i in 0..self.num_features { - importances[i] /= self.nodes()[0].n_node_samples as f64; + for item in importances.iter_mut() { + *item /= self.nodes()[0].n_node_samples as f64; } if normalize { let sum = importances.iter().sum::();