Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.cluster import KMeans
from sklearn.multioutput import MultiOutputClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from sklearn.impute import SimpleImputer

Expand Down Expand Up @@ -936,6 +938,56 @@ def test_cross_val_predict_input_types():
assert_array_equal(predictions.shape, (150,))


def test_cross_val_predict_multioutput_predict_proba_list():
X, y = make_multilabel_classification(n_samples=120, n_features=10,
n_classes=4, random_state=0)
estimator = MultiOutputClassifier(LinearDiscriminantAnalysis())
proba = cross_val_predict(estimator, X, y, cv=3, method='predict_proba')

assert isinstance(proba, list)
assert len(proba) == y.shape[1]

expected_row_sums = np.ones(X.shape[0])
for idx, output_proba in enumerate(proba):
n_classes = np.unique(y[:, idx]).shape[0]
assert output_proba.shape == (X.shape[0], n_classes)
assert_allclose(output_proba.sum(axis=1), expected_row_sums,
atol=1e-7)


def test_cross_val_predict_multioutput_predict_proba_mixed_classes():
X_bin, y_binary = make_classification(n_samples=150, n_features=5,
n_informative=5, n_redundant=0,
n_classes=2, random_state=0)
X_multi, y_multiclass = make_classification(
n_samples=150, n_features=3, n_informative=3, n_redundant=0,
n_classes=3, n_clusters_per_class=1, random_state=1)
X = np.hstack([X_bin, X_multi])
y = np.column_stack([y_binary, y_multiclass])

estimator = MultiOutputClassifier(
LogisticRegression(max_iter=2000, solver="lbfgs"))
proba = cross_val_predict(estimator, X, y, cv=3, method='predict_proba')

assert isinstance(proba, list)
assert len(proba) == 2
assert proba[0].shape == (X.shape[0], 2)
assert proba[1].shape == (X.shape[0], 3)

expected_row_sums = np.ones(X.shape[0])
assert_allclose(proba[0].sum(axis=1), expected_row_sums, atol=1e-7)
assert_allclose(proba[1].sum(axis=1), expected_row_sums, atol=1e-7)


def test_cross_val_predict_multioutput_without_method():
X, y = make_multilabel_classification(n_samples=90, n_features=8,
n_classes=3, random_state=1)
estimator = MultiOutputClassifier(LinearDiscriminantAnalysis())
predictions = cross_val_predict(estimator, X, y, cv=3)

assert predictions.shape == y.shape


@pytest.mark.filterwarnings('ignore: Using or importing the ABCs from')
# python3.7 deprecation warnings in pandas via matplotlib :-/
def test_cross_val_predict_pandas():
Expand Down
19 changes: 18 additions & 1 deletion sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,28 @@ class MultiOutputClassifier(MultiOutputEstimator, ClassifierMixin):
----------
estimators_ : list of ``n_output`` estimators
Estimators used for predictions.
classes_ : list of arrays
Class labels for each output, of shape ``(n_outputs,)``. Each entry is
an array of shape ``(n_classes_i,)`` containing the class labels for
the corresponding output.
"""

def __init__(self, estimator, n_jobs=None):
super().__init__(estimator, n_jobs)

def fit(self, X, y, sample_weight=None):
super().fit(X, y, sample_weight=sample_weight)
self.classes_ = [est.classes_ for est in self.estimators_]
return self

@if_delegate_has_method('estimator')
def partial_fit(self, X, y, classes=None, sample_weight=None):
super().partial_fit(X, y, classes=classes, sample_weight=sample_weight)
if not all(hasattr(est, 'classes_') for est in self.estimators_):
return self
self.classes_ = [est.classes_ for est in self.estimators_]
return self

def predict_proba(self, X):
"""Probability estimates.
Returns prediction probabilities for each class of each output.
Expand Down Expand Up @@ -420,7 +437,7 @@ def fit(self, X, Y):
if self.order_ == 'random':
self.order_ = random_state.permutation(Y.shape[1])
elif sorted(self.order_) != list(range(Y.shape[1])):
raise ValueError("invalid order")
raise ValueError("invalid order")

self.estimators_ = [clone(self.base_estimator)
for _ in range(Y.shape[1])]
Expand Down