Skip to content

Commit 70aecce

Browse files
nikmlnnansters
andauthored
Average precision and AUROC update (#374)
* add realized perf AP metric * add CBPE BC AP implementation * update CBPE metrics _common_cleaning and estimate_auroc/accuracy * ap sampling error update * update docs --------- Co-authored-by: Niels <94110348+nnansters@users.noreply.github.com> Co-authored-by: Niels Nuyttens <niels@nannyml.com>
1 parent 0ec1fc8 commit 70aecce

File tree

17 files changed

+1252
-1613
lines changed

17 files changed

+1252
-1613
lines changed

docs/_static/tutorials/performance_calculation/binary/tutorial-performance-calculation-binary-car-loan-analysis.svg

Lines changed: 0 additions & 1 deletion
This file was deleted.

docs/_static/tutorials/performance_calculation/binary/tutorial-standard-metrics-calculation-binary-car-loan-analysis.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/example_notebooks/Tutorial - Calculating Standard Metrics - Binary Classification.ipynb

Lines changed: 205 additions & 207 deletions
Large diffs are not rendered by default.

docs/example_notebooks/Tutorial - Realized Performance - Binary Classification.ipynb

Lines changed: 0 additions & 1054 deletions
This file was deleted.

docs/tutorials/performance_calculation/binary_performance_calculation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ We currently support the following **standard** metrics for bianry classificatio
99
* **roc_auc**
1010
* **f1**
1111
* **precision**
12+
* **average_precision**
1213
* **recall**
1314
* **specificity**
1415
* **accuracy**

docs/tutorials/performance_estimation/binary_performance_estimation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ We currently support the following **standard** metrics for bianry classificatio
99
* **roc_auc**
1010
* **f1**
1111
* **precision**
12+
* **average_precision**
1213
* **recall**
1314
* **specificity**
1415
* **accuracy**

nannyml/drift/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
- Domain Classifer: detects drift by looking at how performance a domain classifier is at distinguising
2424
between the reference and the chunk datasets.
2525
"""
26-
from .multivariate.domain_classifier import DomainClassifierCalculator
2726
from .multivariate.data_reconstruction import DataReconstructionDriftCalculator
27+
from .multivariate.domain_classifier import DomainClassifierCalculator
2828
from .ranker import AlertCountRanker, CorrelationRanker
2929
from .univariate import FeatureType, Method, MethodFactory, UnivariateDriftCalculator

nannyml/performance_calculation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'accuracy',
2020
'confusion_matrix',
2121
'business_value',
22+
'average_precision',
2223
]
2324

2425
SUPPORTED_REGRESSION_METRIC_VALUES = [

nannyml/performance_calculation/calculator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
... y_true='repaid',
2929
... timestamp_column_name='timestamp',
3030
... problem_type='classification_binary',
31-
... metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy'],
31+
... metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'average_precision'],
3232
... chunk_size=5000)
3333
>>> calc.fit(reference_df)
3434
>>> results = calc.calculate(analysis_df)
@@ -62,6 +62,7 @@
6262
'roc_auc': StandardDeviationThreshold(),
6363
'f1': StandardDeviationThreshold(),
6464
'precision': StandardDeviationThreshold(),
65+
'average_precision': StandardDeviationThreshold(),
6566
'recall': StandardDeviationThreshold(),
6667
'specificity': StandardDeviationThreshold(),
6768
'accuracy': StandardDeviationThreshold(),
@@ -128,6 +129,7 @@ def __init__(
128129
'roc_auc': StandardDeviationThreshold(),
129130
'f1': StandardDeviationThreshold(),
130131
'precision': StandardDeviationThreshold(),
132+
'average_precision': StandardDeviationThreshold(),
131133
'recall': StandardDeviationThreshold(),
132134
'specificity': StandardDeviationThreshold(),
133135
'accuracy': StandardDeviationThreshold(),
@@ -187,7 +189,7 @@ def __init__(
187189
... y_true='repaid',
188190
... timestamp_column_name='timestamp',
189191
... problem_type='classification_binary',
190-
... metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy'],
192+
... metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'average_precision'],
191193
... chunk_size=5000)
192194
>>> calc.fit(reference_df)
193195
>>> results = calc.calculate(analysis_df)

nannyml/performance_calculation/metrics/binary_classification.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
# Author: Niels Nuyttens <niels@nannyml.com>
22
#
33
# License: Apache Software License 2.0
4+
"""Module containing implemenations for binary classification metrics and utilities."""
45
import warnings
56
from typing import Any, Dict, List, Optional, Tuple, Union
67

78
import numpy as np
89
import pandas as pd
9-
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score
10+
from sklearn.metrics import (
11+
average_precision_score,
12+
confusion_matrix,
13+
f1_score,
14+
precision_score,
15+
recall_score,
16+
roc_auc_score,
17+
)
1018

1119
from nannyml._typing import ProblemType
1220
from nannyml.base import _list_missing, _remove_nans
@@ -16,6 +24,8 @@
1624
from nannyml.sampling_error.binary_classification import (
1725
accuracy_sampling_error,
1826
accuracy_sampling_error_components,
27+
ap_sampling_error,
28+
ap_sampling_error_components,
1929
auroc_sampling_error,
2030
auroc_sampling_error_components,
2131
business_value_sampling_error,
@@ -64,7 +74,7 @@ def __init__(
6474
The Threshold instance that determines how the lower and upper threshold values will be calculated.
6575
y_pred_proba: Optional[str], default=None
6676
Name(s) of the column(s) containing your model output. For binary classification, pass a single string
67-
refering to the model output column.
77+
referring to the model output column.
6878
"""
6979
super().__init__(
7080
name='roc_auc',
@@ -81,9 +91,11 @@ def __init__(
8191
self._sampling_error_components: Tuple = ()
8292

8393
def __str__(self):
94+
"""Metric string."""
8495
return "roc_auc"
8596

8697
def _fit(self, reference_data: pd.DataFrame):
98+
"""Metric _fit implementation on reference data."""
8799
_list_missing([self.y_true, self.y_pred_proba], list(reference_data.columns))
88100
self._sampling_error_components = auroc_sampling_error_components(
89101
y_true_reference=reference_data[self.y_true],
@@ -111,6 +123,88 @@ def _sampling_error(self, data: pd.DataFrame) -> float:
111123
return auroc_sampling_error(self._sampling_error_components, data)
112124

113125

126+
@MetricFactory.register(metric='average_precision', use_case=ProblemType.CLASSIFICATION_BINARY)
127+
class BinaryClassificationAP(Metric):
128+
"""Average Precision metric.
129+
130+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
131+
"""
132+
133+
def __init__(
134+
self,
135+
y_true: str,
136+
y_pred: str,
137+
threshold: Threshold,
138+
y_pred_proba: Optional[str] = None,
139+
**kwargs,
140+
):
141+
"""Creates a new AP instance.
142+
143+
Parameters
144+
----------
145+
y_true: str
146+
The name of the column containing target values.
147+
y_pred: str
148+
The name of the column containing your model predictions.
149+
threshold: Threshold
150+
The Threshold instance that determines how the lower and upper threshold values will be calculated.
151+
y_pred_proba: Optional[str], default=None
152+
Name(s) of the column(s) containing your model output. For binary classification, pass a single string
153+
referring to the model output column.
154+
"""
155+
super().__init__(
156+
name='average_precision',
157+
y_true=y_true,
158+
y_pred=y_pred,
159+
threshold=threshold,
160+
y_pred_proba=y_pred_proba,
161+
lower_threshold_limit=0,
162+
upper_threshold_limit=1,
163+
components=[('Average Precision', 'average_precision')],
164+
)
165+
166+
# sampling error
167+
self._sampling_error_components: Tuple = ()
168+
169+
def __str__(self):
170+
"""Metric string."""
171+
return "average_precision"
172+
173+
def _fit(self, reference_data: pd.DataFrame):
174+
"""Metric _fit implementation on reference data."""
175+
_list_missing([self.y_true, self.y_pred_proba], list(reference_data.columns))
176+
# we don't want to count missing rows for sampling error
177+
reference_data = _remove_nans(reference_data, (self.y_true, self.y_pred))
178+
179+
if 1 not in reference_data[self.y_true].unique():
180+
self._sampling_error_components = np.NaN, 0
181+
else:
182+
self._sampling_error_components = ap_sampling_error_components(
183+
y_true_reference=reference_data[self.y_true],
184+
y_pred_proba_reference=reference_data[self.y_pred_proba],
185+
)
186+
187+
def _calculate(self, data: pd.DataFrame):
188+
"""Redefine to handle NaNs and edge cases."""
189+
_list_missing([self.y_true, self.y_pred_proba], list(data.columns))
190+
data = _remove_nans(data, (self.y_true, self.y_pred))
191+
192+
y_true = data[self.y_true]
193+
y_pred_proba = data[self.y_pred_proba]
194+
195+
if 1 not in y_true.unique():
196+
warnings.warn(
197+
f"'{self.y_true}' does not contain positive class for chunk, cannot calculate {self.display_name}. "
198+
f"Returning NaN."
199+
)
200+
return np.NaN
201+
else:
202+
return average_precision_score(y_true, y_pred_proba)
203+
204+
def _sampling_error(self, data: pd.DataFrame) -> float:
205+
return ap_sampling_error(self._sampling_error_components, data)
206+
207+
114208
@MetricFactory.register(metric='f1', use_case=ProblemType.CLASSIFICATION_BINARY)
115209
class BinaryClassificationF1(Metric):
116210
"""F1 score metric."""
@@ -156,6 +250,7 @@ def __str__(self):
156250

157251
def _fit(self, reference_data: pd.DataFrame):
158252
_list_missing([self.y_true, self.y_pred], list(reference_data.columns))
253+
# TODO: maybe handle data quality issues here and pass clean data to sampling error calculation?
159254
self._sampling_error_components = f1_sampling_error_components(
160255
y_true_reference=reference_data[self.y_true],
161256
y_pred_reference=reference_data[self.y_pred],

nannyml/performance_estimation/confidence_based/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SUPPORTED_METRIC_VALUES = [
2525
'roc_auc',
2626
'f1',
27+
'average_precision',
2728
'precision',
2829
'recall',
2930
'specificity',

nannyml/performance_estimation/confidence_based/cbpe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
'accuracy': StandardDeviationThreshold(),
5050
'confusion_matrix': StandardDeviationThreshold(),
5151
'business_value': StandardDeviationThreshold(),
52+
'average_precision': StandardDeviationThreshold(),
5253
}
5354

5455

0 commit comments

Comments
 (0)