Skip to content

Commit

Permalink
Thresholder (#154)
Browse files Browse the repository at this point in the history
* added thresholder method and tests

* oke now with style checks

* really small change

* added base for checking of properties
  • Loading branch information
koaning authored Jun 19, 2019
1 parent 647dbfc commit b07f5aa
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 2 deletions.
7 changes: 7 additions & 0 deletions sklego/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class ProbabilisticClassifierMeta(type):
def __instancecheck__(self, other):
return hasattr(other, 'predict_proba')


class ProbabilisticClassifier(metaclass=ProbabilisticClassifierMeta):
pass
52 changes: 50 additions & 2 deletions sklego/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sklearn.base import BaseEstimator, TransformerMixin, MetaEstimatorMixin
from sklearn.utils.validation import check_is_fitted, check_X_y, check_array, FLOAT_DTYPES

from sklego.base import ProbabilisticClassifier
from sklego.common import as_list


Expand Down Expand Up @@ -130,7 +131,6 @@ class DecayEstimator(BaseEstimator):
The DecayEstimator will use exponential decay to weight the parameters.
w_{t-1} = decay * w_{t}
"""

def __init__(self, model, decay: float = 0.999, decay_func="exponential"):
Expand Down Expand Up @@ -163,7 +163,7 @@ def fit(self, X, y):

def predict(self, X):
"""
Predict new data by making random guesses.
Predict new data.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples,) the predicted data
Expand All @@ -175,3 +175,51 @@ def predict(self, X):

def score(self, X, y):
return self.estimator_.score(X, y)


class Thresholder(BaseEstimator):
"""
Takes a two class estimator and moves the threshold. This way you might
design the algorithm to only accept a certain class if the probability
for it is larger than, say, 90% instead of 50%.
"""

def __init__(self, model, threshold: float):
self.model = model
self.threshold = threshold

def fit(self, X, y):
"""
Fit the data.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:param y: array-like, shape=(n_samples,) training data.
:return: Returns an instance of self.
"""
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
self.estimator_ = clone(self.model)
if not isinstance(self.estimator_, ProbabilisticClassifier):
raise ValueError("The Thresholder meta model only works on classifcation models with .predict_proba.")
self.estimator_.fit(X, y)
self.classes_ = self.estimator_.classes_
if len(self.classes_) != 2:
raise ValueError("The Thresholder meta model only works on models with two classes.")
return self

def predict(self, X):
"""
Predict new data.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples,) the predicted data
"""
check_is_fitted(self, ['classes_', 'estimator_'])
predicate = self.estimator_.predict_proba(X)[:, 1] > self.threshold
return np.where(predicate, self.classes_[1], self.classes_[0])

def predict_proba(self, X):
check_is_fitted(self, ['classes_', 'estimator_'])
return self.estimator_.predict_proba(X)

def score(self, X, y):
return self.estimator_.score(X, y)
75 changes: 75 additions & 0 deletions tests/test_meta/test_thresholder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
import numpy as np
from sklearn.linear_model import LogisticRegression, LinearRegression


from sklego.common import flatten
from sklego.meta import Thresholder
from sklearn.utils import estimator_checks


@pytest.mark.parametrize("test_fn", flatten([
# GENERAL CHECKS #
# estimator_checks.check_fit2d_predict1d -> we only test for two classes
# estimator_checks.check_methods_subset_invariance -> we only test for two classes
estimator_checks.check_fit2d_1sample,
estimator_checks.check_fit2d_1feature,
estimator_checks.check_fit1d,
estimator_checks.check_get_params_invariance,
estimator_checks.check_set_params,
estimator_checks.check_dict_unchanged,
# estimator_checks.check_dont_overwrite_parameters -> we only test for two classes
# CLASSIFIER CHECKS #
estimator_checks.check_classifier_data_not_an_array,
estimator_checks.check_classifiers_one_label,
# estimator_checks.check_classifiers_classes -> we only test for two classes
estimator_checks.check_estimators_partial_fit_n_features,
# estimator_checks.check_classifiers_train -> we only test for two classes
# estimator_checks.check_supervised_y_2d -> we only test for two classes
estimator_checks.check_supervised_y_no_nan,
estimator_checks.check_estimators_unfitted,
estimator_checks.check_non_transformer_estimators_n_iter,
estimator_checks.check_decision_proba_consistency,
]))
def test_standard_checks(test_fn):
trf = Thresholder(LogisticRegression(), threshold=0.5)
test_fn(Thresholder.__name__, trf)


def test_same_threshold():
mod1 = Thresholder(LogisticRegression(), threshold=0.5)
mod2 = LogisticRegression()
X = np.random.normal(0, 1, (100, 3))
y = np.random.normal(0, 1, (100,)) < 0
assert (mod1.fit(X, y).predict(X) == mod2.fit(X, y).predict(X)).all()


def test_diff_threshold():
mod1 = Thresholder(LogisticRegression(), threshold=0.5)
mod2 = Thresholder(LogisticRegression(), threshold=0.7)
mod3 = Thresholder(LogisticRegression(), threshold=0.9)
np.random.seed(42)
X = np.random.normal(0, 1, (100, 3))
y = np.random.normal(0, 1, (100,)) < 0
assert mod1.fit(X, y).predict(X).sum() >= mod2.fit(X, y).predict(X).sum()
assert mod2.fit(X, y).predict(X).sum() >= mod3.fit(X, y).predict(X).sum()


def test_raise_error1():
with pytest.raises(ValueError):
# we only support classification models
mod = Thresholder(LinearRegression(), threshold=0.7)
np.random.seed(42)
X = np.random.normal(0, 1, (100, 3))
y = np.random.normal(0, 1, (100,)) < 0
mod.fit(X, y)


def test_raise_error2():
with pytest.raises(ValueError):
mod = Thresholder(LinearRegression(), threshold=0.7)
np.random.seed(42)
X = np.random.normal(0, 1, (1000, 3))
# we only support two classes
y = np.random.choice(["a", "b", "c"], 1000)
mod.fit(X, y)

0 comments on commit b07f5aa

Please sign in to comment.