From a647816e3d80876f3a794891e9314269f50555c8 Mon Sep 17 00:00:00 2001
From: Chandan Singh
Date: Thu, 28 Jul 2022 14:05:04 -0700
Subject: [PATCH] hs working for tree ensembles
---
docs/shrinkage.html | 22 ++++++++++++++-
docs/tree/hierarchical_shrinkage.html | 39 ++++++++++++++++++++++++--
imodels/tree/hierarchical_shrinkage.py | 13 ++++++++-
setup.py | 2 +-
tests/estimator_checks_test.py | 12 ++++----
tests/shrinkage_test.py | 11 +++++++-
6 files changed, 86 insertions(+), 13 deletions(-)
diff --git a/docs/shrinkage.html b/docs/shrinkage.html
index 0a67c1a5..8072ffb9 100644
--- a/docs/shrinkage.html
+++ b/docs/shrinkage.html
@@ -49,7 +49,8 @@ How does Hierarchical shrinkage work?
An example using HS
-HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the fit
and predict
methods. Here's a full example of using it on a sample clinical dataset.
+HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the fit
and predict
methods.
+Here's a full example of using it on a sample clinical dataset.
@@ -73,6 +74,9 @@ An example using HS
+Here we used HSTreeClassifierCV
, which selects the amount of regularization to use via cross-validation, but we can also use HSTreeClassifier
if we want to specify a particular amount of regularization.
+For regression, we can use the corresponding classes: HSTreeRegressorCV
and HSTreeRegressor
.
+
@@ -103,6 +107,22 @@
Examples with HS on synthetic data
+Applying HS to tree ensembles
+
+HS can also be used on tree ensembles to regularize each tree in an ensemble (e.g. in a Random Forest).
+We must simply pass the desired estimator during initialization.
+
+
+
+from sklearn.ensemble import RandomForestClassifier # also works with ExtraTreesClassifier, GradientBoostingClassifier
+from imodels import HSTreeClassifier
+ensemble = RandomForestClassifier()
+model = HSTreeClassifier(estimator_=ensemble)
+model = model.fit(X_train, y_train)
+
+
+
+
diff --git a/docs/tree/hierarchical_shrinkage.html b/docs/tree/hierarchical_shrinkage.html
index 41846717..77ef8c24 100644
--- a/docs/tree/hierarchical_shrinkage.html
+++ b/docs/tree/hierarchical_shrinkage.html
@@ -78,7 +78,18 @@
self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed
self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs)
self._shrink()
- self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
+
+ # compute complexity
+ if hasattr(self.estimator_, 'tree_'):
+ self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
+ elif hasattr(self.estimator_, 'estimators_'):
+ self.complexity_ = 0
+ for i in range(len(self.estimator_.estimators_)):
+ t = deepcopy(self.estimator_.estimators_[i])
+ if isinstance(t, np.ndarray):
+ assert t.size == 1, 'multiple trees stored under tree_?'
+ t = t[0]
+ self.complexity_ += compute_tree_complexity(t.tree_)
return self
def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0):
@@ -406,7 +417,18 @@ Params
self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed
self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs)
self._shrink()
- self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
+
+ # compute complexity
+ if hasattr(self.estimator_, 'tree_'):
+ self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
+ elif hasattr(self.estimator_, 'estimators_'):
+ self.complexity_ = 0
+ for i in range(len(self.estimator_.estimators_)):
+ t = deepcopy(self.estimator_.estimators_[i])
+ if isinstance(t, np.ndarray):
+ assert t.size == 1, 'multiple trees stored under tree_?'
+ t = t[0]
+ self.complexity_ += compute_tree_complexity(t.tree_)
return self
def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0):
@@ -544,7 +566,18 @@ Methods
self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed
self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs)
self._shrink()
- self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
+
+ # compute complexity
+ if hasattr(self.estimator_, 'tree_'):
+ self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
+ elif hasattr(self.estimator_, 'estimators_'):
+ self.complexity_ = 0
+ for i in range(len(self.estimator_.estimators_)):
+ t = deepcopy(self.estimator_.estimators_[i])
+ if isinstance(t, np.ndarray):
+ assert t.size == 1, 'multiple trees stored under tree_?'
+ t = t[0]
+ self.complexity_ += compute_tree_complexity(t.tree_)
return self
diff --git a/imodels/tree/hierarchical_shrinkage.py b/imodels/tree/hierarchical_shrinkage.py
index 2f3f046c..6b38efa4 100644
--- a/imodels/tree/hierarchical_shrinkage.py
+++ b/imodels/tree/hierarchical_shrinkage.py
@@ -55,7 +55,18 @@ def fit(self, X, y, *args, **kwargs):
self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed
self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs)
self._shrink()
- self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
+
+ # compute complexity
+ if hasattr(self.estimator_, 'tree_'):
+ self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
+ elif hasattr(self.estimator_, 'estimators_'):
+ self.complexity_ = 0
+ for i in range(len(self.estimator_.estimators_)):
+ t = deepcopy(self.estimator_.estimators_[i])
+ if isinstance(t, np.ndarray):
+ assert t.size == 1, 'multiple trees stored under tree_?'
+ t = t[0]
+ self.complexity_ += compute_tree_complexity(t.tree_)
return self
def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0):
diff --git a/setup.py b/setup.py
index f7f31ae9..115cca60 100644
--- a/setup.py
+++ b/setup.py
@@ -26,7 +26,7 @@
setuptools.setup(
name="imodels",
- version="1.3.2",
+ version="1.3.3",
author="Chandan Singh, Keyan Nasseri, Bin Yu, and others",
author_email="chandan_singh@berkeley.edu",
description="Implementations of various interpretable models",
diff --git a/tests/estimator_checks_test.py b/tests/estimator_checks_test.py
index ff17c4c2..867ad9de 100644
--- a/tests/estimator_checks_test.py
+++ b/tests/estimator_checks_test.py
@@ -6,19 +6,19 @@
class TestCheckEstimators(unittest.TestCase):
- '''Checks that estimators conform to sklearn checks
- '''
+ """Checks that estimators conform to sklearn checks
+ """
def test_check_classifier_compatibility(self):
- '''Test classifiers are properly sklearn-compatible
- '''
+ """Test classifiers are properly sklearn-compatible
+ """
for classifier in [imodels.SLIMClassifier]: # BoostedRulesClassifier (multi-class not supported)
check_estimator(classifier())
assert 'passed check_estimator for ' + str(classifier)
def test_check_regressor_compatibility(self):
- '''Test regressors are properly sklearn-compatible
- '''
+ """Test regressors are properly sklearn-compatible
+ """
for regr in []: # SLIMRegressor fails acc screening for boston dset
check_estimator(regr())
assert 'passed check_estimator for ' + str(regr)
diff --git a/tests/shrinkage_test.py b/tests/shrinkage_test.py
index bfb0664a..72615f37 100644
--- a/tests/shrinkage_test.py
+++ b/tests/shrinkage_test.py
@@ -2,7 +2,7 @@
from functools import partial
import numpy as np
-from sklearn.ensemble import VotingRegressor
+from sklearn.ensemble import VotingRegressor, RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from imodels import HSTreeClassifier, HSTreeClassifierCV, \
@@ -46,8 +46,11 @@ def test_classification_shrinkage(self):
'''
for model_type in [
+ partial(HSTreeClassifier, estimator_=DecisionTreeClassifier()),
+ partial(HSTreeClassifier, estimator_=GradientBoostingClassifier()),
partial(HSTreeClassifier, estimator_=DecisionTreeClassifier()),
partial(HSTreeClassifierCV, estimator_=DecisionTreeClassifier()),
+ partial(HSTreeClassifierCV, estimator_=RandomForestClassifier()),
partial(HSC45TreeClassifierCV, estimator_=C45TreeClassifier()),
HSTreeClassifierCV, # default estimator is Decision tree with 25 max_leaf_nodes
# partial(HSOptimalTreeClassifierCV, estimator_=OptimalTreeClassifier()),
@@ -74,6 +77,9 @@ def test_classification_shrinkage(self):
# print(type(m), m, 'final acc', acc_train)
assert acc_train > 0.8, 'acc greater than 0.8'
+ # complexity
+ assert m.complexity_ > 0, 'complexity is greater than 0'
+
def test_recognized_by_sklearn(self):
base_models = [('hs', HSTreeRegressor(DecisionTreeRegressor())),
('dt', DecisionTreeRegressor())]
@@ -97,6 +103,9 @@ def test_regression_shrinkage(self):
mse = np.mean(np.square(preds - self.y_regression))
assert mse < 1, 'mse less than 1'
+ # complexity
+ assert m.complexity_ > 0, 'complexity is greater than 0'
+
if __name__ == '__main__':
t = TestShrinkage()