diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 6fa2e4fee5ed7..b976dbcef14b8 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -60,6 +60,8 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.cluster import KMeans +from sklearn.multioutput import MultiOutputClassifier +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.impute import SimpleImputer @@ -936,6 +938,56 @@ def test_cross_val_predict_input_types(): assert_array_equal(predictions.shape, (150,)) +def test_cross_val_predict_multioutput_predict_proba_list(): + X, y = make_multilabel_classification(n_samples=120, n_features=10, + n_classes=4, random_state=0) + estimator = MultiOutputClassifier(LinearDiscriminantAnalysis()) + proba = cross_val_predict(estimator, X, y, cv=3, method='predict_proba') + + assert isinstance(proba, list) + assert len(proba) == y.shape[1] + + expected_row_sums = np.ones(X.shape[0]) + for idx, output_proba in enumerate(proba): + n_classes = np.unique(y[:, idx]).shape[0] + assert output_proba.shape == (X.shape[0], n_classes) + assert_allclose(output_proba.sum(axis=1), expected_row_sums, + atol=1e-7) + + +def test_cross_val_predict_multioutput_predict_proba_mixed_classes(): + X_bin, y_binary = make_classification(n_samples=150, n_features=5, + n_informative=5, n_redundant=0, + n_classes=2, random_state=0) + X_multi, y_multiclass = make_classification( + n_samples=150, n_features=3, n_informative=3, n_redundant=0, + n_classes=3, n_clusters_per_class=1, random_state=1) + X = np.hstack([X_bin, X_multi]) + y = np.column_stack([y_binary, y_multiclass]) + + estimator = MultiOutputClassifier( + LogisticRegression(max_iter=2000, solver="lbfgs")) + proba = cross_val_predict(estimator, X, y, cv=3, method='predict_proba') + + assert isinstance(proba, list) + assert len(proba) == 2 + assert proba[0].shape == (X.shape[0], 2) + assert proba[1].shape == (X.shape[0], 3) + + expected_row_sums = np.ones(X.shape[0]) + assert_allclose(proba[0].sum(axis=1), expected_row_sums, atol=1e-7) + assert_allclose(proba[1].sum(axis=1), expected_row_sums, atol=1e-7) + + +def test_cross_val_predict_multioutput_without_method(): + X, y = make_multilabel_classification(n_samples=90, n_features=8, + n_classes=3, random_state=1) + estimator = MultiOutputClassifier(LinearDiscriminantAnalysis()) + predictions = cross_val_predict(estimator, X, y, cv=3) + + assert predictions.shape == y.shape + + @pytest.mark.filterwarnings('ignore: Using or importing the ABCs from') # python3.7 deprecation warnings in pandas via matplotlib :-/ def test_cross_val_predict_pandas(): diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index 463b72d40f47a..3e3dd86332427 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -320,11 +320,28 @@ class MultiOutputClassifier(MultiOutputEstimator, ClassifierMixin): ---------- estimators_ : list of ``n_output`` estimators Estimators used for predictions. + classes_ : list of arrays + Class labels for each output, of shape ``(n_outputs,)``. Each entry is + an array of shape ``(n_classes_i,)`` containing the class labels for + the corresponding output. """ def __init__(self, estimator, n_jobs=None): super().__init__(estimator, n_jobs) + def fit(self, X, y, sample_weight=None): + super().fit(X, y, sample_weight=sample_weight) + self.classes_ = [est.classes_ for est in self.estimators_] + return self + + @if_delegate_has_method('estimator') + def partial_fit(self, X, y, classes=None, sample_weight=None): + super().partial_fit(X, y, classes=classes, sample_weight=sample_weight) + if not all(hasattr(est, 'classes_') for est in self.estimators_): + return self + self.classes_ = [est.classes_ for est in self.estimators_] + return self + def predict_proba(self, X): """Probability estimates. Returns prediction probabilities for each class of each output. @@ -420,7 +437,7 @@ def fit(self, X, Y): if self.order_ == 'random': self.order_ = random_state.permutation(Y.shape[1]) elif sorted(self.order_) != list(range(Y.shape[1])): - raise ValueError("invalid order") + raise ValueError("invalid order") self.estimators_ = [clone(self.base_estimator) for _ in range(Y.shape[1])]