Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion doc/modules/outlier_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ Random partitioning produces noticeably shorter paths for anomalies.
Hence, when a forest of random trees collectively produce shorter path
lengths for particular samples, they are highly likely to be anomalies.

.. note::

Setting ``warm_start=True`` allows one to grow the same fitted
:class:`ensemble.IsolationForest` by increasing ``n_estimators`` between
successive calls to :meth:`fit`. Previously built trees are reused and only
the additional estimators are trained. Keeping ``n_estimators`` unchanged
raises a :class:`UserWarning` while decreasing it raises a
:class:`ValueError`. Parameter changes such as ``max_samples`` or
``max_features`` only impact the newly added trees, so ensembles can become
heterogeneous. See :term:`the Glossary <warm_start>` for more details.

The implementation of :class:`ensemble.IsolationForest` is based on an ensemble
of :class:`tree.ExtraTreeRegressor`. Following Isolation Forest original paper,
the maximum depth of each tree is set to :math:`\lceil \log_2(n) \rceil` where
Expand Down Expand Up @@ -365,4 +376,3 @@ Novelty detection with Local Outlier Factor is illustrated below.
:target: ../auto_examples/neighbors/sphx_glr_plot_lof_novelty_detection.html
:align: center
:scale: 75%

4 changes: 4 additions & 0 deletions examples/ensemble/plot_isolation_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
Hence, when a forest of random trees collectively produce shorter path lengths
for particular samples, they are highly likely to be anomalies.

Enable ``warm_start=True`` to reuse previously fitted trees and incrementally
grow the ensemble when increasing ``n_estimators`` between calls to
``fit``.

"""
print(__doc__)

Expand Down
7 changes: 7 additions & 0 deletions sklearn/ensemble/iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class IsolationForest(BaseBagging, OutlierMixin):
data sampled with replacement. If False, sampling without replacement
is performed.

warm_start : bool, optional (default=False)
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble, otherwise, just fit a whole
new forest. See :term:`the Glossary <warm_start>`.

n_jobs : int or None, optional (default=None)
The number of jobs to run in parallel for both `fit` and `predict`.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
Expand Down Expand Up @@ -170,6 +175,7 @@ def __init__(self,
contamination="legacy",
max_features=1.,
bootstrap=False,
warm_start=False,
n_jobs=None,
behaviour='old',
random_state=None,
Expand All @@ -185,6 +191,7 @@ def __init__(self,
n_estimators=n_estimators,
max_samples=max_samples,
max_features=max_features,
warm_start=warm_start,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose)
Expand Down
108 changes: 108 additions & 0 deletions sklearn/ensemble/tests/test_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,114 @@ def test_iforest_parallel_regression():
assert_array_almost_equal(y1, y3)


def _generate_warm_start_data(n_samples=128, n_features=2, seed=0):
rng_local = np.random.RandomState(seed)
return rng_local.randn(n_samples, n_features)


def test_warm_start_grows_forest_and_matches_single_fit():
X = _generate_warm_start_data()

warm_clf = IsolationForest(
n_estimators=10,
warm_start=True,
behaviour='new',
contamination='auto',
random_state=0
)
warm_clf.fit(X)
assert_equal(len(warm_clf.estimators_), 10)

warm_clf.set_params(n_estimators=25)
warm_clf.fit(X)
assert_equal(len(warm_clf.estimators_), 25)

cold_clf = IsolationForest(
n_estimators=25,
behaviour='new',
contamination='auto',
random_state=0
).fit(X)

assert_allclose(warm_clf.decision_function(X),
cold_clf.decision_function(X))


def test_warm_start_no_increase_warns():
X = _generate_warm_start_data()
clf = IsolationForest(
n_estimators=8,
warm_start=True,
behaviour='new',
contamination='auto',
random_state=0
).fit(X)
assert_equal(len(clf.estimators_), 8)

assert_warns_message(UserWarning,
'Warm-start fitting without increasing n_estimators '
'does not fit new trees.',
clf.fit, X)
assert_equal(len(clf.estimators_), 8)


def test_warm_start_decrease_raises():
X = _generate_warm_start_data()
clf = IsolationForest(
n_estimators=12,
warm_start=True,
behaviour='new',
contamination='auto',
random_state=0
).fit(X)

clf.set_params(n_estimators=6)
assert_raises(ValueError, clf.fit, X)


def test_warm_start_estimator_count_growth():
X = _generate_warm_start_data()
clf = IsolationForest(
n_estimators=5,
warm_start=True,
behaviour='new',
contamination='auto',
random_state=0
)

clf.fit(X)
assert_equal(len(clf.estimators_), 5)

clf.set_params(n_estimators=9)
clf.fit(X)
assert_equal(len(clf.estimators_), 9)

clf.set_params(n_estimators=14)
clf.fit(X)
assert_equal(len(clf.estimators_), 14)


def test_warm_start_idempotent_predictions_when_unchanged():
X = _generate_warm_start_data()
clf = IsolationForest(
n_estimators=7,
warm_start=True,
behaviour='new',
contamination='auto',
random_state=0
)

clf.fit(X)
baseline = clf.decision_function(X)

assert_warns_message(UserWarning,
'Warm-start fitting without increasing n_estimators '
'does not fit new trees.',
clf.fit, X)
assert_equal(len(clf.estimators_), 7)
assert_allclose(clf.decision_function(X), baseline)


@pytest.mark.filterwarnings('ignore:default contamination')
@pytest.mark.filterwarnings('ignore:behaviour="old"')
def test_iforest_performance():
Expand Down