-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adds initial computation of posterior class probabilities, given prior and confusion matrix * Adds tests for single posterior computations The other ones are to make sure all posterior class probabilities add up to 1 for a given class prediction. * Parameterizes test, adds several edge cases ... and breaks the posterior computation - it doesn't like all-zero columns in the confusion matrix. * Typo * Adds test assuring fit() stores confusion matrix as dataframe * Prevents pain from missing classes in prior * Adds handling of edge cases regarding confusion matrices When the confusion matrix has no false positives (i.e. filled only on the diagonal), the posterior for a class is simply 1 because the evidence equals the likelihood * prior. When the confusion matrix has all-zero columns (i.e. the inner classifier fails to make any prediction for a given class), the posterior for these classes equals their prior probability. This is necessary to make sure that the posterior probabilities of all classes, given a prediction, sum to 1. * Adds predict_proba() implementation * Adds guard against invalid model configuration * ok, ok, numerical edge cases * Flips expected/actual in tests to better understand failed tests * Corrects docstring * Adds sklearn boilerplate + checks * Adds check for inner estimator to be fitted * Adds evidence parameter to meta estimator * Adds weighing of posterior by 'probas' of inner estimator * Adds documentation for expected result of different evidence types * Adds helper to convert probas to discrete predictions * Adds optimization for predictions Instead of computing the posterior for every observation at predict_proba(), the possible posterior combinations (n_classes ** 2) are computed during fit() and stored in self as posterior_matrix_. During prediction, this matrix only needs to be multiplied with the (discrete) predictions, preventing lotsa recomputations of the same posteriors. * Moves param checking to fit() * Removes unnecessary references to inner estimator's classes_ attribute It already can be referenced from the meta estimator's property. * Adds docstring * Moves method docstring to correct location
- Loading branch information
Showing
2 changed files
with
244 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import numpy as np | ||
import pytest | ||
from sklearn.cluster import DBSCAN | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.linear_model import Ridge, LogisticRegression | ||
|
||
from sklego.common import flatten | ||
from sklego.meta import SubjectiveClassifier | ||
from tests.conftest import general_checks, classifier_checks | ||
|
||
|
||
@pytest.mark.parametrize("test_fn", flatten([ | ||
general_checks, | ||
classifier_checks | ||
])) | ||
def test_estimator_checks_classification(test_fn): | ||
if test_fn.__name__ == 'check_classifiers_classes': | ||
prior = {'one': 0.1, 'two': 0.1, 'three': 0.1, -1: 0.1, 1: 0.6} # nonsensical prior to make sklearn check pass | ||
else: | ||
prior = {0: 0.7, 1: 0.2, 2: 0.1} | ||
|
||
# Some of the sklearn checkers generate random y data with 3 classes, so prior needs to have these classes | ||
estimator = SubjectiveClassifier(LogisticRegression(), prior) | ||
test_fn(SubjectiveClassifier.__name__, estimator) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'classes, prior, cfm, first_class_posterior', [ | ||
([0, 1], [0.8, 0.2], [[90, 10], [10, 90]], 0.973), # numeric classes | ||
(['a', 'b'], [0.8, 0.2], [[90, 10], [10, 90]], 0.973), # char classes | ||
([False, True], [0.8, 0.2], [[90, 10], [10, 90]], 0.973), # bool classes | ||
(['a', 'b', 'c'], [0.8, 0.1, 0.1], [[80, 10, 10], [10, 90, 0], [0, 0, 100]], 0.985), # n classes | ||
([0, 1], [0.8, 0.2], [[100, 0], [0, 100]], 1.0), # "perfect" confusion matrix (no FP) -> prior is ignored | ||
([0, 1], [0.2, 0.8], [[0, 100], [0, 100]], 0.2), # failure to predict class by inner estimator | ||
([0, 1, 2], [0.1, 0.1, 0.8], [[0, 0, 100], [0, 0, 100], [0, 0, 100]], 0.1), # extremely biased, n classes | ||
([0, 1, 2], [0.2, 0.1, 0.7], [[80, 0, 20], [0, 0, 100], [10, 0, 90]], 0.696) # biased, n classes | ||
] | ||
) | ||
def test_posterior_computation(mocker, classes, prior, cfm, first_class_posterior): | ||
def mock_confusion_matrix(y, y_pred): | ||
return np.array(cfm) | ||
mocker.patch('sklego.meta.confusion_matrix', side_effect=mock_confusion_matrix) | ||
mock_estimator = mocker.Mock(RandomForestClassifier()) | ||
mock_estimator.classes_ = np.array(classes) | ||
subjective_model = SubjectiveClassifier(mock_estimator, dict(zip(classes, prior))) | ||
subjective_model.fit(np.zeros((10, 10)), np.array([classes[0]]*10)) | ||
assert pytest.approx(subjective_model.posterior_matrix_[0, 0], 0.001) == first_class_posterior | ||
assert np.isclose(subjective_model.posterior_matrix_.sum(axis=0), 1).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'prior, y', [ | ||
({'a': 0.8, 'b': 0.2}, ['a', 'c']), # class from train data not defined in prior | ||
({'a': 0.8, 'b': 0.2}, [0, 1]), # different data types | ||
] | ||
) | ||
def test_fit_y_data_inconsistent_with_prior_failure_conditions(prior, y): | ||
with pytest.raises(ValueError) as exc: | ||
SubjectiveClassifier(RandomForestClassifier(), prior).fit(np.zeros((len(y), 2)), np.array(y)) | ||
|
||
assert str(exc.value).startswith('Training data is inconsistent with prior') | ||
|
||
|
||
def test_to_discrete(): | ||
assert np.isclose( | ||
SubjectiveClassifier._to_discrete(np.array([[1, 0], [0.8, 0.2], [0.5, 0.5], [0.2, 0.8]])), | ||
np.array([[1, 0], [1, 0], [1, 0], [0, 1]]) | ||
).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'weights,y_hats,expected_probas', [ | ||
([0.8, 0.2], [[1, 0], [0.5, 0.5], [0.8, 0.2]], [[1, 0], [0.8, 0.2], [0.94, 0.06]]), | ||
([0.5, 0.5], [[1, 0], [0.5, 0.5], [0.8, 0.2]], [[1, 0], [0.5, 0.5], [0.8, 0.2]]), | ||
([[0.8, 0.2], [0.5, 0.5]], [[1, 0], [0.8, 0.2]], [[1, 0], [0.8, 0.2]]) | ||
] | ||
) | ||
def test_weighted_proba(weights, y_hats, expected_probas): | ||
assert np.isclose( | ||
SubjectiveClassifier._weighted_proba(np.array(weights), np.array(y_hats)), np.array(expected_probas), atol=1e-02 | ||
).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'evidence_type,expected_probas', [ | ||
('predict_proba', [[0.94, 0.06], [1, 0], [0.8, 0.2], [0.5, 0.5]]), | ||
('confusion_matrix', [[0.97, 0.03], [0.97, 0.03], [0.97, 0.03], [0.47, 0.53]]), | ||
('both', [[0.99, 0.01], [1, 0], [0.97, 0.03], [0.18, 0.82]]) | ||
] | ||
) | ||
def test_predict_proba(mocker, evidence_type, expected_probas): | ||
def mock_confusion_matrix(y, y_pred): | ||
return np.array([[80, 20], [10, 90]]) | ||
mocker.patch('sklego.meta.confusion_matrix', side_effect=mock_confusion_matrix) | ||
mock_inner_estimator = mocker.Mock(RandomForestClassifier) | ||
mock_inner_estimator.predict_proba.return_value = np.array([[0.8, 0.2], [1, 0], [0.5, 0.5], [0.2, 0.8]]) | ||
mock_inner_estimator.classes_ = np.array([0, 1]) | ||
subjective_model = SubjectiveClassifier(mock_inner_estimator, {0: 0.8, 1: 0.2}, evidence=evidence_type) | ||
subjective_model.fit(np.zeros((10, 10)), np.zeros(10)) | ||
posterior_probabilities = subjective_model.predict_proba(np.zeros((4, 2))) | ||
assert posterior_probabilities.shape == (4, 2) | ||
assert np.isclose(posterior_probabilities, np.array(expected_probas), atol=0.01).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'inner_estimator, prior, evidence, expected_error_msg', [ | ||
(DBSCAN(), {'a': 1}, 'both', 'Invalid inner estimator'), | ||
(Ridge(), {'a': 1}, 'predict_proba', 'Invalid inner estimator'), | ||
(RandomForestClassifier(), {'a': 0.8, 'b': 0.1}, 'confusion_matrix', 'Invalid prior'), | ||
(RandomForestClassifier(), {'a': 0.8, 'b': 0.2}, 'foo_evidence', 'Invalid evidence') | ||
] | ||
) | ||
def test_params_failure_conditions(inner_estimator, prior, evidence, expected_error_msg): | ||
with pytest.raises(ValueError) as exc: | ||
SubjectiveClassifier(inner_estimator, prior, evidence).fit(np.zeros((2, 2)), np.zeros(2)) | ||
|
||
assert str(exc.value).startswith(expected_error_msg) |