Skip to content

Commit ed8cb08

Browse files
authored
Calibrator factory decorator (#341)
* Refactor CalibratorFactory to use decorator * formatting, linting and whatnot * Fix broken tests
1 parent 0478b24 commit ed8cb08

35 files changed

+10986
-10901
lines changed

docs/_static/butterfly-scatterplot.svg

Lines changed: 2088 additions & 2088 deletions
Loading

docs/_static/example_california_latitude_longitude_scatter.svg

Lines changed: 586 additions & 586 deletions
Loading

docs/_static/example_california_performance_estimation_tmp.svg

Lines changed: 689 additions & 689 deletions
Loading

docs/_static/example_green_taxi_feature_importance.svg

Lines changed: 694 additions & 694 deletions
Loading

docs/_static/example_green_taxi_tip_amount_boxplot.svg

Lines changed: 288 additions & 288 deletions
Loading

docs/_static/example_green_taxi_tip_amount_distribution.svg

Lines changed: 377 additions & 377 deletions
Loading

docs/_static/how-it-works-dle-data.svg

Lines changed: 471 additions & 471 deletions
Loading

docs/_static/how-it-works-dle-regression-PI.svg

Lines changed: 719 additions & 719 deletions
Loading

docs/_static/how-it-works-dle-regression-abs-errors-hist.svg

Lines changed: 823 additions & 823 deletions
Loading

docs/_static/how-it-works-dle-regression-errors-hist.svg

Lines changed: 737 additions & 737 deletions
Loading

docs/_static/how-it-works-dle-regression.svg

Lines changed: 592 additions & 592 deletions
Loading

docs/_static/how-it-works/chunks_stability_of_accuracy.svg

Lines changed: 665 additions & 665 deletions
Loading

docs/_static/how-it-works/ranking-abs-perf-features-compare.svg

Lines changed: 754 additions & 754 deletions
Loading

docs/_static/how-it-works/ranking-abs-perf.svg

Lines changed: 683 additions & 683 deletions
Loading

docs/_static/tutorials/performance_estimation/binary/tutorial-custom-metric-estimation-binary-car-loan-analysis-with-ref.svg

Lines changed: 667 additions & 667 deletions
Loading

docs/examples.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ Examples
77
:maxdepth: 2
88

99
examples/california_housing
10-
examples/green_taxi
10+
examples/green_taxi

docs/tutorials/performance_calculation/multiclass_performance_calculation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ For more information about estimating these metrics, refer to the :ref:`multicla
1717

1818
We also support the following *complex* metric for multiclass classification performance calculation:
1919

20-
* **confusion_matrix**
20+
* **confusion_matrix**
2121

2222
For more information about estimating this metrics, refer to the :ref:`multiclass-confusion-matrix-estimation` section.
2323

docs/tutorials/performance_calculation/multiclass_performance_calculation/confusion_matrix_calculation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ The results can be plotted for visual inspection. Our plot contains several key
128128
* *The purple step plot* shows the performance in each chunk of the analysis period. Thick squared point
129129
markers indicate the middle of these chunks.
130130

131-
* *The blue step plot* shows the performance in each chunk of the reference period. Thick squared point markers indicate
131+
* *The blue step plot* shows the performance in each chunk of the reference period. Thick squared point markers indicate
132132
the middle of these chunks.
133133

134134
* *The gray vertical line* splits the reference and analysis periods.

docs/tutorials/performance_calculation/multiclass_performance_calculation/standard_metric_calculation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,4 @@ what feature changes may be contributing to any performance changes. We can also
133133
and :ref:`compare it with the estimated results<compare_estimated_and_realized_performance>`.
134134

135135
It is also wise to check whether the model's performance is satisfactory
136-
according to business requirements. This is an ad-hoc investigation that is not covered by NannyML.
136+
according to business requirements. This is an ad-hoc investigation that is not covered by NannyML.

nannyml/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def _column_is_categorical(column: pd.Series) -> bool:
537537
def _remove_nans(data: pd.Series) -> pd.Series:
538538
...
539539

540+
540541
@overload
541542
def _remove_nans(data: pd.DataFrame, columns: Optional[Iterable[Union[str, Iterable[str]]]]) -> pd.DataFrame:
542543
...

nannyml/calibration.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
"""Calibrating model scores into probabilities."""
77
import abc
8-
from typing import Any, Callable, List, Optional, Tuple
8+
import warnings
9+
from typing import Any, Callable, Dict, List, Tuple, Type
910

1011
import numpy as np
1112
import pandas as pd
@@ -45,10 +46,10 @@ def calibrate(self, y_pred_proba: np.ndarray):
4546
class CalibratorFactory:
4647
"""Factory class to aid in construction of Calibrators."""
4748

48-
_calibrators = {'isotonic': lambda args: IsotonicCalibrator()}
49+
_registry: Dict[str, Type[Calibrator]] = {}
4950

5051
@classmethod
51-
def register_calibrator(cls, key: str, create_calibrator: Callable):
52+
def register_calibrator(cls, key: str, calibrator: Type[Calibrator]):
5253
"""Registers a new calibrator to the index.
5354
5455
This index associates a certain key with a function that can be used to construct a new Calibrator instance.
@@ -58,17 +59,28 @@ def register_calibrator(cls, key: str, create_calibrator: Callable):
5859
key: str
5960
The key used to retrieve a Calibrator. When providing a key that is already in the index, the value
6061
will be overwritten.
61-
create_calibrator: Callable
62+
calibrator: Type[Calibrator]
6263
A function that - given a ``**kwargs`` argument - create a new instance of a Calibrator subclass.
6364
6465
Examples
6566
--------
66-
>>> CalibratorFactory.register_calibrator('isotonic', lambda kwargs: IsotonicCalibrator())
67+
>>> CalibratorFactory.register_calibrator('isotonic', IsotonicCalibrator)
6768
"""
68-
cls._calibrators[key] = create_calibrator
69+
cls._registry[key] = calibrator
6970

7071
@classmethod
71-
def create(cls, key: Optional[str], **kwargs):
72+
def register(cls, key: str) -> Callable:
73+
def inner_wrapper(wrapped_class: Type[Calibrator]) -> Type[Calibrator]:
74+
if key in cls._registry:
75+
warnings.warn(f"re-registering calibrator with key '{key}'")
76+
77+
cls._registry[key] = wrapped_class
78+
return wrapped_class
79+
80+
return inner_wrapper
81+
82+
@classmethod
83+
def create(cls, key: str = 'isotonic', **kwargs):
7284
"""Creates a new Calibrator given a key value and optional keyword args.
7385
7486
If the provided key equals ``None``, then a new instance of the default Calibrator (IsotonicCalibrator)
@@ -78,7 +90,7 @@ def create(cls, key: Optional[str], **kwargs):
7890
7991
Parameters
8092
----------
81-
key : str
93+
key : str, default='isotonic'
8294
The key used to retrieve a Calibrator. When providing a key that is already in the index, the value
8395
will be overwritten.
8496
kwargs : dict
@@ -94,18 +106,18 @@ def create(cls, key: Optional[str], **kwargs):
94106
--------
95107
>>> calibrator = CalibratorFactory.create('isotonic', kwargs={'foo': 'bar'})
96108
"""
97-
default = IsotonicCalibrator()
98-
if key is None:
99-
return default
100-
101-
if key not in cls._calibrators:
109+
if key not in cls._registry:
102110
raise InvalidArgumentsException(
103-
f"calibrator {key} unknown. " f"Please provide one of the following: {cls._calibrators.keys()}"
111+
f"calibrator '{key}' unknown. " f"Please provide one of the following: {cls._registry.keys()}"
104112
)
105113

106-
return cls._calibrators.get(key, default)
114+
calibrator_class = cls._registry.get(key)
115+
assert calibrator_class
116+
117+
return calibrator_class(**kwargs)
107118

108119

120+
@CalibratorFactory.register('isotonic')
109121
class IsotonicCalibrator(Calibrator):
110122
"""Calibrates using IsotonicRegression model."""
111123

nannyml/drift/ranker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def _validate_drift_result(rankable_result: RankableResult):
8080
raise InvalidArgumentsException('rankable_result contains no data to use for ranking')
8181

8282
if isinstance(rankable_result, UnivariateResults):
83-
8483
if len(rankable_result.categorical_method_names) > 1:
8584
raise InvalidArgumentsException(
8685
f"Only one categorical drift method should be present in the univariate results."

nannyml/drift/univariate/methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from scipy.stats import chi2_contingency, ks_2samp, wasserstein_distance
3030

3131
from nannyml._typing import Self
32-
from nannyml.base import _remove_nans, _column_is_categorical
32+
from nannyml.base import _column_is_categorical, _remove_nans
3333
from nannyml.chunk import Chunker
3434
from nannyml.exceptions import InvalidArgumentsException, NotFittedException
3535
from nannyml.thresholds import Threshold, calculate_threshold_values

nannyml/performance_calculation/metrics/binary_classification.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score
1010

1111
from nannyml._typing import ProblemType
12-
from nannyml.base import _remove_nans, _list_missing
12+
from nannyml.base import _list_missing, _remove_nans
1313
from nannyml.chunk import Chunk, Chunker
1414
from nannyml.exceptions import InvalidArgumentsException
1515
from nannyml.performance_calculation.metrics.base import Metric, MetricFactory
@@ -544,17 +544,15 @@ def _calculate(self, data: pd.DataFrame):
544544
tn_value = self.business_value_matrix[0, 0]
545545
fp_value = self.business_value_matrix[0, 1]
546546
fn_value = self.business_value_matrix[1, 0]
547-
bv_array = np.array(
548-
[[tn_value,fp_value], [fn_value,tp_value]]
549-
)
547+
bv_array = np.array([[tn_value, fp_value], [fn_value, tp_value]])
550548

551549
cm = confusion_matrix(y_true, y_pred)
552550
if self.normalize_business_value == 'per_prediction':
553551
with np.errstate(all="ignore"):
554552
cm = cm / cm.sum(axis=0, keepdims=True)
555553
cm = np.nan_to_num(cm)
556554

557-
return (bv_array*cm).sum()
555+
return (bv_array * cm).sum()
558556

559557
def _sampling_error(self, data: pd.DataFrame) -> float:
560558
return business_value_sampling_error(self._sampling_error_components, data)

nannyml/performance_calculation/metrics/multiclass_classification.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sklearn.preprocessing import LabelBinarizer, label_binarize
2525

2626
from nannyml._typing import ProblemType, class_labels, model_output_column_names
27-
from nannyml.base import _remove_nans, _list_missing
27+
from nannyml.base import _list_missing, _remove_nans
2828
from nannyml.chunk import Chunker
2929
from nannyml.exceptions import InvalidArgumentsException
3030
from nannyml.performance_calculation.metrics.base import Metric, MetricFactory
@@ -35,14 +35,14 @@
3535
auroc_sampling_error_components,
3636
f1_sampling_error,
3737
f1_sampling_error_components,
38+
multiclass_confusion_matrix_sampling_error,
39+
multiclass_confusion_matrix_sampling_error_components,
3840
precision_sampling_error,
3941
precision_sampling_error_components,
4042
recall_sampling_error,
4143
recall_sampling_error_components,
4244
specificity_sampling_error,
4345
specificity_sampling_error_components,
44-
multiclass_confusion_matrix_sampling_error,
45-
multiclass_confusion_matrix_sampling_error_components,
4646
)
4747
from nannyml.thresholds import Threshold, calculate_threshold_values
4848

@@ -588,7 +588,6 @@ def __init__(
588588
normalize_confusion_matrix: Optional[str] = None,
589589
**kwargs,
590590
):
591-
592591
"""Creates a new confusion matrix instance."""
593592
super().__init__(
594593
name='confusion_matrix',
@@ -607,7 +606,6 @@ def __str__(self):
607606
return "confusion_matrix"
608607

609608
def fit(self, reference_data: pd.DataFrame, chunker: Chunker):
610-
611609
# _fit
612610
# realized perf on chunks
613611
# set thresholds
@@ -700,7 +698,6 @@ def _calculate(self, data: pd.DataFrame) -> Union[np.ndarray, float]:
700698
return cm
701699

702700
def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict[str, Union[float, bool]]:
703-
704701
if self.classes is None:
705702
raise ValueError("classes must be set before calling this method")
706703

@@ -714,7 +711,6 @@ def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict[str, Union[float, b
714711

715712
for true_class in self.classes:
716713
for pred_class in self.classes:
717-
718714
column_name = f'true_{true_class}_pred_{pred_class}'
719715

720716
chunk_record[f"{column_name}_sampling_error"] = sampling_errors[

nannyml/performance_calculation/metrics/regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from nannyml._typing import ProblemType
16-
from nannyml.base import _remove_nans, _list_missing, _raise_exception_for_negative_values
16+
from nannyml.base import _list_missing, _raise_exception_for_negative_values, _remove_nans
1717
from nannyml.performance_calculation.metrics.base import Metric, MetricFactory
1818
from nannyml.sampling_error.regression import (
1919
mae_sampling_error,

nannyml/performance_estimation/confidence_based/cbpe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
chunk_number: Optional[int] = None,
8484
chunk_period: Optional[str] = None,
8585
chunker: Optional[Chunker] = None,
86-
calibration: Optional[str] = None,
86+
calibration: str = 'isotonic',
8787
calibrator: Optional[Calibrator] = None,
8888
thresholds: Optional[Dict[str, Threshold]] = None,
8989
normalize_confusion_matrix: Optional[str] = None,

nannyml/performance_estimation/confidence_based/metrics.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,16 +1573,14 @@ def _realized_performance(self, data: pd.DataFrame) -> float:
15731573
tn_value = self.business_value_matrix[0, 0]
15741574
fp_value = self.business_value_matrix[0, 1]
15751575
fn_value = self.business_value_matrix[1, 0]
1576-
bv_array = np.array(
1577-
[[tn_value,fp_value], [fn_value,tp_value]]
1578-
)
1576+
bv_array = np.array([[tn_value, fp_value], [fn_value, tp_value]])
15791577

15801578
cm = confusion_matrix(y_true, y_pred)
15811579
if self.normalize_business_value == 'per_prediction':
15821580
with np.errstate(all="ignore"):
15831581
cm = cm / cm.sum(axis=0, keepdims=True)
15841582
cm = np.nan_to_num(cm)
1585-
return (bv_array*cm).sum()
1583+
return (bv_array * cm).sum()
15861584

15871585
def _estimate(self, chunk_data: pd.DataFrame) -> float:
15881586
y_pred_proba = chunk_data[self.y_pred_proba]
@@ -1630,9 +1628,7 @@ def estimate_business_value(
16301628
est_tp_ratio = np.mean(np.where(y_pred == 1, y_pred_proba, 0))
16311629
est_fp_ratio = np.mean(np.where(y_pred == 1, 1 - y_pred_proba, 0))
16321630
est_fn_ratio = np.mean(np.where(y_pred == 0, y_pred_proba, 0))
1633-
cm = np.array(
1634-
[[est_tn_ratio, est_fp_ratio], [est_fn_ratio, est_tp_ratio]]
1635-
)*len(y_pred)
1631+
cm = np.array([[est_tn_ratio, est_fp_ratio], [est_fn_ratio, est_tp_ratio]]) * len(y_pred)
16361632
if normalize_business_value == 'per_prediction':
16371633
with np.errstate(all="ignore"):
16381634
cm = cm / cm.sum(axis=0, keepdims=True)
@@ -1642,11 +1638,9 @@ def estimate_business_value(
16421638
tn_value = business_value_matrix[0, 0]
16431639
fp_value = business_value_matrix[0, 1]
16441640
fn_value = business_value_matrix[1, 0]
1645-
bv_array = np.array(
1646-
[[tn_value,fp_value], [fn_value,tp_value]]
1647-
)
1641+
bv_array = np.array([[tn_value, fp_value], [fn_value, tp_value]])
16481642

1649-
return (bv_array*cm).sum()
1643+
return (bv_array * cm).sum()
16501644

16511645

16521646
def _get_binarized_multiclass_predictions(data: pd.DataFrame, y_pred: str, y_pred_proba: ModelOutputsType):
@@ -2108,7 +2102,6 @@ def __init__(
21082102
normalize_confusion_matrix: Optional[str] = None,
21092103
**kwargs,
21102104
):
2111-
21122105
if isinstance(y_pred_proba, str):
21132106
raise ValueError(
21142107
"y_pred_proba must be a dictionary with class labels as keys and pred_proba column names as values"
@@ -2167,7 +2160,6 @@ def fit(self, reference_data: pd.DataFrame): # override the superclass fit meth
21672160
return
21682161

21692162
def _fit(self, reference_data: pd.DataFrame):
2170-
21712163
self._confusion_matrix_sampling_error_components = mse.multiclass_confusion_matrix_sampling_error_components(
21722164
y_true_reference=reference_data[self.y_true],
21732165
y_pred_reference=reference_data[self.y_pred],
@@ -2177,7 +2169,6 @@ def _fit(self, reference_data: pd.DataFrame):
21772169
def _multiclass_confusion_matrix_alert_thresholds(
21782170
self, reference_chunks: List[Chunk]
21792171
) -> Dict[str, Tuple[Optional[float], Optional[float]]]:
2180-
21812172
realized_chunk_performance = np.asarray(
21822173
[self._multi_class_confusion_matrix_realized_performance(chunk.data) for chunk in reference_chunks]
21832174
)
@@ -2224,22 +2215,19 @@ def _multiclass_confusion_matrix_confidence_deviations(
22242215
self,
22252216
reference_chunks: List[Chunk],
22262217
) -> Dict[str, float]:
2227-
22282218
confidence_deviations = {}
22292219

22302220
num_classes = len(self.classes)
22312221

22322222
for i in range(num_classes):
22332223
for j in range(num_classes):
2234-
22352224
confidence_deviations[f'true_{self.classes[i]}_pred_{self.classes[j]}'] = np.std(
22362225
[self._get_multiclass_confusion_matrix_estimate(chunk.data)[i, j] for chunk in reference_chunks]
22372226
)
22382227

22392228
return confidence_deviations
22402229

22412230
def _get_multiclass_confusion_matrix_estimate(self, chunk_data: pd.DataFrame) -> np.ndarray:
2242-
22432231
if isinstance(self.y_pred_proba, str):
22442232
raise ValueError(
22452233
"y_pred_proba must be a dictionary with class labels as keys and pred_proba column names as values"
@@ -2282,7 +2270,6 @@ def _get_multiclass_confusion_matrix_estimate(self, chunk_data: pd.DataFrame) ->
22822270
return normalized_est_confusion_matrix
22832271

22842272
def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict:
2285-
22862273
chunk_record = {}
22872274

22882275
estimated_cm = self._get_multiclass_confusion_matrix_estimate(chunk_data)
@@ -2295,7 +2282,6 @@ def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict:
22952282

22962283
for true_class in self.classes:
22972284
for pred_class in self.classes:
2298-
22992285
chunk_record[f'estimated_true_{true_class}_pred_{pred_class}'] = estimated_cm[
23002286
self.classes.index(true_class), self.classes.index(pred_class)
23012287
]

nannyml/plots/blueprints/comparisons.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def _plot_compare_step_to_step( # noqa: C901
225225
metric_2_color=Colors.BLUE_SKY_CRAYOLA,
226226
**kwargs,
227227
) -> Figure:
228-
229228
_metric_1_kwargs = {k.replace('metric_1_', ''): v for k, v in kwargs.items() if k.startswith('metric_1_')}
230229
_metric_2_kwargs = {k.replace('metric_2_', ''): v for k, v in kwargs.items() if k.startswith('metric_2_')}
231230

@@ -442,7 +441,6 @@ def _plot_compare_step_to_step( # noqa: C901
442441
# endregion
443442

444443
if has_analysis_results:
445-
446444
# region analysis metric 1
447445

448446
_hover = hover or Hover(
@@ -693,7 +691,6 @@ def _is_estimated_result(result: Result) -> bool:
693691

694692
class ResultComparison:
695693
def __init__(self, result: Result, other: Result, plot_kwargs: Dict[str, Any], title: Optional[str] = None):
696-
697694
if len(result.keys()) != 1 or len(result.keys()) != 1:
698695
raise InvalidArgumentsException(
699696
f"you're comparing {len(result.keys())} metrics to {len(result.keys())} "

0 commit comments

Comments
 (0)