Skip to content

Commit 0cd5e8d

Browse files
committed
Merge branch 'master' of https://github.com/csinva/imodels
2 parents 1da1fda + 48ca668 commit 0cd5e8d

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

imodels/tree/figs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,10 @@ def _importances(node: Node):
378378
# require the tree to have more than 1 node, otherwise just leave importance_data_tree as zeros
379379
if 1 < next(node_counter):
380380
tree_samples = _importances(tree_)
381-
importance_data_tree /= tree_samples
381+
if tree_samples != 0:
382+
importance_data_tree /= tree_samples
383+
else:
384+
importance_data_tree = 0
382385

383386
importance_data.append(importance_data_tree)
384387

imodels/tree/hierarchical_shrinkage.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,21 @@ def __repr__(self):
183183

184184

185185
class HSTreeRegressor(HSTree, RegressorMixin):
186-
...
186+
def __init__(self, estimator_: BaseEstimator = DecisionTreeRegressor(max_leaf_nodes=20),
187+
reg_param: float = 1, shrinkage_scheme_: str = 'node_based'):
188+
super().__init__(estimator_=estimator_,
189+
reg_param=reg_param,
190+
shrinkage_scheme_=shrinkage_scheme_,
191+
)
187192

188193

189194
class HSTreeClassifier(HSTree, ClassifierMixin):
190-
...
195+
def __init__(self, estimator_: BaseEstimator = DecisionTreeClassifier(max_leaf_nodes=20),
196+
reg_param: float = 1, shrinkage_scheme_: str = 'node_based'):
197+
super().__init__(estimator_=estimator_,
198+
reg_param=reg_param,
199+
shrinkage_scheme_=shrinkage_scheme_,
200+
)
191201

192202

193203
def _get_cv_criterion(scorer):

0 commit comments

Comments
 (0)