Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the feature importance for Decision Tree Classifier #275

Merged
merged 11 commits into from
Feb 25, 2024
108 changes: 85 additions & 23 deletions src/tree/decision_tree_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier<
num_classes: usize,
classes: Vec<TY>,
depth: u16,
num_features: usize,
_phantom_tx: PhantomData<TX>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
Expand Down Expand Up @@ -159,11 +160,13 @@ pub enum SplitCriterion {
#[derive(Debug, Clone)]
struct Node {
output: usize,
n_node_samples: usize,
split_feature: usize,
split_value: Option<f64>,
split_score: Option<f64>,
true_child: Option<usize>,
false_child: Option<usize>,
impurity: Option<f64>,
}

impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
Expand Down Expand Up @@ -400,14 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters {
}

impl Node {
fn new(output: usize) -> Self {
fn new(output: usize, n_node_samples: usize) -> Self {
Node {
output,
n_node_samples,
split_feature: 0,
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None,
impurity: Option::None,
}
}
}
Expand Down Expand Up @@ -507,6 +512,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
num_classes: 0usize,
classes: vec![],
depth: 0u16,
num_features: 0usize,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
Expand Down Expand Up @@ -578,7 +584,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
count[yi[i]] += samples[i];
}

let root = Node::new(which_max(&count));
let root = Node::new(which_max(&count), y_ncols);
change_nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new();

Expand All @@ -593,6 +599,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
num_classes: k,
classes,
depth: 0u16,
num_features: num_attributes,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
Expand Down Expand Up @@ -678,16 +685,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
}
}

if is_pure {
return false;
}

let n = visitor.samples.iter().sum();

if n <= self.parameters().min_samples_split {
return false;
}

let mut count = vec![0; self.num_classes];
let mut false_count = vec![0; self.num_classes];
for i in 0..n_rows {
Expand All @@ -696,7 +694,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
}
}

let parent_impurity = impurity(&self.parameters().criterion, &count, n);
self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n));

if is_pure {
return false;
}

if n <= self.parameters().min_samples_split {
return false;
}

let mut variables = (0..n_attr).collect::<Vec<_>>();

Expand All @@ -705,14 +711,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
}

for variable in variables.iter().take(mtry) {
self.find_best_split(
visitor,
n,
&count,
&mut false_count,
parent_impurity,
*variable,
);
self.find_best_split(visitor, n, &count, &mut false_count, *variable);
}

self.nodes()[visitor.node].split_score.is_some()
Expand All @@ -724,7 +723,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
n: usize,
count: &[usize],
false_count: &mut [usize],
parent_impurity: f64,
j: usize,
) {
let mut true_count = vec![0; self.num_classes];
Expand Down Expand Up @@ -760,6 +758,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>

let true_label = which_max(&true_count);
let false_label = which_max(false_count);
let parent_impurity = self.nodes()[visitor.node].impurity.unwrap();
let gain = parent_impurity
- tc as f64 / n as f64
* impurity(&self.parameters().criterion, &true_count, tc)
Expand Down Expand Up @@ -827,9 +826,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>

let true_child_idx = self.nodes().len();

self.nodes.push(Node::new(visitor.true_child_output));
self.nodes.push(Node::new(visitor.true_child_output, tc));
let false_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.false_child_output));
self.nodes.push(Node::new(visitor.false_child_output, fc));
self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx);

Expand Down Expand Up @@ -863,6 +862,33 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>

true
}

/// Compute feature importances for the fitted tree.
pub fn compute_feature_importances(&self, normalize: bool) -> Vec<f64> {
let mut importances = vec![0f64; self.num_features];

for node in self.nodes().iter() {
if node.true_child.is_none() && node.false_child.is_none() {
continue;
}
let left = &self.nodes()[node.true_child.unwrap()];
let right = &self.nodes()[node.false_child.unwrap()];

importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap()
- left.n_node_samples as f64 * left.impurity.unwrap()
- right.n_node_samples as f64 * right.impurity.unwrap();
}
for item in importances.iter_mut() {
*item /= self.nodes()[0].n_node_samples as f64;
}
if normalize {
let sum = importances.iter().sum::<f64>();
for importance in importances.iter_mut() {
*importance /= sum;
}
}
importances
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1016,6 +1042,42 @@ mod tests {
);
}

#[test]
fn test_compute_feature_importances() {
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[1., 1., 1., 0.],
&[1., 1., 1., 0.],
&[1., 1., 1., 1.],
&[1., 1., 0., 0.],
&[1., 1., 0., 1.],
&[1., 0., 1., 0.],
&[1., 0., 1., 0.],
&[1., 0., 1., 1.],
&[1., 0., 0., 0.],
&[1., 0., 0., 1.],
&[0., 1., 1., 0.],
&[0., 1., 1., 0.],
&[0., 1., 1., 1.],
&[0., 1., 0., 0.],
&[0., 1., 0., 1.],
&[0., 0., 1., 0.],
&[0., 0., 1., 0.],
&[0., 0., 1., 1.],
&[0., 0., 0., 0.],
&[0., 0., 0., 1.],
]);
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
assert_eq!(
tree.compute_feature_importances(false),
vec![0., 0., 0.21333333333333332, 0.26666666666666666]
);
assert_eq!(
tree.compute_feature_importances(true),
vec![0., 0., 0.4444444444444444, 0.5555555555555556]
);
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Expand Down
Loading