Skip to content

Commit

Permalink
added warning back in and fixed issue with checking bottleneck import
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanhausen committed Aug 14, 2024
1 parent 91f6271 commit 68e8f9e
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions treeple/stats/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import os
import sys
import warnings
from typing import Optional, Tuple

import numpy as np
Expand All @@ -19,11 +20,14 @@
from treeple._lib.sklearn.ensemble._forest import BaseForest, ForestClassifier

BOTTLENECK_AVAILABLE = False
if "bottleneck" in sys.modules:
if importlib.util.find_spec("bottleneck"):
import bottleneck as bn

BOTTLENECK_AVAILABLE = True

BOTTLENECK_WARNING = (
"Not using bottleneck for calculations involvings nans. Expect slower performance."
)
DISABLE_BN_ENV_VAR = "TREEPLE_NO_BOTTLENECK"

if BOTTLENECK_AVAILABLE and DISABLE_BN_ENV_VAR not in os.environ:
Expand Down Expand Up @@ -257,6 +261,9 @@ def _compute_null_distribution_coleman(
metric_star_pi : ArrayLike of shape (n_samples,)
An array of the metrics computed on the other half of the trees.
"""
if not BOTTLENECK_AVAILABLE:
warnings.warn(BOTTLENECK_WARNING)

# sample two sets of equal number of trees from the combined forest these are the posteriors
# (n_estimators * 2, n_samples, n_outputs)
all_y_pred = np.concatenate((y_pred_proba_normal, y_pred_proba_perm), axis=0)
Expand Down

0 comments on commit 68e8f9e

Please sign in to comment.