Skip to content

Commit

Permalink
wip PR revisions. updated docs, revised bottleneck function selection…
Browse files Browse the repository at this point in the history
… logic in test
  • Loading branch information
ryanhausen committed Aug 21, 2024
1 parent d2204e4 commit 766c6c4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 16 deletions.
19 changes: 16 additions & 3 deletions treeple/stats/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def build_coleman_forest(
return_posteriors : bool, optional
Whether or not to return the posteriors, by default True.
use_sparse : bool, optional
Whether or not to use a sparse representation for the posteriors.
Whether or not to use a sparse for the calculation of the permutation
statistics, by default False. Doesn't affect return values.
**metric_kwargs : dict, optional
Additional keyword arguments to pass to the metric function.
Expand Down Expand Up @@ -173,11 +174,23 @@ def build_coleman_forest(
# sparse indices and values with an array
if return_posteriors:
n_trees = y_pred_proba_orig_perm.shape[0] // 2
n_samples = y_pred_proba_orig_perm.shape[1]
# slicing a csc matrix this way is not efficient, but it is
# it is only done once, so I am not sure if it is worth it to
# optimize this.
orig_forest_proba = y_pred_proba_orig_perm[:n_trees, :]
perm_forest_proba = y_pred_proba_orig_perm[n_trees:, :]
to_coords_data = lambda x: (x.row.astype(int), x.col.astype(int), x.data)

row, col, data = to_coords_data(y_pred_proba_orig_perm[:n_trees, :].tocoo())
orig_forest_proba = np.full((n_trees, n_samples), np.nan, dtype=np.float64)
orig_forest_proba[row, col] = data

row, col, data = to_coords_data(y_pred_proba_orig_perm[n_trees:, :].tocoo())
perm_forest_proba = np.full((n_trees, n_samples), np.nan, dtype=np.float64)
perm_forest_proba[row, col] = data

if y.shape[1] == 2:
orig_forest_proba = np.column_stack((orig_forest_proba, 1 - orig_forest_proba))
perm_forest_proba = np.column_stack((perm_forest_proba, 1 - perm_forest_proba))
else:
metric_star, metric_star_pi = _compute_null_distribution_coleman(
y,
Expand Down
19 changes: 11 additions & 8 deletions treeple/stats/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,21 +238,20 @@ def test_comight_repeated_feature_sets(seed):

@pytest.mark.parametrize(
("use_bottleneck", "use_sparse"),
itertools.product([True, False], [True, False]),
itertools.product([False, True], [False, True]),
)
def test_build_coleman_forest(use_bottleneck: bool, use_sparse: bool):
"""Simple test for building a Coleman forest.
Test the function under alternative and null hypothesis for a very simple dataset.
"""
if use_bottleneck and utils.DISABLE_BN_ENV_VAR in os.environ:
del os.environ[utils.DISABLE_BN_ENV_VAR]
importlib.reload(utils)
importlib.reload(stats)
else:
if not use_bottleneck and utils.DISABLE_BN_ENV_VAR not in os.environ:
os.environ[utils.DISABLE_BN_ENV_VAR] = "1"
importlib.reload(utils)
importlib.reload(stats)
elif use_bottleneck and utils.DISABLE_BN_ENV_VAR in os.environ:
del os.environ[utils.DISABLE_BN_ENV_VAR]

importlib.reload(utils)
importlib.reload(stats)

n_estimators = 100
n_samples = 30
Expand Down Expand Up @@ -436,3 +435,7 @@ def test_build_oob_random_forest():
assert len(np.unique(structure_samples[tree_idx])) + len(oob_samples_list[tree_idx]) == len(
samples
), f"{tree_idx} {len(structure_samples[tree_idx])} + {len(oob_samples_list[tree_idx])} != {len(samples)}"


if __name__ == "__main__":
test_build_coleman_forest(False, False)
22 changes: 17 additions & 5 deletions treeple/stats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,6 @@ def _compute_null_distribution_coleman(
metric_star_pi : ArrayLike of shape (n_samples,)
An array of the metrics computed on the other half of the trees.
"""
if not BOTTLENECK_AVAILABLE:
warnings.warn(BOTTLENECK_WARNING)

# sample two sets of equal number of trees from the combined forest these are the posteriors
# (n_estimators * 2, n_samples, n_outputs)
all_y_pred = np.concatenate((y_pred_proba_normal, y_pred_proba_perm), axis=0)
Expand Down Expand Up @@ -337,9 +334,24 @@ def get_per_tree_oob_samples(est: BaseForest):
def _get_forest_preds_sparse(
all_y_pred: sp.csc_matrix, # (n_trees, n_samples)
all_y_indicator: sp.csc_matrix, # (n_trees, n_samples)
forest_indices: ArrayLike, # (n_trees,)
forest_indices: ArrayLike, # (n_trees/2,)
) -> ArrayLike:
"""Get the forest predictions for a set of trees using sparse matrices."""
"""Get the forest predictions for a set of trees using sparse matrices.
Parameters
----------
all_y_pred : sp.csc_matrix of shape (n_trees, n_samples)
The predicted posteriors from the forest.
all_y_indicator : sp.csc_matrix of shape (n_trees, n_samples)
The indicator matrix for the predictions.
forest_indices : ArrayLike of shape (n_trees/2,)
The indices of the trees in the forest that we are evaluating.
Returns
-------
ArrayLike of shape (n_samples,)
The averaged predictions for the forest.
"""
forest_indicator = np.zeros(len(forest_indices) * 2, dtype=np.uint8)
forest_indicator[forest_indices] = 1

Expand Down

0 comments on commit 766c6c4

Please sign in to comment.