Skip to content

Commit

Permalink
hs working for tree ensembles
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Jul 28, 2022
1 parent b9c2a39 commit a647816
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 13 deletions.
22 changes: 21 additions & 1 deletion docs/shrinkage.html
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ <h2>How does Hierarchical shrinkage work? </h2>

<h2>An example using HS</h2>

HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the <code>fit</code> and <code>predict</code> methods. Here's a full example of using it on a sample clinical dataset.
HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the <code>fit</code> and <code>predict</code> methods.
Here's a full example of using it on a sample clinical dataset.

<pre>
<code>
Expand All @@ -73,6 +74,9 @@ <h2>An example using HS</h2>
</code>
</pre>

Here we used <code>HSTreeClassifierCV</code>, which selects the amount of regularization to use via cross-validation, but we can also use <code>HSTreeClassifier</code> if we want to specify a particular amount of regularization.
For regression, we can use the corresponding classes: <code>HSTreeRegressorCV</code> and <code>HSTreeRegressor</code>.

<p style="text-align:center;">
<a href="https://github.com/csinva/imodels"><img src="https://demos.csinva.io/shrinkage/shrinkage_csi_model.svg?sanitize=True" width="100%"></a>
<br>
Expand Down Expand Up @@ -103,6 +107,22 @@ <h2>Examples with HS on synthetic data</h2>
</p>
</div>

<h2>Applying HS to tree ensembles</h2>

HS can also be used on tree ensembles to regularize each tree in an ensemble (e.g. in a Random Forest).
We must simply pass the desired estimator during initialization.

<pre>
<code>
from sklearn.ensemble import RandomForestClassifier # also works with ExtraTreesClassifier, GradientBoostingClassifier
from imodels import HSTreeClassifier
ensemble = RandomForestClassifier()
model = HSTreeClassifier(estimator_=ensemble)
model = model.fit(X_train, y_train)
</code>
</pre>



</section>

Expand Down
39 changes: 36 additions & 3 deletions docs/tree/hierarchical_shrinkage.html
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,18 @@
self.feature_names = kwargs.pop(&#39;feature_names&#39;, None) # None returned if not passed
self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs)
self._shrink()
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)

# compute complexity
if hasattr(self.estimator_, &#39;tree_&#39;):
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
elif hasattr(self.estimator_, &#39;estimators_&#39;):
self.complexity_ = 0
for i in range(len(self.estimator_.estimators_)):
t = deepcopy(self.estimator_.estimators_[i])
if isinstance(t, np.ndarray):
assert t.size == 1, &#39;multiple trees stored under tree_?&#39;
t = t[0]
self.complexity_ += compute_tree_complexity(t.tree_)
return self

def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0):
Expand Down Expand Up @@ -406,7 +417,18 @@ <h2 id="params">Params</h2>
self.feature_names = kwargs.pop(&#39;feature_names&#39;, None) # None returned if not passed
self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs)
self._shrink()
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)

# compute complexity
if hasattr(self.estimator_, &#39;tree_&#39;):
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
elif hasattr(self.estimator_, &#39;estimators_&#39;):
self.complexity_ = 0
for i in range(len(self.estimator_.estimators_)):
t = deepcopy(self.estimator_.estimators_[i])
if isinstance(t, np.ndarray):
assert t.size == 1, &#39;multiple trees stored under tree_?&#39;
t = t[0]
self.complexity_ += compute_tree_complexity(t.tree_)
return self

def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0):
Expand Down Expand Up @@ -544,7 +566,18 @@ <h3>Methods</h3>
self.feature_names = kwargs.pop(&#39;feature_names&#39;, None) # None returned if not passed
self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs)
self._shrink()
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)

# compute complexity
if hasattr(self.estimator_, &#39;tree_&#39;):
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
elif hasattr(self.estimator_, &#39;estimators_&#39;):
self.complexity_ = 0
for i in range(len(self.estimator_.estimators_)):
t = deepcopy(self.estimator_.estimators_[i])
if isinstance(t, np.ndarray):
assert t.size == 1, &#39;multiple trees stored under tree_?&#39;
t = t[0]
self.complexity_ += compute_tree_complexity(t.tree_)
return self</code></pre>
</details>
</dd>
Expand Down
13 changes: 12 additions & 1 deletion imodels/tree/hierarchical_shrinkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,18 @@ def fit(self, X, y, *args, **kwargs):
self.feature_names = kwargs.pop('feature_names', None) # None returned if not passed
self.estimator_ = self.estimator_.fit(X, y, *args, **kwargs)
self._shrink()
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)

# compute complexity
if hasattr(self.estimator_, 'tree_'):
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
elif hasattr(self.estimator_, 'estimators_'):
self.complexity_ = 0
for i in range(len(self.estimator_.estimators_)):
t = deepcopy(self.estimator_.estimators_[i])
if isinstance(t, np.ndarray):
assert t.size == 1, 'multiple trees stored under tree_?'
t = t[0]
self.complexity_ += compute_tree_complexity(t.tree_)
return self

def _shrink_tree(self, tree, reg_param, i=0, parent_val=None, parent_num=None, cum_sum=0):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

setuptools.setup(
name="imodels",
version="1.3.2",
version="1.3.3",
author="Chandan Singh, Keyan Nasseri, Bin Yu, and others",
author_email="chandan_singh@berkeley.edu",
description="Implementations of various interpretable models",
Expand Down
12 changes: 6 additions & 6 deletions tests/estimator_checks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@


class TestCheckEstimators(unittest.TestCase):
'''Checks that estimators conform to sklearn checks
'''
"""Checks that estimators conform to sklearn checks
"""

def test_check_classifier_compatibility(self):
'''Test classifiers are properly sklearn-compatible
'''
"""Test classifiers are properly sklearn-compatible
"""
for classifier in [imodels.SLIMClassifier]: # BoostedRulesClassifier (multi-class not supported)
check_estimator(classifier())
assert 'passed check_estimator for ' + str(classifier)

def test_check_regressor_compatibility(self):
'''Test regressors are properly sklearn-compatible
'''
"""Test regressors are properly sklearn-compatible
"""
for regr in []: # SLIMRegressor fails acc screening for boston dset
check_estimator(regr())
assert 'passed check_estimator for ' + str(regr)
Expand Down
11 changes: 10 additions & 1 deletion tests/shrinkage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial

import numpy as np
from sklearn.ensemble import VotingRegressor
from sklearn.ensemble import VotingRegressor, RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

from imodels import HSTreeClassifier, HSTreeClassifierCV, \
Expand Down Expand Up @@ -46,8 +46,11 @@ def test_classification_shrinkage(self):
'''

for model_type in [
partial(HSTreeClassifier, estimator_=DecisionTreeClassifier()),
partial(HSTreeClassifier, estimator_=GradientBoostingClassifier()),
partial(HSTreeClassifier, estimator_=DecisionTreeClassifier()),
partial(HSTreeClassifierCV, estimator_=DecisionTreeClassifier()),
partial(HSTreeClassifierCV, estimator_=RandomForestClassifier()),
partial(HSC45TreeClassifierCV, estimator_=C45TreeClassifier()),
HSTreeClassifierCV, # default estimator is Decision tree with 25 max_leaf_nodes
# partial(HSOptimalTreeClassifierCV, estimator_=OptimalTreeClassifier()),
Expand All @@ -74,6 +77,9 @@ def test_classification_shrinkage(self):
# print(type(m), m, 'final acc', acc_train)
assert acc_train > 0.8, 'acc greater than 0.8'

# complexity
assert m.complexity_ > 0, 'complexity is greater than 0'

def test_recognized_by_sklearn(self):
base_models = [('hs', HSTreeRegressor(DecisionTreeRegressor())),
('dt', DecisionTreeRegressor())]
Expand All @@ -97,6 +103,9 @@ def test_regression_shrinkage(self):
mse = np.mean(np.square(preds - self.y_regression))
assert mse < 1, 'mse less than 1'

# complexity
assert m.complexity_ > 0, 'complexity is greater than 0'


if __name__ == '__main__':
t = TestShrinkage()
Expand Down

0 comments on commit a647816

Please sign in to comment.