Skip to content

Commit

Permalink
Optional prediction column: the return (#381)
Browse files Browse the repository at this point in the history
* Make y_pred column optional for estimated and realized performance + tests

* Update data requirements docs

* Fix _list_missing issues + add "run with prediction" tests

* Fix flake8 issues

* Metric classes were not able to deal with optional predictions yet.

Updated the base class and (unfortunately) added asserts to keep mypy from complaining. Didn't want that logic to leak into remove_nans
  • Loading branch information
nnansters authored Apr 30, 2024
1 parent a52c06d commit fa78f5a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
2 changes: 1 addition & 1 deletion nannyml/performance_calculation/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(
self,
name: str,
y_true: str,
y_pred: str,
components: List[Tuple[str, str]],
threshold: Threshold,
y_pred: Optional[str] = None,
y_pred_proba: Optional[Union[str, Dict[str, str]]] = None,
upper_threshold_limit: Optional[float] = None,
lower_threshold_limit: Optional[float] = None,
Expand Down
26 changes: 19 additions & 7 deletions nannyml/performance_calculation/metrics/binary_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class BinaryClassificationAUROC(Metric):
def __init__(
self,
y_true: str,
y_pred: str,
threshold: Threshold,
y_pred: Optional[str] = None,
y_pred_proba: Optional[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -97,6 +97,8 @@ def __str__(self):
def _fit(self, reference_data: pd.DataFrame):
"""Metric _fit implementation on reference data."""
_list_missing([self.y_true, self.y_pred_proba], list(reference_data.columns))
# we don't want to count missing rows for sampling error
reference_data = _remove_nans(reference_data, (self.y_true,))
self._sampling_error_components = auroc_sampling_error_components(
y_true_reference=reference_data[self.y_true],
y_pred_proba_reference=reference_data[self.y_pred_proba],
Expand All @@ -105,10 +107,10 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred_proba], list(data.columns))
data = _remove_nans(data, (self.y_true, self.y_pred))
data = _remove_nans(data, (self.y_true,))

y_true = data[self.y_true]
y_pred = data[self.y_pred_proba]
y_pred_proba = data[self.y_pred_proba]

if y_true.nunique() <= 1:
warnings.warn(
Expand All @@ -117,7 +119,7 @@ def _calculate(self, data: pd.DataFrame):
)
return np.NaN
else:
return roc_auc_score(y_true, y_pred)
return roc_auc_score(y_true, y_pred_proba)

def _sampling_error(self, data: pd.DataFrame) -> float:
return auroc_sampling_error(self._sampling_error_components, data)
Expand All @@ -133,8 +135,8 @@ class BinaryClassificationAP(Metric):
def __init__(
self,
y_true: str,
y_pred: str,
threshold: Threshold,
y_pred: Optional[str] = None,
y_pred_proba: Optional[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -174,7 +176,7 @@ def _fit(self, reference_data: pd.DataFrame):
"""Metric _fit implementation on reference data."""
_list_missing([self.y_true, self.y_pred_proba], list(reference_data.columns))
# we don't want to count missing rows for sampling error
reference_data = _remove_nans(reference_data, (self.y_true, self.y_pred))
reference_data = _remove_nans(reference_data, (self.y_true,))

if 1 not in reference_data[self.y_true].unique():
self._sampling_error_components = np.NaN, 0
Expand All @@ -187,7 +189,7 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred_proba], list(data.columns))
data = _remove_nans(data, (self.y_true, self.y_pred))
data = _remove_nans(data, (self.y_true,))

y_true = data[self.y_true]
y_pred_proba = data[self.y_pred_proba]
Expand Down Expand Up @@ -259,6 +261,7 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -335,6 +338,7 @@ def _fit(self, reference_data: pd.DataFrame):

def _calculate(self, data: pd.DataFrame):
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -411,6 +415,7 @@ def _fit(self, reference_data: pd.DataFrame):

def _calculate(self, data: pd.DataFrame):
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -487,6 +492,7 @@ def _fit(self, reference_data: pd.DataFrame):

def _calculate(self, data: pd.DataFrame):
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -564,6 +570,7 @@ def _fit(self, reference_data: pd.DataFrame):

def _calculate(self, data: pd.DataFrame):
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -674,6 +681,7 @@ def _fit(self, reference_data: pd.DataFrame):

def _calculate(self, data: pd.DataFrame):
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -858,6 +866,7 @@ def _fit(self, reference_data: pd.DataFrame):

def _calculate_true_positives(self, data: pd.DataFrame) -> float:
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand All @@ -882,6 +891,7 @@ def _calculate_true_positives(self, data: pd.DataFrame) -> float:

def _calculate_true_negatives(self, data: pd.DataFrame) -> float:
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand All @@ -906,6 +916,7 @@ def _calculate_true_negatives(self, data: pd.DataFrame) -> float:

def _calculate_false_positives(self, data: pd.DataFrame) -> float:
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand All @@ -930,6 +941,7 @@ def _calculate_false_positives(self, data: pd.DataFrame) -> float:

def _calculate_false_negatives(self, data: pd.DataFrame) -> float:
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _calculate(self, data: pd.DataFrame):
)

_list_missing([self.y_true, self.y_pred], data)
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

labels = sorted(list(self.y_pred_proba.keys()))
Expand Down Expand Up @@ -306,6 +307,7 @@ def _calculate(self, data: pd.DataFrame):
)

_list_missing([self.y_true, self.y_pred], data)
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

labels = sorted(list(self.y_pred_proba.keys()))
Expand Down Expand Up @@ -401,6 +403,7 @@ def _calculate(self, data: pd.DataFrame):
)

_list_missing([self.y_true, self.y_pred], data)
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

labels = sorted(list(self.y_pred_proba.keys()))
Expand Down Expand Up @@ -496,6 +499,7 @@ def _calculate(self, data: pd.DataFrame):
)

_list_missing([self.y_true, self.y_pred], data)
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

labels = sorted(list(self.y_pred_proba.keys()))
Expand Down Expand Up @@ -588,6 +592,7 @@ def _fit(self, reference_data: pd.DataFrame):

def _calculate(self, data: pd.DataFrame):
_list_missing([self.y_true, self.y_pred], data)
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down
6 changes: 6 additions & 0 deletions nannyml/performance_calculation/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -139,6 +140,7 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -201,6 +203,7 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -263,6 +266,7 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -330,6 +334,7 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down Expand Up @@ -392,6 +397,7 @@ def _fit(self, reference_data: pd.DataFrame):
def _calculate(self, data: pd.DataFrame):
"""Redefine to handle NaNs and edge cases."""
_list_missing([self.y_true, self.y_pred], list(data.columns))
assert self.y_pred
data = _remove_nans(data, (self.y_true, self.y_pred))

y_true = data[self.y_true]
Expand Down

0 comments on commit fa78f5a

Please sign in to comment.