Skip to content

Commit

Permalink
subjective classification (#213)
Browse files Browse the repository at this point in the history
* Adds initial computation of posterior class probabilities, given prior and confusion matrix

* Adds tests for single posterior computations

The other ones are to make sure all posterior class probabilities add up to 1 for a given class prediction.

* Parameterizes test, adds several edge cases

... and breaks the posterior computation - it doesn't like all-zero columns in the confusion matrix.

* Typo

* Adds test assuring fit() stores confusion matrix as dataframe

* Prevents pain from missing classes in prior

* Adds handling of edge cases regarding confusion matrices

When the confusion matrix has no false positives (i.e. filled only on the diagonal), the posterior
for a class is simply 1 because the evidence equals the likelihood * prior.

When the confusion matrix has all-zero columns (i.e. the inner classifier fails to make any
prediction for a given class), the posterior for these classes equals their prior probability.
This is necessary to make sure that the posterior probabilities of all classes, given a
prediction, sum to 1.

* Adds predict_proba() implementation

* Adds guard against invalid model configuration

* ok, ok, numerical edge cases

* Flips expected/actual in tests to better understand failed tests

* Corrects docstring

* Adds sklearn boilerplate + checks

* Adds check for inner estimator to be fitted

* Adds evidence parameter to meta estimator

* Adds weighing of posterior by 'probas' of inner estimator

* Adds documentation for expected result of different evidence types

* Adds helper to convert probas to discrete predictions

* Adds optimization for predictions

Instead of computing the posterior for every observation at predict_proba(), the possible
posterior combinations (n_classes ** 2) are computed during fit() and stored in self as
posterior_matrix_. During prediction, this matrix only needs to be multiplied with the
(discrete) predictions, preventing lotsa recomputations of the same posteriors.

* Moves param checking to fit()

* Removes unnecessary references to inner estimator's classes_ attribute

It already can be referenced from the meta estimator's property.

* Adds docstring

* Moves method docstring to correct location
  • Loading branch information
jsamoocha authored and koaning committed Oct 18, 2019
1 parent c68fbe7 commit 1747a10
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 2 deletions.
129 changes: 127 additions & 2 deletions sklego/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sklearn import clone
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin, MetaEstimatorMixin
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_is_fitted, check_X_y, check_array, FLOAT_DTYPES

Expand Down Expand Up @@ -518,7 +519,7 @@ def fit(self, X, y):

def predict_proba(self, X):
"""
Predict new data.
Predict new data, with probabilities
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples, n_classes) the predicted data
Expand All @@ -529,11 +530,135 @@ def predict_proba(self, X):

def predict(self, X):
"""
Predict new data, with probabilities
Predict new data.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples,) the predicted data
"""
check_is_fitted(self, ['cfm_', 'classes_'])
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
return self.classes_[self.predict_proba(X).argmax(axis=1)]


class SubjectiveClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
"""
Corrects predictions of the inner classifier by taking into account a (subjective) prior distribution of the
classes.
This can be useful when there is a difference in class distribution between the training data set and
the real world. Using the confusion matrix of the inner classifier and the prior, the posterior probability for a
class, given the prediction of the inner classifier, can be computed. The background for this posterior estimation
is given `in this article <https://lucdemortier.github.io/articles/16/PerformanceMetrics>_`.
Based on the `evidence` attribute, this meta estimator's predictions are based on simple weighing of the inner
estimator's `predict_proba()` results, the posterior probabilities based on the confusion matrix, or a combination
of the two approaches.
:param estimator: An sklearn-compatible classifier estimator
:param prior: A dict of class->frequency representing the prior (a.k.a. subjective real-world) class
distribution. The class frequencies should sum to 1.
:param evidence: A string indicating which evidence should be used to correct the inner estimator's predictions.
Should be one of 'predict_proba', 'confusion_matrix', or 'both' (default). If `predict_proba`, the inner estimator's
`predict_proba()` results are multiplied by the prior distribution. In case of `confusion_matrix`, the inner
estimator's discrete predictions are converted to posterior probabilities using the prior and the inner estimator's
confusion matrix (obtained from the train data used in `fit()`). In case of `both` (default), the the inner
estimator's `predict_proba()` results are multiplied by the posterior probabilities.
"""
def __init__(self, estimator, prior, evidence='both'):
self.estimator = estimator
self.prior = prior
self.evidence = evidence

def _likelihood(self, predicted_class, given_class, cfm):
return cfm[given_class, predicted_class] / cfm[given_class, :].sum()

def _evidence(self, predicted_class, cfm):
return sum([
self._likelihood(predicted_class, given_class, cfm) * self.prior[self.classes_[given_class]]
for given_class in range(cfm.shape[0])
])

def _posterior(self, y, y_hat, cfm):
y_hat_evidence = self._evidence(y_hat, cfm)
return (
(self._likelihood(y_hat, y, cfm) * self.prior[self.classes_[y]] / y_hat_evidence)
if y_hat_evidence > 0
else self.prior[y] # in case confusion matrix has all-zero column for y_hat
)

def fit(self, X, y):
"""
Fits the inner estimator based on the data.
Raises a `ValueError` if the `y` vector contains classes that are not specified in the prior, or if the prior is
not a valid probability distribution (i.e. does not sum to 1).
:param X: array-like, shape=(n_columns, n_samples,) training data.
:param y: array-like, shape=(n_samples,) training data.
:return: Returns an instance of self.
"""
if not isinstance(self.estimator, ClassifierMixin):
raise ValueError(
'Invalid inner estimator: the SubjectiveClassifier meta model only works on classification models'
)

if not np.isclose(sum(self.prior.values()), 1):
raise ValueError('Invalid prior: the prior probabilities of all classes should sum to 1')

valid_evidence_types = ['predict_proba', 'confusion_matrix', 'both']
if self.evidence not in valid_evidence_types:
raise ValueError(f'Invalid evidence: the provided evidence should be one of {valid_evidence_types}')

X, y = check_X_y(X, y, estimator=self.estimator, dtype=FLOAT_DTYPES)
if set(y) - set(self.prior.keys()):
raise ValueError(f'Training data is inconsistent with prior: no prior defined for classes '
f'{set(y) - set(self.prior.keys())}')
self.estimator.fit(X, y)
cfm = confusion_matrix(y, self.estimator.predict(X))
self.posterior_matrix_ = np.array([
[self._posterior(y, y_hat, cfm) for y_hat in range(cfm.shape[0])] for y in range(cfm.shape[0])
])
return self

@staticmethod
def _weighted_proba(weights, y_hat_probas):
return normalize(weights * y_hat_probas, norm='l1')

@staticmethod
def _to_discrete(y_hat_probas):
y_hat_discrete = np.zeros(y_hat_probas.shape)
y_hat_discrete[np.arange(y_hat_probas.shape[0]), y_hat_probas.argmax(axis=1)] = 1
return y_hat_discrete

def predict_proba(self, X):
"""
Returns probability distribution of the class, based on the provided data.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples, n_classes) the predicted data
"""
check_is_fitted(self, ['posterior_matrix_'])
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
y_hats = self.estimator.predict_proba(X) # these are ignorant of the prior

if self.evidence == 'predict_proba':
prior_weights = np.array([self.prior[klass] for klass in self.classes_])
return self._weighted_proba(prior_weights, y_hats)
else:
posterior_probas = self._to_discrete(y_hats) @ self.posterior_matrix_.T
return self._weighted_proba(posterior_probas, y_hats) if self.evidence == 'both' else posterior_probas

def predict(self, X):
"""
Returns predicted class, based on the provided data.
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples, n_classes) the predicted data
"""
check_is_fitted(self, ['posterior_matrix_'])
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
return self.classes_[self.predict_proba(X).argmax(axis=1)]

@property
def classes_(self):
return self.estimator.classes_
117 changes: 117 additions & 0 deletions tests/test_meta/test_subjective_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
import pytest
from sklearn.cluster import DBSCAN
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import Ridge, LogisticRegression

from sklego.common import flatten
from sklego.meta import SubjectiveClassifier
from tests.conftest import general_checks, classifier_checks


@pytest.mark.parametrize("test_fn", flatten([
general_checks,
classifier_checks
]))
def test_estimator_checks_classification(test_fn):
if test_fn.__name__ == 'check_classifiers_classes':
prior = {'one': 0.1, 'two': 0.1, 'three': 0.1, -1: 0.1, 1: 0.6} # nonsensical prior to make sklearn check pass
else:
prior = {0: 0.7, 1: 0.2, 2: 0.1}

# Some of the sklearn checkers generate random y data with 3 classes, so prior needs to have these classes
estimator = SubjectiveClassifier(LogisticRegression(), prior)
test_fn(SubjectiveClassifier.__name__, estimator)


@pytest.mark.parametrize(
'classes, prior, cfm, first_class_posterior', [
([0, 1], [0.8, 0.2], [[90, 10], [10, 90]], 0.973), # numeric classes
(['a', 'b'], [0.8, 0.2], [[90, 10], [10, 90]], 0.973), # char classes
([False, True], [0.8, 0.2], [[90, 10], [10, 90]], 0.973), # bool classes
(['a', 'b', 'c'], [0.8, 0.1, 0.1], [[80, 10, 10], [10, 90, 0], [0, 0, 100]], 0.985), # n classes
([0, 1], [0.8, 0.2], [[100, 0], [0, 100]], 1.0), # "perfect" confusion matrix (no FP) -> prior is ignored
([0, 1], [0.2, 0.8], [[0, 100], [0, 100]], 0.2), # failure to predict class by inner estimator
([0, 1, 2], [0.1, 0.1, 0.8], [[0, 0, 100], [0, 0, 100], [0, 0, 100]], 0.1), # extremely biased, n classes
([0, 1, 2], [0.2, 0.1, 0.7], [[80, 0, 20], [0, 0, 100], [10, 0, 90]], 0.696) # biased, n classes
]
)
def test_posterior_computation(mocker, classes, prior, cfm, first_class_posterior):
def mock_confusion_matrix(y, y_pred):
return np.array(cfm)
mocker.patch('sklego.meta.confusion_matrix', side_effect=mock_confusion_matrix)
mock_estimator = mocker.Mock(RandomForestClassifier())
mock_estimator.classes_ = np.array(classes)
subjective_model = SubjectiveClassifier(mock_estimator, dict(zip(classes, prior)))
subjective_model.fit(np.zeros((10, 10)), np.array([classes[0]]*10))
assert pytest.approx(subjective_model.posterior_matrix_[0, 0], 0.001) == first_class_posterior
assert np.isclose(subjective_model.posterior_matrix_.sum(axis=0), 1).all()


@pytest.mark.parametrize(
'prior, y', [
({'a': 0.8, 'b': 0.2}, ['a', 'c']), # class from train data not defined in prior
({'a': 0.8, 'b': 0.2}, [0, 1]), # different data types
]
)
def test_fit_y_data_inconsistent_with_prior_failure_conditions(prior, y):
with pytest.raises(ValueError) as exc:
SubjectiveClassifier(RandomForestClassifier(), prior).fit(np.zeros((len(y), 2)), np.array(y))

assert str(exc.value).startswith('Training data is inconsistent with prior')


def test_to_discrete():
assert np.isclose(
SubjectiveClassifier._to_discrete(np.array([[1, 0], [0.8, 0.2], [0.5, 0.5], [0.2, 0.8]])),
np.array([[1, 0], [1, 0], [1, 0], [0, 1]])
).all()


@pytest.mark.parametrize(
'weights,y_hats,expected_probas', [
([0.8, 0.2], [[1, 0], [0.5, 0.5], [0.8, 0.2]], [[1, 0], [0.8, 0.2], [0.94, 0.06]]),
([0.5, 0.5], [[1, 0], [0.5, 0.5], [0.8, 0.2]], [[1, 0], [0.5, 0.5], [0.8, 0.2]]),
([[0.8, 0.2], [0.5, 0.5]], [[1, 0], [0.8, 0.2]], [[1, 0], [0.8, 0.2]])
]
)
def test_weighted_proba(weights, y_hats, expected_probas):
assert np.isclose(
SubjectiveClassifier._weighted_proba(np.array(weights), np.array(y_hats)), np.array(expected_probas), atol=1e-02
).all()


@pytest.mark.parametrize(
'evidence_type,expected_probas', [
('predict_proba', [[0.94, 0.06], [1, 0], [0.8, 0.2], [0.5, 0.5]]),
('confusion_matrix', [[0.97, 0.03], [0.97, 0.03], [0.97, 0.03], [0.47, 0.53]]),
('both', [[0.99, 0.01], [1, 0], [0.97, 0.03], [0.18, 0.82]])
]
)
def test_predict_proba(mocker, evidence_type, expected_probas):
def mock_confusion_matrix(y, y_pred):
return np.array([[80, 20], [10, 90]])
mocker.patch('sklego.meta.confusion_matrix', side_effect=mock_confusion_matrix)
mock_inner_estimator = mocker.Mock(RandomForestClassifier)
mock_inner_estimator.predict_proba.return_value = np.array([[0.8, 0.2], [1, 0], [0.5, 0.5], [0.2, 0.8]])
mock_inner_estimator.classes_ = np.array([0, 1])
subjective_model = SubjectiveClassifier(mock_inner_estimator, {0: 0.8, 1: 0.2}, evidence=evidence_type)
subjective_model.fit(np.zeros((10, 10)), np.zeros(10))
posterior_probabilities = subjective_model.predict_proba(np.zeros((4, 2)))
assert posterior_probabilities.shape == (4, 2)
assert np.isclose(posterior_probabilities, np.array(expected_probas), atol=0.01).all()


@pytest.mark.parametrize(
'inner_estimator, prior, evidence, expected_error_msg', [
(DBSCAN(), {'a': 1}, 'both', 'Invalid inner estimator'),
(Ridge(), {'a': 1}, 'predict_proba', 'Invalid inner estimator'),
(RandomForestClassifier(), {'a': 0.8, 'b': 0.1}, 'confusion_matrix', 'Invalid prior'),
(RandomForestClassifier(), {'a': 0.8, 'b': 0.2}, 'foo_evidence', 'Invalid evidence')
]
)
def test_params_failure_conditions(inner_estimator, prior, evidence, expected_error_msg):
with pytest.raises(ValueError) as exc:
SubjectiveClassifier(inner_estimator, prior, evidence).fit(np.zeros((2, 2)), np.zeros(2))

assert str(exc.value).startswith(expected_error_msg)

0 comments on commit 1747a10

Please sign in to comment.