diff --git a/sklearn/mixture/base.py b/sklearn/mixture/base.py index bd34333c0630b..2b7a99a65221b 100644 --- a/sklearn/mixture/base.py +++ b/sklearn/mixture/base.py @@ -257,11 +257,6 @@ def fit_predict(self, X, y=None): best_params = self._get_parameters() best_n_iter = n_iter - # Always do a final e-step to guarantee that the labels returned by - # fit_predict(X) are always consistent with fit(X).predict(X) - # for any value of max_iter and tol (and any random_state). - _, log_resp = self._e_step(X) - if not self.converged_: warnings.warn('Initialization %d did not converge. ' 'Try different init parameters, ' @@ -273,7 +268,9 @@ def fit_predict(self, X, y=None): self.n_iter_ = best_n_iter self.lower_bound_ = max_lower_bound - return log_resp.argmax(axis=1) + # Compute labels through the same path as predict to keep outputs + # identical between fit_predict(X) and fit(X).predict(X). + return self.predict(X) def _e_step(self, X): """E step. diff --git a/sklearn/mixture/tests/test_gaussian_mixture.py b/sklearn/mixture/tests/test_gaussian_mixture.py index 4d549ccd7b9d1..fb1f665a2c25e 100644 --- a/sklearn/mixture/tests/test_gaussian_mixture.py +++ b/sklearn/mixture/tests/test_gaussian_mixture.py @@ -198,6 +198,49 @@ def test_check_X(): assert_array_equal(X, _check_X(X, n_components, n_features)) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_fit_predict_matches_predict_across_seeds(seed): + rng = np.random.RandomState(seed) + X = rng.randn(200, 3) + + gmm = GaussianMixture( + n_components=3, + n_init=5, + covariance_type='full', + random_state=seed, + ) + + labels_fit_predict = gmm.fit_predict(X) + labels_predict = gmm.predict(X) + + assert_array_equal(labels_fit_predict, labels_predict) + + +def test_fit_predict_uses_best_init(): + rng = np.random.RandomState(0) + X = rng.randn(120, 2) + + gmm_fit_predict = GaussianMixture( + n_components=2, + n_init=5, + covariance_type='full', + random_state=0, + ) + + labels_fit_predict = gmm_fit_predict.fit_predict(X) + + gmm_fit_then_predict = GaussianMixture( + n_components=2, + n_init=5, + covariance_type='full', + random_state=0, + ) + + labels_predict = gmm_fit_then_predict.fit(X).predict(X) + + assert_array_equal(labels_fit_predict, labels_predict) + + def test_check_weights(): rng = np.random.RandomState(0) rand_data = RandomData(rng)