Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH, FIX i) build_oob_forest backwards compatiblility with sklearn and ii) HonestForest stratification during bootstrap #283

Merged
merged 28 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
<!-- To ensure we can review your pull request promptly please complete this template entirely. -->
<!--
Thanks for contributing a pull request! Please ensure you have taken a look at
the contribution guidelines from scikit-learn: https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md
-->

<!-- Please reference the issue number here. You can replace "Fixes" with "Closes" if it makes more sense. -->
Fixes #
#### Reference Issues/PRs
<!--
Example: Fixes #1234. See also #3456.
Please use keywords (e.g., Fixes) to create link to the issues or pull requests
you resolved, so that they will automatically be closed when your pull request
is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests
-->

Changes proposed in this pull request:
<!-- Please list all changes/additions here. -->
-

## Before submitting
#### What does this implement/fix? Explain your changes.

<!-- Please complete this checklist BEFORE submitting your PR to speed along the review process. -->
- [ ] I've read and followed all steps in the [Making a pull request](https://github.com/py-why/pywhy-graphs/blob/main/CONTRIBUTING.md#making-a-pull-request)
section of the `CONTRIBUTING` docs.
- [ ] I've updated or added any relevant docstrings following the syntax described in the
[Writing docstrings](https://github.com/py-why/pywhy-graphs/blob/main/CONTRIBUTING.md#writing-docstrings) section of the `CONTRIBUTING` docs.
- [ ] If this PR fixes a bug, I've added a test that will fail without my fix.
- [ ] If this PR adds a new feature, I've added tests that sufficiently cover my new functionality.

## After submitting
#### Any other comments?

<!-- Please complete this checklist AFTER submitting your PR to speed along the review process. -->
- [ ] All GitHub Actions jobs for my pull request have passed.

<!--
Please be aware that we are a loose team of volunteers so patience is
necessary; assistance handling other issues is very welcome. We value
all user contributions, no matter how minor they are.

See https://github.com/neurodata/scikit-tree/blob/main/CONTRIBUTING.md for more
information on contributing.

Thanks for contributing!
-->
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.4.2
hooks:
- id: black
args: [--quiet]
Expand All @@ -15,14 +15,14 @@ repos:
types: [cython]

- repo: https://github.com/MarcoGorelli/cython-lint
rev: v0.16.0
rev: v0.16.2
hooks:
- id: cython-lint
- id: double-quote-cython-strings

# Ruff sktree
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.4.10
hooks:
- id: ruff
name: ruff sktree
Expand All @@ -31,7 +31,7 @@ repos:

# Ruff tutorials and examples
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.4.10
hooks:
- id: ruff
name: ruff tutorials and examples
Expand All @@ -42,7 +42,7 @@ repos:

# Codespell
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
additional_dependencies:
Expand All @@ -52,7 +52,7 @@ repos:

# yamllint
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.33.0
rev: v1.35.1
hooks:
- id: yamllint
args: [--strict, -c, .yamllint.yml]
Expand All @@ -67,7 +67,7 @@ repos:

# mypy
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.10.0
hooks:
- id: mypy
# Avoid the conflict between mne/__init__.py and mne/__init__.pyi by ignoring the former
Expand All @@ -84,4 +84,4 @@ repos:
files: ^(?!doc/use\.rst$).*\.(rst|inc)$

ci:
autofix_prs: false
autofix_prs: true
4 changes: 4 additions & 0 deletions doc/whats_new/v0.8.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Changelog
argument. By `Adam Li`_ (:pr:`#274`)
- |API| Removed all instances of ``FeatureImportanceForestClassifier`` and outdated
MIGHT code. By `Adam Li`_ (:pr:`#274`)
- |Fix| Fixed a bug in the ``sktree.HonestForestClassifier`` where posteriors
estimated on oob samples were biased when there was a low number of samples
due to imbalance in the classes when ``bootstrap=True``.
By `Adam Li`_ (:pr:`#283`)

Code and Documentation Contributors
-----------------------------------
Expand Down
55 changes: 0 additions & 55 deletions sktree/cv.py

This file was deleted.

13 changes: 6 additions & 7 deletions sktree/datasets/hyppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,11 @@ def make_trunk_mixture_classification(
mixture_idx = rng.choice(2, n_samples // 2, replace=True, shuffle=True, p=[mix, 1 - mix]) # type: ignore

norm_params = [[mu_0_vec, cov], [mu_1_vec, cov]]
X_mixture = np.fromiter(
(
rng_children.multivariate_normal(*(norm_params[i]), size=1, method=method)
for i, rng_children in zip(mixture_idx, rng.spawn(n_samples // 2))
),
dtype=np.dtype((float, n_informative)),
)
dim_sample = len(norm_params[0][0]) # Dimensionality of the samples
X_mixture = np.empty((n_samples // 2, dim_sample)) # Pre-allocate array for samples
for idx, (i, rng_child) in enumerate(zip(mixture_idx, rng.spawn(n_samples // 2))):
mean, cov = norm_params[i]
X_mixture[idx, :] = rng_child.multivariate_normal(mean, cov, size=1, method=method)

# create new generator instance to ensure reproducibility with multiple runs with the same seed
rng_F = np.random.default_rng(seed=seed)
Expand Down Expand Up @@ -474,6 +472,7 @@ def make_trunk_classification(
Either 'ma', or 'ar'.
return_params : bool, optional
Whether or not to return the distribution parameters of the classes normal distributions.
Default false.
scaling_factor : float, optional
The scaling factor for the covariance matrix. By default 1.
seed : int, optional
Expand Down
136 changes: 109 additions & 27 deletions sktree/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from joblib import Parallel, delayed
from sklearn.base import _fit_context, clone
from sklearn.ensemble._base import _partition_estimators, _set_random_states
from sklearn.utils import compute_sample_weight
from sklearn.utils._param_validation import Interval, RealNotInt
from sklearn.utils import compute_sample_weight, resample
from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
from sklearn.utils.validation import check_is_fitted

from .._lib.sklearn.ensemble._forest import ForestClassifier
Expand All @@ -31,41 +31,67 @@
n_samples_bootstrap=None,
missing_values_in_feature_mask=None,
classes=None,
stratify=False,
):
"""
Private function used to fit a single tree in parallel.

XXX: this is copied over from scikit-learn and modified to allow sampling with
XXX:
1. this is copied over from scikit-learn and modified to allow sampling with
and without replacement given ``bootstrap``.

2. Overrides the scikit-learn implementation to allow for stratification during bootstrapping
via the `stratify` parameter.
YuxinB marked this conversation as resolved.
Show resolved Hide resolved
"""
if verbose > 1:
print("building tree %d of %d" % (tree_idx + 1, n_trees))

n_samples = X.shape[0]
if sample_weight is None:
curr_sample_weight = np.ones((n_samples,), dtype=np.float64)
if bootstrap:
n_samples = X.shape[0]
if sample_weight is None:
curr_sample_weight = np.ones((n_samples,), dtype=np.float64)
else:
curr_sample_weight = sample_weight.copy()

if stratify:
indices = resample(
np.arange(n_samples),
n_samples=n_samples_bootstrap,
stratify=y,
replace=True,
random_state=tree.random_state,
)
else:
indices = _generate_sample_indices(
tree.random_state, n_samples, n_samples_bootstrap, bootstrap=bootstrap
)
sample_counts = np.bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts

if class_weight == "subsample":
with catch_warnings():
simplefilter("ignore", DeprecationWarning)
curr_sample_weight *= compute_sample_weight("auto", y, indices=indices)

Check warning on line 74 in sktree/ensemble/_honest_forest.py

View check run for this annotation

Codecov / codecov/patch

sktree/ensemble/_honest_forest.py#L73-L74

Added lines #L73 - L74 were not covered by tests
elif class_weight == "balanced_subsample":
curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)

Check warning on line 76 in sktree/ensemble/_honest_forest.py

View check run for this annotation

Codecov / codecov/patch

sktree/ensemble/_honest_forest.py#L76

Added line #L76 was not covered by tests

tree._fit(
X,
y,
sample_weight=curr_sample_weight,
check_input=False,
missing_values_in_feature_mask=missing_values_in_feature_mask,
classes=classes,
)
else:
curr_sample_weight = sample_weight.copy()

indices = _generate_sample_indices(tree.random_state, n_samples, n_samples_bootstrap, bootstrap)
sample_counts = np.bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts

if class_weight == "subsample":
with catch_warnings():
simplefilter("ignore", DeprecationWarning)
curr_sample_weight *= compute_sample_weight("auto", y, indices=indices)
elif class_weight == "balanced_subsample":
curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)

tree._fit(
X,
y,
sample_weight=curr_sample_weight,
check_input=False,
missing_values_in_feature_mask=missing_values_in_feature_mask,
classes=classes,
)
tree._fit(
X,
y,
sample_weight=sample_weight,
check_input=False,
missing_values_in_feature_mask=missing_values_in_feature_mask,
classes=classes,
)

return tree

Expand Down Expand Up @@ -254,6 +280,8 @@

stratify : bool
Whether or not to stratify sample when considering structure and leaf indices.
This will also stratify samples when bootstrap sampling is used. For more
information, see :func:`sklearn.utils.resample`.
By default False.

**tree_estimator_params : dict
Expand Down Expand Up @@ -389,6 +417,11 @@
Interval(RealNotInt, 0.0, None, closed="right"),
Interval(Integral, 1, None, closed="left"),
]
_parameter_constraints["honest_fraction"] = [Interval(RealNotInt, 0.0, 1.0, closed="both")]
_parameter_constraints["honest_prior"] = [
StrOptions({"empirical", "uniform", "ignore"}),
]
_parameter_constraints["stratify"] = ["boolean"]

def __init__(
self,
Expand Down Expand Up @@ -515,6 +548,55 @@

return self

def _construct_trees(
self,
X,
y,
sample_weight,
random_state,
n_samples_bootstrap,
missing_values_in_feature_mask,
classes,
n_more_estimators,
):
"""Override construction of trees to allow stratification during bootstrapping."""
trees = [
self._make_estimator(append=False, random_state=random_state)
for i in range(n_more_estimators)
]

# Parallel loop: we prefer the threading backend as the Cython code
# for fitting the trees is internally releasing the Python GIL
# making threading more efficient than multiprocessing in
# that case. However, for joblib 0.12+ we respect any
# parallel_backend contexts set at a higher level,
# since correctness does not rely on using threads.
trees = Parallel(
n_jobs=self.n_jobs,
verbose=self.verbose,
prefer="threads",
)(
delayed(_parallel_build_trees)(
t,
self.bootstrap,
X,
y,
sample_weight,
i,
len(trees),
verbose=self.verbose,
class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap,
missing_values_in_feature_mask=missing_values_in_feature_mask,
classes=classes,
stratify=self.stratify,
)
for i, t in enumerate(trees)
)

# Collect newly grown trees
self.estimators_.extend(trees)

def _make_estimator(self, append=True, random_state=None):
"""Make and configure a copy of the `estimator_` attribute.

Expand Down
Loading
Loading