Skip to content

Commit

Permalink
added clarifying comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanhausen committed Aug 26, 2024
1 parent 251392f commit d733508
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 3 additions & 4 deletions treeple/stats/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,12 @@ def build_coleman_forest(
)

# if we are returning the posteriors, then we need to replace the
# sparse indices and values with an array
# sparse indices and values with an array. We convert the sparse data
# to dense data, so that the function returns results in a consistent format.
if return_posteriors:
n_trees = y_pred_proba_orig_perm.shape[0] // 2
n_samples = y_pred_proba_orig_perm.shape[1]
# slicing a csc matrix this way is not efficient, but it is
# it is only done once, so I am not sure if it is worth it to
# optimize this.

to_coords_data = lambda x: (x.row.astype(int), x.col.astype(int), x.data)

row, col, data = to_coords_data(y_pred_proba_orig_perm[:n_trees, :].tocoo())
Expand Down
2 changes: 2 additions & 0 deletions treeple/stats/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def test_build_coleman_forest(use_bottleneck: bool, use_sparse: bool):
elif use_bottleneck and utils.DISABLE_BN_ENV_VAR in os.environ:
del os.environ[utils.DISABLE_BN_ENV_VAR]

# We need to reload the modules after changing the environment variable
# because an environment variable is used to disable bottleneck
importlib.reload(utils)
importlib.reload(stats)

Expand Down

0 comments on commit d733508

Please sign in to comment.