-
-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Add conditional resampling using nearest-neighbor algorithm (#170)
* Add conditional resampling technique based on nearest-neighbors estimates * Fix coleman unit-test simulations by upping tree count --------- Signed-off-by: Adam Li <adam2392@gmail.com>
- Loading branch information
Showing
12 changed files
with
457 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -244,6 +244,7 @@ | |
"fit", | ||
"apply", | ||
"TreeBuilder", | ||
"joint_rank", | ||
} | ||
|
||
# validation | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
from numpy.typing import ArrayLike | ||
from scipy.sparse import issparse | ||
from sklearn.neighbors import NearestNeighbors | ||
from sklearn.utils import _approximate_mode, _safe_indexing, check_array, check_consistent_length | ||
|
||
|
||
def _conditional_shuffle(nbrs: ArrayLike, replace: bool = False, seed=None) -> ArrayLike: | ||
"""Compute a permutation of neighbors with restrictions. | ||
Parameters | ||
---------- | ||
nbrs : ArrayLike of shape (n_samples, k) | ||
The k-nearest-neighbors for each sample index. Each row corresponds to the | ||
original sample. Each element corresponds to another sample index that is deemed | ||
as the k-nearest neighbors with respect to the original sample. | ||
replace : bool, optional | ||
Whether or not to allow replacement of samples, by default False. | ||
seed : int, optional | ||
Random seed, by default None. | ||
Returns | ||
------- | ||
restricted_perm : ArrayLike of shape (n_samples) | ||
The final permutation order of the sample indices. There may be | ||
repeating samples. See Notes for details. | ||
Notes | ||
----- | ||
Restricted permutation goes through random samples and looks at the k-nearest | ||
neighbors (columns of ``nbrs``) and shuffles the closest neighbor index only | ||
if it has not been used to permute another sample. If it has been, then the | ||
algorithm looks at the next nearest-neighbor and so on. If all k-nearest | ||
neighbors of a sample has been checked, then a random neighbor is chosen. In this | ||
manner, the algorithm tries to perform permutation without replacement, but | ||
if necessary, will choose a repeating neighbor sample. | ||
""" | ||
n_samples, k_dims = nbrs.shape | ||
rng = np.random.default_rng(seed=seed) | ||
|
||
# initialize the final permutation order | ||
restricted_perm = np.zeros((n_samples,), dtype=np.intp) | ||
|
||
# generate a random order of samples to go through | ||
random_order = rng.permutation(n_samples) | ||
|
||
# keep track of values we have already used | ||
used = set() | ||
|
||
# go through the random order | ||
for idx in random_order: | ||
if replace: | ||
possible_nbrs = nbrs[idx, :] | ||
restricted_perm[idx] = rng.choice(possible_nbrs, size=1).squeeze() | ||
else: | ||
m = 0 | ||
use_idx = nbrs[idx, m] | ||
|
||
# if the current nbr is already used, continue incrementing | ||
# until we have either found a new sample to use, or if | ||
# we have reach the maximum number of shuffles to consider | ||
while (use_idx in used) and (m < k_dims - 1): | ||
m += 1 | ||
use_idx = nbrs[idx, m] | ||
|
||
# check whether or not we have exhaustively checked all kNN | ||
if use_idx in used and m == k_dims: | ||
# XXX: Note this step is not in the original paper | ||
# choose a random neighbor to permute | ||
restricted_perm[idx] = rng.choice(nbrs[idx, :], size=1) | ||
else: | ||
# permute with the existing neighbor | ||
restricted_perm[idx] = use_idx | ||
used.add(use_idx) | ||
return restricted_perm | ||
|
||
|
||
def conditional_resample( | ||
conditional_array: ArrayLike, | ||
*arrays, | ||
nn_estimator=None, | ||
replace: bool = True, | ||
replace_nbrs: bool = True, | ||
n_samples: Optional[int] = None, | ||
random_state: Optional[int] = None, | ||
stratify: Optional[ArrayLike] = None, | ||
): | ||
"""Conditionally resample arrays or sparse matrices in a consistent way. | ||
The default strategy implements one step of the bootstrapping | ||
procedure. Conditional resampling is a modification of the bootstrap | ||
technique that preserves the conditional distribution of the data. This | ||
is done by fitting a nearest neighbors estimator on the conditional array | ||
and then resampling the nearest neighbors of each sample. | ||
Parameters | ||
---------- | ||
conditional_array : array-like of shape (n_samples, n_features) | ||
The array, which we preserve the conditional distribution of. | ||
*arrays : sequence of array-like of shape (n_samples,) or \ | ||
(n_samples, n_outputs) | ||
Indexable data-structures can be arrays, lists, dataframes or scipy | ||
sparse matrices with consistent first dimension. | ||
nn_estimator : estimator object, default=None | ||
The nearest neighbors estimator to use. If None, then a | ||
:class:`sklearn.neighbors.NearestNeighbors` instance is used. | ||
replace : bool, default=True | ||
Implements resampling with replacement. If False, this will implement | ||
(sliced) random permutations. The replacement will take place at the level | ||
of the sample index. | ||
replace_nbrs : bool, default=True | ||
Implements resampling with replacement at the level of the nearest neighbors. | ||
n_samples : int, default=None | ||
Number of samples to generate. If left to None this is | ||
automatically set to the first dimension of the arrays. | ||
If replace is False it should not be larger than the length of | ||
arrays. | ||
random_state : int, RandomState instance or None, default=None | ||
Determines random number generation for shuffling | ||
the data. | ||
Pass an int for reproducible results across multiple function calls. | ||
See :term:`Glossary <random_state>`. | ||
stratify : array-like of shape (n_samples,) or (n_samples, n_outputs), \ | ||
default=None | ||
If not None, data is split in a stratified fashion, using this as | ||
the class labels. | ||
Returns | ||
------- | ||
resampled_arrays : sequence of array-like of shape (n_samples,) or \ | ||
(n_samples, n_outputs) | ||
Sequence of resampled copies of the collections. The original arrays | ||
are not impacted. | ||
""" | ||
max_n_samples = n_samples | ||
rng = np.random.default_rng(random_state) | ||
|
||
if len(arrays) == 0: | ||
return None | ||
|
||
first = arrays[0] | ||
n_samples = first.shape[0] if hasattr(first, "shape") else len(first) | ||
|
||
if max_n_samples is None: | ||
max_n_samples = n_samples | ||
elif (max_n_samples > n_samples) and (not replace): | ||
raise ValueError( | ||
f"Cannot sample {max_n_samples} out of arrays with dim " | ||
f"{n_samples} when replace is False" | ||
) | ||
|
||
check_consistent_length(conditional_array, *arrays) | ||
|
||
# fit nearest neighbors onto the conditional array | ||
if nn_estimator is None: | ||
nn_estimator = NearestNeighbors() | ||
nn_estimator.fit(conditional_array) | ||
|
||
if stratify is None: | ||
if replace: | ||
indices = rng.integers(0, n_samples, size=(max_n_samples,)) | ||
else: | ||
indices = np.arange(n_samples) | ||
rng.shuffle(indices) | ||
indices = indices[:max_n_samples] | ||
else: | ||
# Code adapted from StratifiedShuffleSplit() | ||
y = check_array(stratify, ensure_2d=False, dtype=None) | ||
if y.ndim == 2: | ||
# for multi-label y, map each distinct row to a string repr | ||
# using join because str(row) uses an ellipsis if len(row) > 1000 | ||
y = np.array([" ".join(row.astype("str")) for row in y]) | ||
|
||
classes, y_indices = np.unique(y, return_inverse=True) | ||
n_classes = classes.shape[0] | ||
|
||
class_counts = np.bincount(y_indices) | ||
|
||
# Find the sorted list of instances for each class: | ||
# (np.unique above performs a sort, so code is O(n logn) already) | ||
class_indices = np.split( | ||
np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1] | ||
) | ||
|
||
n_i = _approximate_mode(class_counts, max_n_samples, random_state) | ||
|
||
indices = [] | ||
|
||
for i in range(n_classes): | ||
indices_i = rng.choice(class_indices[i], n_i[i], replace=replace) | ||
indices.extend(indices_i) | ||
|
||
indices = rng.permutation(indices) | ||
|
||
# now get the kNN indices for each sample (n_samples, n_neighbors) | ||
sample_nbrs = nn_estimator.kneighbors(X=conditional_array[indices, :], return_distance=False) | ||
|
||
# actually sample the indices using a conditional permutation | ||
indices = _conditional_shuffle(sample_nbrs, replace=replace_nbrs, seed=rng) | ||
|
||
# convert sparse matrices to CSR for row-based indexing | ||
arrays_ = [a.tocsr() if issparse(a) else a for a in arrays] | ||
resampled_arrays = [_safe_indexing(a, indices) for a in arrays_] | ||
|
||
if len(resampled_arrays) == 1: | ||
# syntactic sugar for the unit argument case | ||
return resampled_arrays[0] | ||
else: | ||
return resampled_arrays |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.