Skip to content

Commit

Permalink
added bottleneck for nan cacluations
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanhausen committed Jul 30, 2024
1 parent 1e62fe3 commit bd22121
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions treeple/stats/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Optional, Tuple

import numpy as np
Expand All @@ -16,6 +17,15 @@

from treeple._lib.sklearn.ensemble._forest import BaseForest, ForestClassifier

try:
import bottleneck as bn
nanmean_f = bn.nanmean
anynan_f = lambda arr: bn.anynan(arr, axis=2)
except ImportError:
warnings.warn("bottleneck is not installed, falling back to numpy for nanmean and anynan functions. This may be slower.")
nanmean_f = np.nanmean
anynan_f = lambda arr: np.isnan(arr).any(axis=2)


def _mutual_information(y_true: ArrayLike, y_pred_proba: ArrayLike) -> float:
"""Compute estimate of mutual information for supervised classification setting.
Expand Down Expand Up @@ -131,7 +141,7 @@ def _non_nan_samples(posterior_arr: ArrayLike) -> ArrayLike:
along axis=1.
"""
# Find the row indices with NaN values along the specified axis
nan_indices = np.isnan(posterior_arr).any(axis=2).all(axis=0)
nan_indices = anynan_f(posterior_arr).all(axis=0)

# Invert the boolean mask to get indices without NaN values
nonnan_indices = np.where(~nan_indices)[0]
Expand Down Expand Up @@ -320,8 +330,8 @@ def _parallel_build_null_forests(
# first_half_metric = metric_func(y_test[non_nan_samples, :], y_pred_first_half)
# second_half_metric = metric_func(y_test[non_nan_samples, :], y_pred_second_half)

y_pred_first_half = np.nanmean(first_forest_pred[:, first_forest_samples, :], axis=0)
y_pred_second_half = np.nanmean(second_forest_pred[:, second_forest_samples, :], axis=0)
y_pred_first_half = nanmean_f(first_forest_pred[:, first_forest_samples, :], axis=0)
y_pred_second_half = nanmean_f(second_forest_pred[:, second_forest_samples, :], axis=0)

# compute two instances of the metric from the sampled trees
first_half_metric = metric_func(
Expand Down

0 comments on commit bd22121

Please sign in to comment.