From fae7e5cf384cd8a654fbb280588921d25af7bc4c Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Mon, 15 Jul 2024 12:40:16 -0400 Subject: [PATCH] FIX ignore nan values when summing posteriors (#291) * FIX ignore nan values when summing posteriors --------- Co-authored-by: Adam Li --- doc/whats_new/v0.8.rst | 2 +- doc/whats_new/v0.9.rst | 8 ++++++-- treeple/ensemble/_honest_forest.py | 15 ++++++--------- treeple/stats/tests/test_forestht.py | 8 +++++--- treeple/tests/test_extensions.py | 5 ++++- treeple/tests/test_honest_forest.py | 17 +++++++++++++++++ 6 files changed, 39 insertions(+), 16 deletions(-) diff --git a/doc/whats_new/v0.8.rst b/doc/whats_new/v0.8.rst index b0403460c..69be7f84b 100644 --- a/doc/whats_new/v0.8.rst +++ b/doc/whats_new/v0.8.rst @@ -40,4 +40,4 @@ Thanks to everyone who has contributed to the maintenance and improvement of the project since version inception, including: * `Adam Li`_ - +* `Sambit Panda`_ diff --git a/doc/whats_new/v0.9.rst b/doc/whats_new/v0.9.rst index 937f94350..a3688111b 100644 --- a/doc/whats_new/v0.9.rst +++ b/doc/whats_new/v0.9.rst @@ -10,7 +10,7 @@ Version 0.9 **In Development** -This release include a rename of the package to from ``scikit-learn`` to ``treeple`` +This release include a rename of the package to from ``scikit-tree`` to ``treeple`` The users can replace the previous usage as follows: ``import sktree`` to ``import treeple`` ``from sktree import tree`` to ``from treeple import tree`` @@ -21,7 +21,10 @@ Changelog --------- - |API| Rename the package to ``treeple``. By `SUKI-O`_ (:pr:`#292`) - +- |Fix| Fixed a bug in the predict_proba function of the :class:`treeple.HonestForestClassifier` where posteriors + estimated on empty leaf with ``ignore`` prior would result in ``np.nan`` + values for all trees on that sample. + By `Haoyin Xu`_ (:pr:`#291`) Code and Documentation Contributors ----------------------------------- @@ -31,3 +34,4 @@ the project since version inception, including: * `Adam Li`_ * `SUKI-O`_ +* `Haoyin Xu`_ diff --git a/treeple/ensemble/_honest_forest.py b/treeple/ensemble/_honest_forest.py index 5d1e082c5..96c010625 100644 --- a/treeple/ensemble/_honest_forest.py +++ b/treeple/ensemble/_honest_forest.py @@ -259,7 +259,7 @@ class HonestForestClassifier(ForestClassifier, ForestClassifierMixin): - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. - honest_prior : {"ignore", "uniform", "empirical"}, default="empirical" + honest_prior : {"ignore", "uniform", "empirical"}, default="ignore" Method for dealing with empty leaves during evaluation of a test sample. If "ignore", the tree is ignored. If "uniform", the prior tree posterior is 1/(number of classes). If "empirical", the prior tree @@ -444,7 +444,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, max_samples=None, - honest_prior="empirical", + honest_prior="ignore", honest_fraction=0.5, tree_estimator=None, stratify=False, @@ -648,7 +648,7 @@ def predict_proba(self, X): """ return self._predict_proba(X) - def _predict_proba(self, X, indices=None, impute_missing=None): + def _predict_proba(self, X, indices=None, impute_missing=np.nan): """predict_proba helper class""" check_is_fitted(self) X = self._validate_X_predict(X) @@ -672,10 +672,7 @@ def _predict_proba(self, X, indices=None, impute_missing=None): zero_mask = posteriors.sum(2) == 0 posteriors[~zero_mask] /= posteriors[~zero_mask].sum(1, keepdims=True) - if impute_missing is None: - pass - else: - posteriors[zero_mask] = impute_missing + posteriors[zero_mask] = impute_missing # preserve shape of multi-outputs if self.n_outputs_ > 1: @@ -823,7 +820,7 @@ def _accumulate_prediction(predict, X, out, lock, indices=None): with lock: if len(out) == 1: - out[0][indices] += proba + out[0][indices] = np.nansum([out[0][indices], proba], axis=0) else: for i in range(len(out)): - out[i][indices] += proba[i] + out[i][indices] = np.nansum([out[i][indices], proba[i]], axis=0) diff --git a/treeple/stats/tests/test_forestht.py b/treeple/stats/tests/test_forestht.py index a0b1fbb60..0eea2c257 100644 --- a/treeple/stats/tests/test_forestht.py +++ b/treeple/stats/tests/test_forestht.py @@ -119,11 +119,10 @@ def test_small_dataset_independent(seed): @flaky(max_runs=3) @pytest.mark.parametrize("seed", [10, 0]) def test_small_dataset_dependent(seed): - n_samples = 50 + n_samples = 100 n_features = 5 rng = np.random.default_rng(seed) - X = rng.uniform(size=(n_samples, n_features)) X = rng.uniform(size=(n_samples // 2, n_features)) X2 = X + 3 X = np.vstack([X, X2]) @@ -157,12 +156,15 @@ def test_small_dataset_dependent(seed): n_repeats=1000, metric="mi", return_posteriors=False, + seed=seed, ) assert ~np.isnan(result.pvalue) assert ~np.isnan(result.observe_test_stat) assert result.pvalue <= 0.05 - result = build_coleman_forest(clf, perm_clf, X, y, metric="mi", return_posteriors=False) + result = build_coleman_forest( + clf, perm_clf, X, y, metric="mi", return_posteriors=False, seed=seed + ) assert result.pvalue <= 0.05 diff --git a/treeple/tests/test_extensions.py b/treeple/tests/test_extensions.py index e8072af26..bcd736590 100644 --- a/treeple/tests/test_extensions.py +++ b/treeple/tests/test_extensions.py @@ -32,7 +32,10 @@ def test_predict_proba_per_tree(Forest, n_classes): ) # Call the method being tested - est = Forest(n_estimators=10, bootstrap=True, random_state=0) + if Forest == HonestForestClassifier: + est = Forest(n_estimators=10, bootstrap=True, random_state=0, honest_prior="empirical") + else: + est = Forest(n_estimators=10, bootstrap=True, random_state=0) est.fit(X, y) proba_per_tree = est.predict_proba_per_tree(X) diff --git a/treeple/tests/test_honest_forest.py b/treeple/tests/test_honest_forest.py index 0735c55ca..f1e392109 100644 --- a/treeple/tests/test_honest_forest.py +++ b/treeple/tests/test_honest_forest.py @@ -263,6 +263,23 @@ def test_impute_posteriors(honest_prior, val): ), f"Failed with {honest_prior}, prior {clf.estimators_[0].empirical_prior_}" +def test_honestforest_predict_proba_with_honest_prior(): + X = rng.normal(0, 1, (100, 2)) + y = [0] * 75 + [1] * 25 + honest_prior = "ignore" + clf = HonestForestClassifier( + honest_fraction=0.5, random_state=0, honest_prior=honest_prior, n_estimators=100 + ) + clf = clf.fit(X, y) + + y_proba = clf.predict_proba(X) + + # With enough trees no nan values should exist + assert ( + len(np.where(np.isnan(y_proba[:, 0]))[0]) == 0 + ), f"Failed with {honest_prior}, prior {clf.estimators_[0].empirical_prior_}" + + @pytest.mark.parametrize( "honest_fraction, val", [