Skip to content

Commit 93ac6e7

Browse files
Fix handling NaN values when fitting JS univariate drift (#340)
* Add column & method for univariate fitting errors * Refactor to use single data cleaning method * Filter NaN's when fitting JS * Refactor data cleaning to accept columns argument Previously the data cleaning method operated by accepting multiple dataframes and inspecting each dataframe separetely for `NaN`'s. Depending on how the data is processed after cleaning, splitting columns into separate dataframes can be rather annoying. To avoid that this commit changes the method to accept a single dataframe and a columns argument. The columns argument specifies which column subsets should be inspected for `NaN`'s, enabling the same behaviour using a more convenient syntax. * Remove errors and use warning behaviour instead The performance calculator for binary classification had checks in place to generate an exception if the prediction column contains nothing but `NaN`'s. This behaviour contradicts the warning functionality that is in the same functions that would should return `NaN` and issue a warning. It is also inconsistent with other calculators which do issue a warning instead of raising an error. This commit removes the errors and relies on the existing warning functionality. * Refactor more data cleaning methods * Deal with mypy overload issue --------- Co-authored-by: Niels Nuyttens <niels@nannyml.com>
1 parent 51105c7 commit 93ac6e7

File tree

9 files changed

+126
-174
lines changed

9 files changed

+126
-174
lines changed

nannyml/base.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import copy
99
import logging
1010
from abc import ABC, abstractmethod
11-
from typing import Generic, List, Optional, Tuple, TypeVar, Union
11+
from typing import Generic, Iterable, List, Optional, Tuple, TypeVar, Union, overload
1212

1313
import numpy as np
1414
import pandas as pd
@@ -533,12 +533,38 @@ def _column_is_categorical(column: pd.Series) -> bool:
533533
return column.dtype in ['object', 'string', 'category', 'bool']
534534

535535

536-
def _remove_missing_data(column: pd.Series):
537-
if isinstance(column, pd.Series):
538-
column = column.dropna().reset_index(drop=True)
536+
@overload
537+
def _remove_nans(data: pd.Series) -> pd.Series:
538+
...
539+
540+
@overload
541+
def _remove_nans(data: pd.DataFrame, columns: Optional[Iterable[Union[str, Iterable[str]]]]) -> pd.DataFrame:
542+
...
543+
544+
545+
def _remove_nans(
546+
data: Union[pd.Series, pd.DataFrame], columns: Optional[Iterable[Union[str, Iterable[str]]]] = None
547+
) -> Tuple[pd.DataFrame, ...]:
548+
"""Remove rows with NaN values in the specified columns.
549+
550+
If no columns are given, drop rows with NaN values in any column. If columns are given, drop rows with NaN values
551+
in the specified columns. If a set of columns is given, drop rows with NaN values in all of the columns in the set.
552+
"""
553+
# If no columns are given, drop rows with NaN values in any columns
554+
if columns is None:
555+
mask = ~data.isna()
556+
if isinstance(mask, pd.DataFrame):
557+
mask = mask.all(axis=1)
539558
else:
540-
column = column[~np.isnan(column)]
541-
return column
559+
mask = np.ones(len(data), dtype=bool)
560+
for column_selector in columns:
561+
nans = data[column_selector].isna()
562+
if isinstance(nans, pd.DataFrame):
563+
nans = nans.all(axis=1)
564+
mask &= ~nans
565+
566+
# NaN values have been dropped. Try to infer types again
567+
return data[mask].reset_index(drop=True).infer_objects()
542568

543569

544570
def _column_is_continuous(column: pd.Series) -> bool:

nannyml/drift/univariate/calculator.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from nannyml.chunk import Chunker
4040
from nannyml.drift.univariate.methods import FeatureType, Method, MethodFactory
4141
from nannyml.drift.univariate.result import Result
42-
from nannyml.exceptions import InvalidArgumentsException
42+
from nannyml.exceptions import CalculatorException, InvalidArgumentsException
4343
from nannyml.thresholds import ConstantThreshold, StandardDeviationThreshold, Threshold
4444
from nannyml.usage_logging import UsageEvent, log_usage
4545

@@ -271,34 +271,45 @@ def _fit(self, reference_data: pd.DataFrame, *args, **kwargs) -> UnivariateDrift
271271
if column_name not in self.categorical_column_names:
272272
self.categorical_column_names.append(column_name)
273273

274+
timestamps = reference_data[self.timestamp_column_name] if self.timestamp_column_name else None
274275
for column_name in self.continuous_column_names:
275-
self._column_to_models_mapping[column_name] += [
276-
MethodFactory.create(
277-
key=method,
278-
feature_type=FeatureType.CONTINUOUS,
279-
chunker=self.chunker,
280-
computation_params=self.computation_params or {},
281-
threshold=self.thresholds[method],
282-
).fit(
283-
reference_data=reference_data[column_name],
284-
timestamps=reference_data[self.timestamp_column_name] if self.timestamp_column_name else None,
285-
)
286-
for method in self.continuous_method_names
287-
]
276+
methods = []
277+
for method in self.continuous_method_names:
278+
try:
279+
methods.append(
280+
MethodFactory.create(
281+
key=method,
282+
feature_type=FeatureType.CONTINUOUS,
283+
chunker=self.chunker,
284+
computation_params=self.computation_params or {},
285+
threshold=self.thresholds[method],
286+
).fit(
287+
reference_data=reference_data[column_name],
288+
timestamps=timestamps,
289+
)
290+
)
291+
except Exception as ex:
292+
raise CalculatorException(f"Failed to fit method {method} for column {column_name}: {ex!r}") from ex
293+
self._column_to_models_mapping[column_name] = methods
288294

289295
for column_name in self.categorical_column_names:
290-
self._column_to_models_mapping[column_name] += [
291-
MethodFactory.create(
292-
key=method,
293-
feature_type=FeatureType.CATEGORICAL,
294-
chunker=self.chunker,
295-
threshold=self.thresholds[method],
296-
).fit(
297-
reference_data=reference_data[column_name],
298-
timestamps=reference_data[self.timestamp_column_name] if self.timestamp_column_name else None,
299-
)
300-
for method in self.categorical_method_names
301-
]
296+
methods = []
297+
for method in self.categorical_method_names:
298+
try:
299+
methods.append(
300+
MethodFactory.create(
301+
key=method,
302+
feature_type=FeatureType.CATEGORICAL,
303+
chunker=self.chunker,
304+
threshold=self.thresholds[method],
305+
).fit(
306+
reference_data=reference_data[column_name],
307+
timestamps=timestamps,
308+
)
309+
)
310+
except Exception as ex:
311+
raise CalculatorException(f"Failed to fit method {method} for column {column_name}: {ex!r}") from ex
312+
self._column_to_models_mapping[column_name] = methods
302313

303314
self.result = self._calculate(reference_data)
304315
self.result.data['chunk', 'chunk', 'period'] = 'reference'

nannyml/drift/univariate/methods.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from scipy.stats import chi2_contingency, ks_2samp, wasserstein_distance
3030

3131
from nannyml._typing import Self
32-
from nannyml.base import _column_is_categorical, _remove_missing_data
32+
from nannyml.base import _remove_nans, _column_is_categorical
3333
from nannyml.chunk import Chunker
3434
from nannyml.exceptions import InvalidArgumentsException, NotFittedException
3535
from nannyml.thresholds import Threshold, calculate_threshold_values
@@ -278,6 +278,7 @@ def __init__(self, **kwargs) -> None:
278278
self._reference_proba_in_bins: np.ndarray
279279

280280
def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None):
281+
reference_data = _remove_nans(reference_data)
281282
if _column_is_categorical(reference_data):
282283
treat_as_type = 'cat'
283284
else:
@@ -305,7 +306,7 @@ def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None
305306

306307
def _calculate(self, data: pd.Series):
307308
reference_proba_in_bins = copy(self._reference_proba_in_bins)
308-
data = _remove_missing_data(data)
309+
data = _remove_nans(data)
309310
if data.empty:
310311
return np.nan
311312
if self._treat_as_type == 'cont':
@@ -374,7 +375,7 @@ def __init__(self, **kwargs) -> None:
374375
self.n_bins = kwargs['computation_params'].get('n_bins', 10_000)
375376

376377
def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None) -> Self:
377-
reference_data = _remove_missing_data(reference_data)
378+
reference_data = _remove_nans(reference_data)
378379
if (self.calculation_method == 'auto' and len(reference_data) < 10_000) or self.calculation_method == 'exact':
379380
self._reference_data = reference_data
380381
else:
@@ -389,7 +390,7 @@ def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None
389390
return self
390391

391392
def _calculate(self, data: pd.Series):
392-
data = _remove_missing_data(data)
393+
data = _remove_nans(data)
393394
if data.empty:
394395
return np.nan
395396
if not self._fitted:
@@ -443,13 +444,13 @@ def __init__(self, **kwargs) -> None:
443444
self._fitted = False
444445

445446
def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None) -> Self:
446-
reference_data = _remove_missing_data(reference_data)
447+
reference_data = _remove_nans(reference_data)
447448
self._reference_data_vcs = reference_data.value_counts().loc[lambda v: v != 0]
448449
self._fitted = True
449450
return self
450451

451452
def _calculate(self, data: pd.Series):
452-
data = _remove_missing_data(data)
453+
data = _remove_nans(data)
453454
if data.empty:
454455
return np.nan
455456
if not self._fitted:
@@ -505,7 +506,7 @@ def __init__(self, **kwargs) -> None:
505506
self._reference_proba: Optional[dict] = None
506507

507508
def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None) -> Self:
508-
reference_data = _remove_missing_data(reference_data)
509+
reference_data = _remove_nans(reference_data)
509510
ref_labels = reference_data.unique()
510511
self._reference_proba = {label: (reference_data == label).sum() / len(reference_data) for label in ref_labels}
511512

@@ -516,7 +517,7 @@ def _calculate(self, data: pd.Series):
516517
raise NotFittedException(
517518
"tried to call 'calculate' on an unfitted method " f"{self.display_name}. Please run 'fit' first"
518519
)
519-
data = _remove_missing_data(data)
520+
data = _remove_nans(data)
520521
if data.empty:
521522
return np.nan
522523
data_labels = data.unique()
@@ -574,7 +575,7 @@ def __init__(self, **kwargs) -> None:
574575
self.n_bins = kwargs['computation_params'].get('n_bins', 10_000)
575576

576577
def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None) -> Self:
577-
reference_data = _remove_missing_data(reference_data)
578+
reference_data = _remove_nans(reference_data)
578579
if (self.calculation_method == 'auto' and len(reference_data) < 10_000) or self.calculation_method == 'exact':
579580
self._reference_data = reference_data
580581
else:
@@ -592,7 +593,7 @@ def _calculate(self, data: pd.Series):
592593
raise NotFittedException(
593594
"tried to call 'calculate' on an unfitted method " f"{self.display_name}. Please run 'fit' first"
594595
)
595-
data = _remove_missing_data(data)
596+
data = _remove_nans(data)
596597
if data.empty:
597598
return np.nan
598599
if (
@@ -668,7 +669,7 @@ def __init__(self, **kwargs) -> None:
668669
self._reference_proba_in_bins: np.ndarray
669670

670671
def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None) -> Self:
671-
reference_data = _remove_missing_data(reference_data)
672+
reference_data = _remove_nans(reference_data)
672673
if _column_is_categorical(reference_data):
673674
treat_as_type = 'cat'
674675
else:
@@ -695,7 +696,7 @@ def _fit(self, reference_data: pd.Series, timestamps: Optional[pd.Series] = None
695696
return self
696697

697698
def _calculate(self, data: pd.Series):
698-
data = _remove_missing_data(data)
699+
data = _remove_nans(data)
699700
if data.empty:
700701
return np.nan
701702
reference_proba_in_bins = copy(self._reference_proba_in_bins)

nannyml/performance_calculation/metrics/base.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -255,25 +255,3 @@ def inner_wrapper(wrapped_class: Type[Metric]) -> Type[Metric]:
255255
return wrapped_class
256256

257257
return inner_wrapper
258-
259-
260-
def _common_data_cleaning(y_true: pd.Series, y_pred: Union[pd.Series, pd.DataFrame]):
261-
y_true, y_pred = (
262-
y_true.reset_index(drop=True),
263-
y_pred.reset_index(drop=True),
264-
)
265-
266-
if isinstance(y_pred, pd.DataFrame):
267-
y_true = y_true[~y_pred.isna().all(axis=1)]
268-
else:
269-
y_true = y_true[~y_pred.isna()]
270-
y_pred.dropna(inplace=True)
271-
272-
y_pred = y_pred[~y_true.isna()]
273-
y_true.dropna(inplace=True)
274-
275-
# NaN values have been dropped. Try to infer types again
276-
y_pred = y_pred.infer_objects()
277-
y_true = y_true.infer_objects()
278-
279-
return y_true, y_pred

0 commit comments

Comments
 (0)