diff --git a/doc/modules/outlier_detection.rst b/doc/modules/outlier_detection.rst index b27b0c8a59643..1353845ac25f6 100644 --- a/doc/modules/outlier_detection.rst +++ b/doc/modules/outlier_detection.rst @@ -239,6 +239,17 @@ Random partitioning produces noticeably shorter paths for anomalies. Hence, when a forest of random trees collectively produce shorter path lengths for particular samples, they are highly likely to be anomalies. +.. note:: + + Setting ``warm_start=True`` allows one to grow the same fitted + :class:`ensemble.IsolationForest` by increasing ``n_estimators`` between + successive calls to :meth:`fit`. Previously built trees are reused and only + the additional estimators are trained. Keeping ``n_estimators`` unchanged + raises a :class:`UserWarning` while decreasing it raises a + :class:`ValueError`. Parameter changes such as ``max_samples`` or + ``max_features`` only impact the newly added trees, so ensembles can become + heterogeneous. See :term:`the Glossary ` for more details. + The implementation of :class:`ensemble.IsolationForest` is based on an ensemble of :class:`tree.ExtraTreeRegressor`. Following Isolation Forest original paper, the maximum depth of each tree is set to :math:`\lceil \log_2(n) \rceil` where @@ -365,4 +376,3 @@ Novelty detection with Local Outlier Factor is illustrated below. :target: ../auto_examples/neighbors/sphx_glr_plot_lof_novelty_detection.html :align: center :scale: 75% - diff --git a/examples/ensemble/plot_isolation_forest.py b/examples/ensemble/plot_isolation_forest.py index 1b79072dff64f..0b58f9e282c9b 100644 --- a/examples/ensemble/plot_isolation_forest.py +++ b/examples/ensemble/plot_isolation_forest.py @@ -21,6 +21,10 @@ Hence, when a forest of random trees collectively produce shorter path lengths for particular samples, they are highly likely to be anomalies. +Enable ``warm_start=True`` to reuse previously fitted trees and incrementally +grow the ensemble when increasing ``n_estimators`` between calls to +``fit``. + """ print(__doc__) diff --git a/sklearn/ensemble/iforest.py b/sklearn/ensemble/iforest.py index 8a1bd36259e48..4db5504b7d74e 100644 --- a/sklearn/ensemble/iforest.py +++ b/sklearn/ensemble/iforest.py @@ -85,6 +85,11 @@ class IsolationForest(BaseBagging, OutlierMixin): data sampled with replacement. If False, sampling without replacement is performed. + warm_start : bool, optional (default=False) + When set to ``True``, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit a whole + new forest. See :term:`the Glossary `. + n_jobs : int or None, optional (default=None) The number of jobs to run in parallel for both `fit` and `predict`. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. @@ -170,6 +175,7 @@ def __init__(self, contamination="legacy", max_features=1., bootstrap=False, + warm_start=False, n_jobs=None, behaviour='old', random_state=None, @@ -185,6 +191,7 @@ def __init__(self, n_estimators=n_estimators, max_samples=max_samples, max_features=max_features, + warm_start=warm_start, n_jobs=n_jobs, random_state=random_state, verbose=verbose) diff --git a/sklearn/ensemble/tests/test_iforest.py b/sklearn/ensemble/tests/test_iforest.py index 67ba2d7f933e3..049857dbe790f 100644 --- a/sklearn/ensemble/tests/test_iforest.py +++ b/sklearn/ensemble/tests/test_iforest.py @@ -194,6 +194,114 @@ def test_iforest_parallel_regression(): assert_array_almost_equal(y1, y3) +def _generate_warm_start_data(n_samples=128, n_features=2, seed=0): + rng_local = np.random.RandomState(seed) + return rng_local.randn(n_samples, n_features) + + +def test_warm_start_grows_forest_and_matches_single_fit(): + X = _generate_warm_start_data() + + warm_clf = IsolationForest( + n_estimators=10, + warm_start=True, + behaviour='new', + contamination='auto', + random_state=0 + ) + warm_clf.fit(X) + assert_equal(len(warm_clf.estimators_), 10) + + warm_clf.set_params(n_estimators=25) + warm_clf.fit(X) + assert_equal(len(warm_clf.estimators_), 25) + + cold_clf = IsolationForest( + n_estimators=25, + behaviour='new', + contamination='auto', + random_state=0 + ).fit(X) + + assert_allclose(warm_clf.decision_function(X), + cold_clf.decision_function(X)) + + +def test_warm_start_no_increase_warns(): + X = _generate_warm_start_data() + clf = IsolationForest( + n_estimators=8, + warm_start=True, + behaviour='new', + contamination='auto', + random_state=0 + ).fit(X) + assert_equal(len(clf.estimators_), 8) + + assert_warns_message(UserWarning, + 'Warm-start fitting without increasing n_estimators ' + 'does not fit new trees.', + clf.fit, X) + assert_equal(len(clf.estimators_), 8) + + +def test_warm_start_decrease_raises(): + X = _generate_warm_start_data() + clf = IsolationForest( + n_estimators=12, + warm_start=True, + behaviour='new', + contamination='auto', + random_state=0 + ).fit(X) + + clf.set_params(n_estimators=6) + assert_raises(ValueError, clf.fit, X) + + +def test_warm_start_estimator_count_growth(): + X = _generate_warm_start_data() + clf = IsolationForest( + n_estimators=5, + warm_start=True, + behaviour='new', + contamination='auto', + random_state=0 + ) + + clf.fit(X) + assert_equal(len(clf.estimators_), 5) + + clf.set_params(n_estimators=9) + clf.fit(X) + assert_equal(len(clf.estimators_), 9) + + clf.set_params(n_estimators=14) + clf.fit(X) + assert_equal(len(clf.estimators_), 14) + + +def test_warm_start_idempotent_predictions_when_unchanged(): + X = _generate_warm_start_data() + clf = IsolationForest( + n_estimators=7, + warm_start=True, + behaviour='new', + contamination='auto', + random_state=0 + ) + + clf.fit(X) + baseline = clf.decision_function(X) + + assert_warns_message(UserWarning, + 'Warm-start fitting without increasing n_estimators ' + 'does not fit new trees.', + clf.fit, X) + assert_equal(len(clf.estimators_), 7) + assert_allclose(clf.decision_function(X), baseline) + + @pytest.mark.filterwarnings('ignore:default contamination') @pytest.mark.filterwarnings('ignore:behaviour="old"') def test_iforest_performance():