Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion sklearn/ensemble/_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def fit(self, X, y=None, sample_weight=None):
return self

# else, define offset_ wrt contamination parameter
self.offset_ = np.percentile(self.score_samples(X), 100.0 * self.contamination)
self.offset_ = np.percentile(
self._score_samples_no_validation(X), 100.0 * self.contamination
)

return self

Expand Down Expand Up @@ -435,6 +437,14 @@ def score_samples(self, X):
# Check data
X = self._validate_data(X, accept_sparse="csr", dtype=np.float32, reset=False)

return self._score_samples_no_validation(X)

def _score_samples_no_validation(self, X):
if issparse(X):
X = X.tocsr()
else:
X = np.asarray(X, dtype=np.float32)

# Take the opposite of the scores as bigger is better (here less
# abnormal)
return -self._compute_chunked_score_samples(X)
Expand Down
52 changes: 52 additions & 0 deletions sklearn/ensemble/tests/test_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings

import numpy as np
import pandas as pd

from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
Expand All @@ -33,6 +34,12 @@
diabetes = load_diabetes()


FEATURE_NAMES_WARNING = (
"X does not have valid feature names, but IsolationForest was fitted with "
"feature names"
)


def test_iforest(global_random_seed):
"""Check Isolation Forest for various parameter settings."""
X_train = np.array([[0, 1], [1, 2]])
Expand Down Expand Up @@ -75,6 +82,51 @@ def test_iforest_sparse(global_random_seed):
assert_array_equal(sparse_results, dense_results)


def test_iforest_fit_dataframe_contamination_no_warning():
X = pd.DataFrame({"a": [-1.1, 0.3, 0.5, 100]})

iso = IsolationForest(random_state=0, contamination=0.05)

with warnings.catch_warnings():
warnings.filterwarnings(
"error", message=FEATURE_NAMES_WARNING, category=UserWarning
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
iso.fit(X)


def test_iforest_dataframe_then_ndarray_warns_on_score_and_predict():
X = pd.DataFrame({"a": [-1.1, 0.3, 0.5, 100]})

iso = IsolationForest(random_state=0, contamination=0.05)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
iso.fit(X)

X_array = X.to_numpy()

with pytest.warns(UserWarning, match=FEATURE_NAMES_WARNING):
iso.score_samples(X_array)

with pytest.warns(UserWarning, match=FEATURE_NAMES_WARNING):
iso.predict(X_array)


def test_iforest_fit_dataframe_auto_no_warning_and_offset():
X = pd.DataFrame({"a": [-1.1, 0.3, 0.5, 100]})
iso = IsolationForest(random_state=0, contamination="auto")

with warnings.catch_warnings():
warnings.filterwarnings(
"error", message=FEATURE_NAMES_WARNING, category=UserWarning
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
iso.fit(X)

assert iso.offset_ == -0.5


def test_iforest_error():
"""Test that it gives proper exception on deficient input."""
X = iris.data
Expand Down