From a52283e67c414e834e17f1b20833968312d2f9dd Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Fri, 26 Dec 2025 02:54:41 +0000 Subject: [PATCH] fix(linear_model): stabilize logregcv refit paths --- sklearn/linear_model/logistic.py | 139 +++++++++++++------- sklearn/linear_model/tests/test_logistic.py | 92 +++++++++++++ 2 files changed, 185 insertions(+), 46 deletions(-) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 1999f1dae7ba3..71b586aad5b67 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -2003,6 +2003,7 @@ def fit(self, X, y, sample_weight=None): multi_class = _check_multi_class(self.multi_class, solver, len(classes)) + self.multi_class = multi_class if solver in ['sag', 'saga']: max_squared_sum = row_norms(X, squared=True).max() @@ -2081,52 +2082,89 @@ def fit(self, X, y, sample_weight=None): # - n_iter is of shape # (n_classes, n_folds, n_Cs . n_l1_ratios) or # (1, n_folds, n_Cs . n_l1_ratios) - coefs_paths, Cs, scores, n_iter_ = zip(*fold_coefs_) - self.Cs_ = Cs[0] + coefs_paths_blocks, Cs_blocks, scores_blocks, n_iter_blocks = zip( + *fold_coefs_) + coefs_paths = np.stack(coefs_paths_blocks, axis=0) + Cs_blocks = np.stack(Cs_blocks, axis=0) + scores = np.stack(scores_blocks, axis=0) + n_iter_ = np.stack(n_iter_blocks, axis=0) + + self.Cs_ = Cs_blocks[0] + if not np.allclose(Cs_blocks, self.Cs_): + raise ValueError( + "Inconsistent regularization grids across folds.") + + n_folds = len(folds) + n_l1 = len(l1_ratios_) + n_cs = self.Cs_.size + n_classes_total = len(classes) + expected_blocks = len(iter_encoded_labels) * n_folds * n_l1 + + if (coefs_paths.shape[0] != expected_blocks or + scores.shape[0] != expected_blocks or + n_iter_.shape[0] != expected_blocks): + raise ValueError( + "Unexpected internal LogisticRegressionCV shapes: " + "expected %d blocks, got coefs=%d, scores=%d, n_iter=%d" + % (expected_blocks, coefs_paths.shape[0], + scores.shape[0], n_iter_.shape[0])) + if multi_class == 'multinomial': - coefs_paths = np.reshape( - coefs_paths, - (len(folds), len(l1_ratios_) * len(self.Cs_), n_classes, -1) - ) - # equiv to coefs_paths = np.moveaxis(coefs_paths, (0, 1, 2, 3), - # (1, 2, 0, 3)) - coefs_paths = np.swapaxes(coefs_paths, 0, 1) - coefs_paths = np.swapaxes(coefs_paths, 0, 2) - self.n_iter_ = np.reshape( - n_iter_, - (1, len(folds), len(self.Cs_) * len(l1_ratios_)) - ) - # repeat same scores across all classes - scores = np.tile(scores, (n_classes, 1, 1)) + coefs_paths = coefs_paths.reshape( + n_folds, n_l1, n_cs, n_classes_total, -1) + coefs_paths = np.moveaxis(coefs_paths, 3, 0) + coefs_paths = coefs_paths.reshape( + n_classes_total, n_folds, n_l1 * n_cs, -1) + + scores = scores.reshape(n_folds, n_l1, n_cs) + scores = scores.reshape(1, n_folds, n_l1 * n_cs) + scores = np.tile(scores, (n_classes_total, 1, 1)) + + n_iter_ = n_iter_.reshape(1, n_folds, n_l1 * n_cs) else: - coefs_paths = np.reshape( - coefs_paths, - (n_classes, len(folds), len(self.Cs_) * len(l1_ratios_), - -1) - ) - self.n_iter_ = np.reshape( - n_iter_, - (n_classes, len(folds), len(self.Cs_) * len(l1_ratios_)) - ) - scores = np.reshape(scores, (n_classes, len(folds), -1)) - self.scores_ = dict(zip(classes, scores)) - self.coefs_paths_ = dict(zip(classes, coefs_paths)) + n_targets = len(iter_encoded_labels) + coefs_paths = coefs_paths.reshape( + n_targets, n_folds, n_l1, n_cs, -1) + coefs_paths = coefs_paths.reshape( + n_targets, n_folds, n_l1 * n_cs, -1) + + scores = scores.reshape(n_targets, n_folds, n_l1, n_cs) + scores = scores.reshape(n_targets, n_folds, n_l1 * n_cs) + + n_iter_ = n_iter_.reshape(n_targets, n_folds, n_l1 * n_cs) + + if coefs_paths.ndim != 4: + raise ValueError( + "LogisticRegressionCV expected 4D coefficient paths but " + "got %d dimensions" % coefs_paths.ndim) + + scores_per_class = scores + coefs_paths_per_class = coefs_paths + + if self.penalty == 'elasticnet': + l1_ratios_array = np.asarray(l1_ratios_, dtype=float) + else: + l1_ratios_array = None + + self.scores_ = dict(zip(classes, scores_per_class)) + self.coefs_paths_ = dict(zip(classes, coefs_paths_per_class)) + self.n_iter_ = n_iter_ self.C_ = list() self.l1_ratio_ = list() self.coef_ = np.empty((n_classes, X.shape[1])) self.intercept_ = np.zeros(n_classes) + fold_indices = np.arange(len(folds)) for index, (cls, encoded_label) in enumerate( zip(iter_classes, iter_encoded_labels)): if multi_class == 'ovr': - scores = self.scores_[cls] - coefs_paths = self.coefs_paths_[cls] + class_scores = self.scores_[cls] + class_coefs_paths = self.coefs_paths_[cls] else: - # For multinomial, all scores are the same across classes - scores = scores[0] - # coefs_paths will keep its original shape because - # logistic_regression_path expects it this way + # For multinomial, reuse the shared structures across classes + class_scores = scores_per_class[0] + class_coefs_paths = coefs_paths_per_class if self.refit: # best_index is between 0 and (n_Cs . n_l1_ratios - 1) @@ -2134,21 +2172,25 @@ def fit(self, X, y, sample_weight=None): # the layout of scores is # [c1, c2, c1, c2, c1, c2] # l1_1 , l1_2 , l1_3 - best_index = scores.sum(axis=0).argmax() + best_index = class_scores.sum(axis=0).argmax() best_index_C = best_index % len(self.Cs_) C_ = self.Cs_[best_index_C] self.C_.append(C_) best_index_l1 = best_index // len(self.Cs_) - l1_ratio_ = l1_ratios_[best_index_l1] + if self.penalty == 'elasticnet': + l1_ratio_ = float(l1_ratios_array[best_index_l1]) + else: + l1_ratio_ = None self.l1_ratio_.append(l1_ratio_) if multi_class == 'multinomial': - coef_init = np.mean(coefs_paths[:, :, best_index, :], - axis=1) + coef_init = np.mean( + class_coefs_paths[:, :, best_index, :], axis=1) else: - coef_init = np.mean(coefs_paths[:, best_index, :], axis=0) + coef_init = np.mean( + class_coefs_paths[:, best_index, :], axis=0) # Note that y is label encoded and hence pos_class must be # the encoded label / None (for 'multinomial') @@ -2169,19 +2211,24 @@ def fit(self, X, y, sample_weight=None): else: # Take the best scores across every fold and the average of # all coefficients corresponding to the best scores. - best_indices = np.argmax(scores, axis=1) - if self.multi_class == 'ovr': - w = np.mean([coefs_paths[i, best_indices[i], :] - for i in range(len(folds))], axis=0) + best_indices = np.argmax(class_scores, axis=1) + if multi_class == 'ovr': + w = class_coefs_paths[fold_indices, best_indices, :] + w = w.mean(axis=0) else: - w = np.mean([coefs_paths[:, i, best_indices[i], :] - for i in range(len(folds))], axis=0) + w = class_coefs_paths[:, fold_indices, best_indices, :] + w = w.mean(axis=1) best_indices_C = best_indices % len(self.Cs_) self.C_.append(np.mean(self.Cs_[best_indices_C])) best_indices_l1 = best_indices // len(self.Cs_) - self.l1_ratio_.append(np.mean(l1_ratios_[best_indices_l1])) + if self.penalty == 'elasticnet': + mean_l1_ratio = float( + np.mean(l1_ratios_array[best_indices_l1])) + else: + mean_l1_ratio = None + self.l1_ratio_.append(mean_l1_ratio) if multi_class == 'multinomial': self.C_ = np.tile(self.C_, n_classes) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 6ad9a4ec99d77..f7a01b8bccf86 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -702,6 +702,98 @@ def test_ovr_multinomial_iris(): assert_equal(scores.shape, (3, n_cv, 10)) +@pytest.mark.parametrize( + "solver, penalty, l1_ratios", + [ + ("saga", "elasticnet", [0.3, 0.7]), + ("liblinear", "l2", None), + ], +) +def test_logistic_regression_cv_refit_false_binary_shapes( + solver, penalty, l1_ratios): + X, y = make_classification( + n_samples=80, n_features=6, n_informative=5, + n_redundant=0, n_classes=2, random_state=0) + + params = dict( + solver=solver, + penalty=penalty, + Cs=[0.1, 1.0], + cv=3, + refit=False, + tol=1e-3, + max_iter=3000 if solver == 'saga' else 200, + random_state=0, + ) + if l1_ratios is not None: + params["l1_ratios"] = l1_ratios + + clf = LogisticRegressionCV(**params) + with ignore_warnings(category=ConvergenceWarning): + clf.fit(X, y) + + assert_equal(clf.coef_.shape, (1, X.shape[1])) + assert_equal(clf.intercept_.shape, (1,)) + + assert_equal(len(clf.coefs_paths_), 1) + coefs_path = next(iter(clf.coefs_paths_.values())) + class_scores = next(iter(clf.scores_.values())) + + expected_features = X.shape[1] + int(clf.fit_intercept) + n_l1 = clf.l1_ratios_.size + n_cs = clf.Cs_.size + + if clf.penalty == 'elasticnet': + assert_equal(coefs_path.shape, + (params["cv"], n_cs, n_l1, expected_features)) + assert_equal(class_scores.shape, + (params["cv"], n_cs, n_l1)) + flattened = coefs_path.reshape(params["cv"], n_cs * n_l1, + expected_features) + assert_equal(flattened.shape[1], n_cs * n_l1) + else: + assert_equal(coefs_path.shape, + (params["cv"], n_cs, expected_features)) + assert_equal(class_scores.shape, + (params["cv"], n_cs)) + + +def test_logistic_regression_cv_refit_false_multinomial(): + X, y = make_classification( + n_samples=120, n_features=5, n_informative=4, + n_redundant=0, n_classes=3, random_state=0) + + clf = LogisticRegressionCV( + solver='saga', + penalty='elasticnet', + l1_ratios=[0.2, 0.8], + Cs=[0.1, 1.0], + cv=3, + refit=False, + multi_class='multinomial', + tol=1e-3, + max_iter=4000, + random_state=0, + ) + + with ignore_warnings(category=ConvergenceWarning): + clf.fit(X, y) + + n_classes = len(clf.classes_) + expected_features = X.shape[1] + int(clf.fit_intercept) + assert_equal(clf.coef_.shape, (n_classes, X.shape[1])) + assert_equal(clf.intercept_.shape, (n_classes,)) + + for coefs_path, class_scores in zip(clf.coefs_paths_.values(), + clf.scores_.values()): + assert_equal(coefs_path.shape, + (class_scores.shape[0], clf.Cs_.size, + clf.l1_ratios_.size, expected_features)) + assert_equal(class_scores.shape, + (class_scores.shape[0], clf.Cs_.size, + clf.l1_ratios_.size)) + + def test_logistic_regression_solvers(): X, y = make_classification(n_features=10, n_informative=5, random_state=0)