From a647816e3d80876f3a794891e9314269f50555c8 Mon Sep 17 00:00:00 2001 From: Chandan Singh Date: Thu, 28 Jul 2022 14:05:04 -0700 Subject: [PATCH] hs working for tree ensembles --- docs/shrinkage.html | 22 ++++++++++++++- docs/tree/hierarchical_shrinkage.html | 39 ++++++++++++++++++++++++-- imodels/tree/hierarchical_shrinkage.py | 13 ++++++++- setup.py | 2 +- tests/estimator_checks_test.py | 12 ++++---- tests/shrinkage_test.py | 11 +++++++- 6 files changed, 86 insertions(+), 13 deletions(-) diff --git a/docs/shrinkage.html b/docs/shrinkage.html index 0a67c1a5..8072ffb9 100644 --- a/docs/shrinkage.html +++ b/docs/shrinkage.html @@ -49,7 +49,8 @@

How does Hierarchical shrinkage work?

An example using HS

-HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the fit and predict methods. Here's a full example of using it on a sample clinical dataset. +HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the fit and predict methods. +Here's a full example of using it on a sample clinical dataset.
     
@@ -73,6 +74,9 @@ 

An example using HS

+Here we used HSTreeClassifierCV, which selects the amount of regularization to use via cross-validation, but we can also use HSTreeClassifier if we want to specify a particular amount of regularization. +For regression, we can use the corresponding classes: HSTreeRegressorCV and HSTreeRegressor. +


@@ -103,6 +107,22 @@

Examples with HS on synthetic data

+

Applying HS to tree ensembles

+ +HS can also be used on tree ensembles to regularize each tree in an ensemble (e.g. in a Random Forest). +We must simply pass the desired estimator during initialization. + +
+    
+from sklearn.ensemble import RandomForestClassifier # also works with ExtraTreesClassifier, GradientBoostingClassifier
+from imodels import HSTreeClassifier
+ensemble = RandomForestClassifier()
+model = HSTreeClassifier(estimator_=ensemble)
+model = model.fit(X_train, y_train)
+        
+
+ + diff --git a/docs/tree/hierarchical_shrinkage.html b/docs/tree/hierarchical_shrinkage.html index 41846717..77ef8c24 100644 --- a/docs/tree/hierarchical_shrinkage.html +++ b/docs/tree/hierarchical_shrinkage.html @@ -78,7 +78,18 @@ self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs) self._shrink() - self.complexity_ = compute_tree_complexity(self.estimator_.tree_) + + # compute complexity + if hasattr(self.estimator_, 'tree_'): + self.complexity_ = compute_tree_complexity(self.estimator_.tree_) + elif hasattr(self.estimator_, 'estimators_'): + self.complexity_ = 0 + for i in range(len(self.estimator_.estimators_)): + t = deepcopy(self.estimator_.estimators_[i]) + if isinstance(t, np.ndarray): + assert t.size == 1, 'multiple trees stored under tree_?' + t = t[0] + self.complexity_ += compute_tree_complexity(t.tree_) return self def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0): @@ -406,7 +417,18 @@

Params

self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs) self._shrink() - self.complexity_ = compute_tree_complexity(self.estimator_.tree_) + + # compute complexity + if hasattr(self.estimator_, 'tree_'): + self.complexity_ = compute_tree_complexity(self.estimator_.tree_) + elif hasattr(self.estimator_, 'estimators_'): + self.complexity_ = 0 + for i in range(len(self.estimator_.estimators_)): + t = deepcopy(self.estimator_.estimators_[i]) + if isinstance(t, np.ndarray): + assert t.size == 1, 'multiple trees stored under tree_?' + t = t[0] + self.complexity_ += compute_tree_complexity(t.tree_) return self def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0): @@ -544,7 +566,18 @@

Methods

self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs) self._shrink() - self.complexity_ = compute_tree_complexity(self.estimator_.tree_) + + # compute complexity + if hasattr(self.estimator_, 'tree_'): + self.complexity_ = compute_tree_complexity(self.estimator_.tree_) + elif hasattr(self.estimator_, 'estimators_'): + self.complexity_ = 0 + for i in range(len(self.estimator_.estimators_)): + t = deepcopy(self.estimator_.estimators_[i]) + if isinstance(t, np.ndarray): + assert t.size == 1, 'multiple trees stored under tree_?' + t = t[0] + self.complexity_ += compute_tree_complexity(t.tree_) return self diff --git a/imodels/tree/hierarchical_shrinkage.py b/imodels/tree/hierarchical_shrinkage.py index 2f3f046c..6b38efa4 100644 --- a/imodels/tree/hierarchical_shrinkage.py +++ b/imodels/tree/hierarchical_shrinkage.py @@ -55,7 +55,18 @@ def fit(self, X, y, *args, **kwargs): self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs) self._shrink() - self.complexity_ = compute_tree_complexity(self.estimator_.tree_) + + # compute complexity + if hasattr(self.estimator_, 'tree_'): + self.complexity_ = compute_tree_complexity(self.estimator_.tree_) + elif hasattr(self.estimator_, 'estimators_'): + self.complexity_ = 0 + for i in range(len(self.estimator_.estimators_)): + t = deepcopy(self.estimator_.estimators_[i]) + if isinstance(t, np.ndarray): + assert t.size == 1, 'multiple trees stored under tree_?' + t = t[0] + self.complexity_ += compute_tree_complexity(t.tree_) return self def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0): diff --git a/setup.py b/setup.py index f7f31ae9..115cca60 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ setuptools.setup( name="imodels", - version="1.3.2", + version="1.3.3", author="Chandan Singh, Keyan Nasseri, Bin Yu, and others", author_email="chandan_singh@berkeley.edu", description="Implementations of various interpretable models", diff --git a/tests/estimator_checks_test.py b/tests/estimator_checks_test.py index ff17c4c2..867ad9de 100644 --- a/tests/estimator_checks_test.py +++ b/tests/estimator_checks_test.py @@ -6,19 +6,19 @@ class TestCheckEstimators(unittest.TestCase): - '''Checks that estimators conform to sklearn checks - ''' + """Checks that estimators conform to sklearn checks + """ def test_check_classifier_compatibility(self): - '''Test classifiers are properly sklearn-compatible - ''' + """Test classifiers are properly sklearn-compatible + """ for classifier in [imodels.SLIMClassifier]: # BoostedRulesClassifier (multi-class not supported) check_estimator(classifier()) assert 'passed check_estimator for ' + str(classifier) def test_check_regressor_compatibility(self): - '''Test regressors are properly sklearn-compatible - ''' + """Test regressors are properly sklearn-compatible + """ for regr in []: # SLIMRegressor fails acc screening for boston dset check_estimator(regr()) assert 'passed check_estimator for ' + str(regr) diff --git a/tests/shrinkage_test.py b/tests/shrinkage_test.py index bfb0664a..72615f37 100644 --- a/tests/shrinkage_test.py +++ b/tests/shrinkage_test.py @@ -2,7 +2,7 @@ from functools import partial import numpy as np -from sklearn.ensemble import VotingRegressor +from sklearn.ensemble import VotingRegressor, RandomForestClassifier, GradientBoostingClassifier from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from imodels import HSTreeClassifier, HSTreeClassifierCV, \ @@ -46,8 +46,11 @@ def test_classification_shrinkage(self): ''' for model_type in [ + partial(HSTreeClassifier, estimator_=DecisionTreeClassifier()), + partial(HSTreeClassifier, estimator_=GradientBoostingClassifier()), partial(HSTreeClassifier, estimator_=DecisionTreeClassifier()), partial(HSTreeClassifierCV, estimator_=DecisionTreeClassifier()), + partial(HSTreeClassifierCV, estimator_=RandomForestClassifier()), partial(HSC45TreeClassifierCV, estimator_=C45TreeClassifier()), HSTreeClassifierCV, # default estimator is Decision tree with 25 max_leaf_nodes # partial(HSOptimalTreeClassifierCV, estimator_=OptimalTreeClassifier()), @@ -74,6 +77,9 @@ def test_classification_shrinkage(self): # print(type(m), m, 'final acc', acc_train) assert acc_train > 0.8, 'acc greater than 0.8' + # complexity + assert m.complexity_ > 0, 'complexity is greater than 0' + def test_recognized_by_sklearn(self): base_models = [('hs', HSTreeRegressor(DecisionTreeRegressor())), ('dt', DecisionTreeRegressor())] @@ -97,6 +103,9 @@ def test_regression_shrinkage(self): mse = np.mean(np.square(preds - self.y_regression)) assert mse < 1, 'mse less than 1' + # complexity + assert m.complexity_ > 0, 'complexity is greater than 0' + if __name__ == '__main__': t = TestShrinkage()