forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
HistGradientBoostingClassifier does not work with string target when early stopping turned on
Description
The scorer used under the hood during early stopping is provided with y_true being integer while y_pred are original classes (i.e. string). We need to encode y_true each time that we want to compute the score.
Steps/Code to Reproduce
import numpy as np
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingClassifier
X = np.random.randn(100, 10)
y = np.array(['x'] * 50 + ['y'] * 50, dtype=object)
gbrt = HistGradientBoostingClassifier(n_iter_no_change=10)
gbrt.fit(X, y)Expected Results
No error is thrown
Actual Results
TypeError: '<' not supported between instances of 'str' and 'float'Additional reproduction & full stack trace (from PR #46)
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import HistGradientBoostingClassifier
X, y = make_classification(random_state=0)
y = np.where(y == 0, 'class_a', 'class_b')
HistGradientBoostingClassifier(
max_iter=30,
n_iter_no_change=5,
scoring='accuracy',
validation_fraction=0.2,
random_state=0,
).fit(X, y)Traceback (most recent call last):
...
File "sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py", line 437, in _maybe_do_early_stopping
self.scorer_(self, X_binned_val, y_val)
File "sklearn/metrics/_scorer.py", line 88, in __call__
return self._score(
File "sklearn/metrics/_scorer.py", line 252, in _score
return scorer(y_true, y_pred, sample_weight=sample_weight)
File "sklearn/metrics/_classification.py", line 1871, in accuracy_score
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
File "sklearn/metrics/_classification.py", line 101, in _check_targets
raise TypeError("< not supported between instances of str and float")
TypeError: '<' not supported between instances of 'str' and 'float'
Potential resolution
Encode y_true using self.classes_ when computing early-stopping scores so the scorer sees original class labels consistently.
ping @NicolasHug @ogrisel
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels