Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions sklearn/mixture/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,6 @@ def fit_predict(self, X, y=None):
best_params = self._get_parameters()
best_n_iter = n_iter

# Always do a final e-step to guarantee that the labels returned by
# fit_predict(X) are always consistent with fit(X).predict(X)
# for any value of max_iter and tol (and any random_state).
_, log_resp = self._e_step(X)

if not self.converged_:
warnings.warn('Initialization %d did not converge. '
'Try different init parameters, '
Expand All @@ -273,7 +268,9 @@ def fit_predict(self, X, y=None):
self.n_iter_ = best_n_iter
self.lower_bound_ = max_lower_bound

return log_resp.argmax(axis=1)
# Compute labels through the same path as predict to keep outputs
# identical between fit_predict(X) and fit(X).predict(X).
return self.predict(X)

def _e_step(self, X):
"""E step.
Expand Down
43 changes: 43 additions & 0 deletions sklearn/mixture/tests/test_gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,49 @@ def test_check_X():
assert_array_equal(X, _check_X(X, n_components, n_features))


@pytest.mark.parametrize("seed", [0, 1, 2])
def test_fit_predict_matches_predict_across_seeds(seed):
rng = np.random.RandomState(seed)
X = rng.randn(200, 3)

gmm = GaussianMixture(
n_components=3,
n_init=5,
covariance_type='full',
random_state=seed,
)

labels_fit_predict = gmm.fit_predict(X)
labels_predict = gmm.predict(X)

assert_array_equal(labels_fit_predict, labels_predict)


def test_fit_predict_uses_best_init():
rng = np.random.RandomState(0)
X = rng.randn(120, 2)

gmm_fit_predict = GaussianMixture(
n_components=2,
n_init=5,
covariance_type='full',
random_state=0,
)

labels_fit_predict = gmm_fit_predict.fit_predict(X)

gmm_fit_then_predict = GaussianMixture(
n_components=2,
n_init=5,
covariance_type='full',
random_state=0,
)

labels_predict = gmm_fit_then_predict.fit(X).predict(X)

assert_array_equal(labels_fit_predict, labels_predict)


def test_check_weights():
rng = np.random.RandomState(0)
rand_data = RandomData(rng)
Expand Down