Skip to content

Commit 123d3b7

Browse files
authored
Merge pull request #408 from jnesfield/main
Added Numerical Range Data Quality Check
2 parents 8507064 + 5730807 commit 123d3b7

File tree

11 files changed

+562
-17
lines changed

11 files changed

+562
-17
lines changed

nannyml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from .calibration import Calibrator, IsotonicCalibrator, needs_calibration
4141
from .chunk import Chunk, Chunker, CountBasedChunker, DefaultChunker, PeriodBasedChunker, SizeBasedChunker
42-
from .data_quality import MissingValuesCalculator, UnseenValuesCalculator
42+
from .data_quality import MissingValuesCalculator, UnseenValuesCalculator, NumericalRangeCalculator
4343
from .datasets import (
4444
load_modified_california_housing_dataset,
4545
load_synthetic_binary_classification_dataset,

nannyml/data_quality/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77

88
from .missing import MissingValuesCalculator
99
from .unseen import UnseenValuesCalculator
10+
from .range import NumericalRangeCalculator

nannyml/data_quality/missing/calculator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def __init__(
7676
... timestamp_column_name='timestamp',
7777
... ).fit(reference_df)
7878
>>> res = calc.calculate(analysis_df)
79-
>>> for column_name in res.feature_column_names:
80-
... res = res.filter(period='analysis', column_name=column_name).plot().show()
79+
>>> res.filter(period='analysis').plot().show()
8180
"""
8281
super(MissingValuesCalculator, self).__init__(
8382
chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name

nannyml/data_quality/missing/result.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def plot(
7979
... timestamp_column_name='timestamp',
8080
... ).fit(reference)
8181
>>> res = calc.calculate(analysis)
82-
>>> for column_name in res.column_names:
83-
... res = res.filter(period='analysis', column_name=column_name).plot().show()
82+
>>> res.filter(period='analysis').plot().show()
8483
8584
"""
8685
return plot_metrics(
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Author: James Nesfield <jamesnesfield@live.com>
2+
#
3+
# License: Apache Software License 2.0
4+
5+
"""Package containing the Data Quality Calculators implementation."""
6+
7+
from .calculator import NumericalRangeCalculator
8+
from .result import Result
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# Author: James Nesfield <jamesnesfield@live.com>
2+
#
3+
# License: Apache Software License 2.0
4+
5+
"""Continuous numerical variable range monitor to ensure range supplied is within training bounds."""
6+
7+
from typing import Any, Dict, List, Optional, Union
8+
9+
import numpy as np
10+
import pandas as pd
11+
from pandas import MultiIndex
12+
13+
from nannyml.base import AbstractCalculator, _list_missing, _split_features_by_type
14+
from nannyml.chunk import Chunker
15+
from nannyml.exceptions import InvalidArgumentsException
16+
from nannyml.thresholds import Threshold, calculate_threshold_values, ConstantThreshold
17+
from nannyml.usage_logging import UsageEvent, log_usage
18+
from .result import Result
19+
20+
"""
21+
Values Out Of Range Data Quality Module.
22+
"""
23+
24+
25+
class NumericalRangeCalculator(AbstractCalculator):
26+
"""NumericalRangeCalculator ensures the monitoring data set numerical ranges match the reference data set ones."""
27+
28+
def __init__(
29+
self,
30+
column_names: Union[str, List[str]],
31+
normalize: bool = True,
32+
timestamp_column_name: Optional[str] = None,
33+
chunk_size: Optional[int] = None,
34+
chunk_number: Optional[int] = None,
35+
chunk_period: Optional[str] = None,
36+
chunker: Optional[Chunker] = None,
37+
threshold: Threshold = ConstantThreshold(lower=None, upper=0),
38+
):
39+
"""Creates a new NumericalRangeCalculator instance.
40+
41+
Parameters
42+
----------
43+
column_names: Union[str, List[str]]
44+
A string or list containing the names of features in the provided data set.
45+
Missing Values will be calculated for each entry in this list.
46+
normalize: bool, default=True
47+
Whether to provide the missing value ratio (True) or the absolute number of missing values (False).
48+
timestamp_column_name: str
49+
The name of the column containing the timestamp of the model prediction.
50+
chunk_size: int
51+
Splits the data into chunks containing `chunks_size` observations.
52+
Only one of `chunk_size`, `chunk_number` or `chunk_period` should be given.
53+
chunk_number: int
54+
Splits the data into `chunk_number` pieces.
55+
Only one of `chunk_size`, `chunk_number` or `chunk_period` should be given.
56+
chunk_period: str
57+
Splits the data according to the given period.
58+
Only one of `chunk_size`, `chunk_number` or `chunk_period` should be given.
59+
chunker : Chunker
60+
The `Chunker` used to split the data sets into a lists of chunks.
61+
threshold: Threshold, default=StandardDeviationThreshold
62+
The threshold you wish to evaluate values on. Defaults to a StandardDeviationThreshold with default
63+
options. The other available value is ConstantThreshold.
64+
65+
66+
Examples
67+
--------
68+
>>> import nannyml as nml
69+
>>> reference_df, analysis_df, _ = nml.load_synthetic_car_price_dataset()
70+
>>> feature_column_names = [col for col in reference_df.columns if col not in [
71+
... 'fuel','transmission','timestamp', 'y_pred', 'y_true']]
72+
>>> calc = nml.NumericalRangeCalculator(
73+
... column_names=feature_column_names,
74+
... timestamp_column_name='timestamp',
75+
... ).fit(reference_df)
76+
>>> res = calc.calculate(analysis_df)
77+
>>> res.filter(period='analysis').plot().show()
78+
"""
79+
super(NumericalRangeCalculator, self).__init__(
80+
chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name
81+
)
82+
if isinstance(column_names, str):
83+
self.column_names = [column_names]
84+
elif isinstance(column_names, list):
85+
for el in column_names:
86+
if not isinstance(el, str):
87+
raise InvalidArgumentsException(
88+
f"column_names elements should be either a column name string or a list of strings, found\n{el}"
89+
)
90+
self.column_names = column_names
91+
else:
92+
raise InvalidArgumentsException(
93+
"column_names should be either a column name string or a list of columns names strings, "
94+
"found\n{column_names}"
95+
)
96+
self.result: Optional[Result] = None
97+
98+
# threshold strategy is the same across all columns
99+
self.threshold = threshold
100+
self._upper_alert_thresholds: Dict[str, Optional[float]] = {column_name: 0 for column_name in self.column_names}
101+
self._lower_alert_thresholds: Dict[str, Optional[float]] = {column_name: 0 for column_name in self.column_names}
102+
103+
self.lower_threshold_value_limit: float = 0
104+
self.upper_threshold_value_limit: Optional[float] = None
105+
self.normalize = normalize
106+
107+
if self.normalize:
108+
self.data_quality_metric = 'out_of_range_values_rate'
109+
self.upper_threshold_value_limit = 1
110+
else:
111+
self.data_quality_metric = 'out_of_range_values_count'
112+
self.upper_threshold_value_limit = np.nan
113+
114+
# object tracks values as list [min,max]
115+
self._reference_value_ranges: Dict[str, list] = {column_name: list() for column_name in self.column_names}
116+
117+
def _calculate_out_of_range_stats(self, data: pd.Series, lower_bound: float, upper_bound: float):
118+
# to do make this calc out of range stats
119+
count_tot = data.shape[0]
120+
count_out_of_range = ((data < lower_bound) | (data > upper_bound)).sum()
121+
if self.normalize:
122+
count_out_of_range = count_out_of_range / count_tot
123+
return count_out_of_range
124+
125+
@log_usage(UsageEvent.DQ_CALC_VALUES_OUT_OF_RANGE_FIT, metadata_from_self=['normalize'])
126+
def _fit(self, reference_data: pd.DataFrame, *args, **kwargs):
127+
"""Fits the drift calculator to a set of reference data."""
128+
if reference_data.empty:
129+
raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.')
130+
131+
_list_missing(self.column_names, reference_data)
132+
133+
# All provided columns must be continuous
134+
# We do not make int categorical
135+
continuous_column_names, categorical_column_names = _split_features_by_type(reference_data, self.column_names)
136+
if not set(self.column_names) == set(continuous_column_names):
137+
raise InvalidArgumentsException(
138+
f"Specified columns_names for NumericalRangeCalculator must all be continuous. "
139+
f"Categorical columns found: {categorical_column_names}"
140+
)
141+
142+
for col in self.column_names:
143+
self._reference_value_ranges[col] = [reference_data[col].min(), reference_data[col].max()]
144+
145+
self.result = self._calculate(data=reference_data)
146+
self.result.data[('chunk', 'period')] = 'reference'
147+
148+
return self
149+
150+
@log_usage(UsageEvent.DQ_CALC_VALUES_OUT_OF_RANGE_RUN, metadata_from_self=['normalize'])
151+
def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
152+
"""Calculates methods for both categorical and continuous columns."""
153+
if data.empty:
154+
raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.')
155+
156+
_list_missing(self.column_names, data)
157+
158+
chunks = self.chunker.split(data)
159+
160+
rows = []
161+
for chunk in chunks:
162+
row = {
163+
'key': chunk.key,
164+
'chunk_index': chunk.chunk_index,
165+
'start_index': chunk.start_index,
166+
'end_index': chunk.end_index,
167+
'start_datetime': chunk.start_datetime,
168+
'end_datetime': chunk.end_datetime,
169+
'period': 'analysis',
170+
}
171+
172+
for column_name in self.column_names:
173+
for k, v in self._calculate_for_column(chunk.data, column_name).items():
174+
row[f'{column_name}_{k}'] = v
175+
176+
rows.append(row)
177+
178+
result_index = _create_multilevel_index(
179+
column_names=self.column_names,
180+
)
181+
res = pd.DataFrame(rows)
182+
res.columns = result_index
183+
res = res.reset_index(drop=True)
184+
185+
if self.result is None:
186+
self._set_metric_thresholds(res)
187+
res = self._populate_alert_thresholds(res)
188+
self.result = Result(
189+
results_data=res,
190+
column_names=self.column_names,
191+
data_quality_metric=self.data_quality_metric,
192+
timestamp_column_name=self.timestamp_column_name,
193+
chunker=self.chunker,
194+
)
195+
else:
196+
# TODO: review subclassing setup => superclass + '_filter' is screwing up typing.
197+
# Dropping the intermediate '_filter' and directly returning the correct 'Result' class works OK
198+
# but this causes us to lose the "common behavior" in the top level 'filter' method when overriding.
199+
# Applicable here but to many of the base classes as well (e.g. fitting and calculating)
200+
res = self._populate_alert_thresholds(res)
201+
self.result = self.result.filter(period='reference')
202+
self.result.data = pd.concat([self.result.data, res], ignore_index=True)
203+
204+
return self.result
205+
206+
def _calculate_for_column(self, data: pd.DataFrame, column_name: str) -> Dict[str, Any]:
207+
result = {}
208+
value_range = self._reference_value_ranges[column_name]
209+
value = self._calculate_out_of_range_stats(data[column_name], value_range[0], value_range[1])
210+
result['value'] = value
211+
return result
212+
213+
def _set_metric_thresholds(self, result_data: pd.DataFrame):
214+
for column_name in self.column_names:
215+
(
216+
self._lower_alert_thresholds[column_name],
217+
self._upper_alert_thresholds[column_name],
218+
) = calculate_threshold_values( # noqa: E501
219+
threshold=self.threshold,
220+
data=result_data.loc[:, (column_name, 'value')],
221+
lower_threshold_value_limit=self.lower_threshold_value_limit,
222+
upper_threshold_value_limit=self.upper_threshold_value_limit,
223+
logger=self._logger,
224+
)
225+
226+
def _populate_alert_thresholds(self, result_data: pd.DataFrame) -> pd.DataFrame:
227+
for column_name in self.column_names:
228+
result_data[(column_name, 'upper_threshold')] = self._upper_alert_thresholds[column_name]
229+
result_data[(column_name, 'lower_threshold')] = self._lower_alert_thresholds[column_name]
230+
result_data[(column_name, 'alert')] = result_data.apply(
231+
lambda row: True
232+
if (
233+
row[(column_name, 'value')]
234+
> (
235+
np.inf
236+
if row[(column_name, 'upper_threshold')] is None
237+
else row[(column_name, 'upper_threshold')] # noqa: E501
238+
)
239+
or row[(column_name, 'value')]
240+
< (
241+
-np.inf
242+
if row[(column_name, 'lower_threshold')] is None
243+
else row[(column_name, 'lower_threshold')] # noqa: E501
244+
)
245+
)
246+
else False,
247+
axis=1,
248+
)
249+
return result_data
250+
251+
252+
def _create_multilevel_index(
253+
column_names,
254+
):
255+
chunk_column_names = ['key', 'chunk_index', 'start_index', 'end_index', 'start_date', 'end_date', 'period']
256+
chunk_tuples = [('chunk', chunk_column_name) for chunk_column_name in chunk_column_names]
257+
column_tuples = [
258+
(column_name, 'value')
259+
for column_name in column_names
260+
# for el in ['value', 'upper_threshold', 'lower_threshold', 'alert']
261+
]
262+
tuples = chunk_tuples + column_tuples
263+
return MultiIndex.from_tuples(tuples)

0 commit comments

Comments
 (0)