diff --git a/sklearn/ensemble/tests/test_voting.py b/sklearn/ensemble/tests/test_voting.py index 2a19bc9a64dc0..f207cbc06171a 100644 --- a/sklearn/ensemble/tests/test_voting.py +++ b/sklearn/ensemble/tests/test_voting.py @@ -19,7 +19,7 @@ from sklearn.svm import SVC from sklearn.multiclass import OneVsRestClassifier from sklearn.neighbors import KNeighborsClassifier -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin from sklearn.dummy import DummyRegressor @@ -348,6 +348,109 @@ def test_sample_weight(): assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight) +def test_none_estimator_with_sample_weight_classifier(): + class SupportsSampleWeightClassifier(ClassifierMixin, BaseEstimator): + def fit(self, X, y, sample_weight=None): + self.sample_weight_ = sample_weight + return self + + def predict(self, X): + return np.zeros(len(X), dtype=int) + + class AnotherClassifier(SupportsSampleWeightClassifier): + pass + + X_local = np.arange(6).reshape(-1, 1) + y_local = np.array([0, 1, 0, 1, 0, 1]) + sample_weight = np.ones_like(y_local, dtype=float) + + voter = VotingClassifier( + estimators=[ + ('to_drop', SupportsSampleWeightClassifier()), + ('keep', AnotherClassifier()) + ], + voting='hard' + ) + voter.fit(X_local, y_local, sample_weight=sample_weight) + voter.set_params(to_drop=None) + + voter.fit(X_local, y_local, sample_weight=sample_weight) + + assert len(voter.estimators_) == 1 + assert isinstance(voter.estimators_[0], AnotherClassifier) + assert np.array_equal(voter.estimators_[0].sample_weight_, sample_weight) + + +def test_none_estimator_with_sample_weight_regressor(): + class SupportsSampleWeightRegressor(RegressorMixin, BaseEstimator): + def fit(self, X, y, sample_weight=None): + self.sample_weight_ = sample_weight + return self + + def predict(self, X): + return np.zeros(len(X), dtype=float) + + class AnotherRegressor(SupportsSampleWeightRegressor): + pass + + X_local = np.arange(6).reshape(-1, 1) + y_local = np.arange(6, dtype=float) + sample_weight = np.ones_like(y_local, dtype=float) + + voter = VotingRegressor([ + ('to_drop', SupportsSampleWeightRegressor()), + ('keep', AnotherRegressor()) + ]) + voter.fit(X_local, y_local, sample_weight=sample_weight) + voter.set_params(to_drop=None) + + voter.fit(X_local, y_local, sample_weight=sample_weight) + + assert len(voter.estimators_) == 1 + assert isinstance(voter.estimators_[0], AnotherRegressor) + assert np.array_equal(voter.estimators_[0].sample_weight_, sample_weight) + + +def test_sample_weight_mixed_support_with_none(): + class SupportsSampleWeightClassifier(ClassifierMixin, BaseEstimator): + def fit(self, X, y, sample_weight=None): + self.sample_weight_ = sample_weight + return self + + def predict(self, X): + return np.zeros(len(X), dtype=int) + + class NoSampleWeightClassifier(ClassifierMixin, BaseEstimator): + def fit(self, X, y): + return self + + def predict(self, X): + return np.zeros(len(X), dtype=int) + + X_local = np.arange(6).reshape(-1, 1) + y_local = np.array([0, 1, 0, 1, 0, 1]) + sample_weight = np.ones_like(y_local, dtype=float) + + voter = VotingClassifier( + estimators=[ + ('to_drop', SupportsSampleWeightClassifier()), + ('no_sw', NoSampleWeightClassifier()) + ], + voting='hard' + ) + voter.set_params(to_drop=None) + + msg = "Underlying estimator 'no_sw' does not support sample weights." + assert_raise_message( + ValueError, + msg, + voter.fit, + X_local, + y_local, + sample_weight=sample_weight + ) + + def test_sample_weight_kwargs(): """Check that VotingClassifier passes sample_weight as kwargs""" class MockClassifier(BaseEstimator, ClassifierMixin): diff --git a/sklearn/ensemble/voting.py b/sklearn/ensemble/voting.py index 7afa7180cc5f3..d8e14b152d3ab 100644 --- a/sklearn/ensemble/voting.py +++ b/sklearn/ensemble/voting.py @@ -78,6 +78,8 @@ def fit(self, X, y, sample_weight=None): if sample_weight is not None: for name, step in self.estimators: + if step is None: + continue if not has_fit_parameter(step, 'sample_weight'): raise ValueError('Underlying estimator \'%s\' does not' ' support sample weights.' % name)