-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: distribution calculators (#352)
* First version of continuous distribution calculator working * Refactor plotting to support drift results for alert specification * Support running ContinuousDistributionCalculator in the Runner * Fix pickling ContinuousDistributionCalculator * Working version of CategoricalDistributionCalculator * This is not how overload works. * Getting index-based plots to work * Support categorical distribution calculator in the runner * Fix Flake8 & mypy * Expose option to downscale resolution of individual joyplots for continuous distribution plots * Expose cumulative density for KDE quartiles * Use first point >= quartile instead of closest * Updated default thresholds for Univariate Drift detection methods * Fix broken ranker tests. This is why we do PR's kids. * Fix linting * Register summary stats in CLI runner (#353) * Unique identifier column to nannyML datasets (#348) * add unique ID column * Remove duplicate 'identifier' column * Fix broken tests * isort changes --------- Co-authored-by: Niels Nuyttens <niels@nannyml.com> --------- Co-authored-by: Michael Van de Steene <michael@nannyml.com> Co-authored-by: Michael Van de Steene <124588413+michael-nml@users.noreply.github.com> Co-authored-by: Santiago Víquez <santi.viquez@gmail.com>
- Loading branch information
1 parent
7b5b969
commit 2af5e7b
Showing
12 changed files
with
1,145 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .categorical import CategoricalDistributionCalculator | ||
from .continuous import ContinuousDistributionCalculator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .calculator import CategoricalDistributionCalculator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from typing import List, Optional, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from typing_extensions import Self | ||
|
||
from nannyml import Chunker | ||
from nannyml.base import AbstractCalculator, _list_missing | ||
from nannyml.distribution.categorical.result import Result | ||
from nannyml.exceptions import InvalidArgumentsException | ||
|
||
|
||
class CategoricalDistributionCalculator(AbstractCalculator): | ||
def __init__( | ||
self, | ||
column_names: Union[str, List[str]], | ||
timestamp_column_name: Optional[str] = None, | ||
chunk_size: Optional[int] = None, | ||
chunk_number: Optional[int] = None, | ||
chunk_period: Optional[str] = None, | ||
chunker: Optional[Chunker] = None, | ||
): | ||
super().__init__( | ||
chunk_size, | ||
chunk_number, | ||
chunk_period, | ||
chunker, | ||
timestamp_column_name, | ||
) | ||
|
||
self.column_names = column_names if isinstance(column_names, List) else [column_names] | ||
self.result: Optional[Result] = None | ||
self._was_fitted: bool = False | ||
|
||
def _fit(self, reference_data: pd.DataFrame, *args, **kwargs) -> Self: | ||
self.result = self._calculate(reference_data) | ||
self._was_fitted = True | ||
|
||
return self | ||
|
||
def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result: | ||
if data.empty: | ||
raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.') | ||
|
||
_list_missing(self.column_names, data) | ||
|
||
# result_data = pd.DataFrame(columns=_create_multilevel_index(self.column_names)) | ||
result_data = pd.DataFrame() | ||
|
||
chunks = self.chunker.split(data) | ||
chunks_data = pd.DataFrame( | ||
{ | ||
'key': [c.key for c in chunks], | ||
'chunk_index': [c.chunk_index for c in chunks], | ||
'start_datetime': [c.start_datetime for c in chunks], | ||
'end_datetime': [c.end_datetime for c in chunks], | ||
'start_index': [c.start_index for c in chunks], | ||
'end_index': [c.end_index for c in chunks], | ||
'period': ['analysis' if self._was_fitted else 'reference' for _ in chunks], | ||
} | ||
) | ||
|
||
for column in self.column_names: | ||
value_counts = calculate_value_counts( | ||
data=data[column], | ||
chunker=self.chunker, | ||
timestamps=data.get(self.timestamp_column_name, default=None), | ||
max_number_of_categories=5, | ||
missing_category_label='Missing', | ||
column_name=column, | ||
) | ||
result_data = pd.concat([result_data, pd.merge(chunks_data, value_counts, on='chunk_index')]) | ||
|
||
# result_data.index = pd.MultiIndex.from_tuples(list(zip(result_data['column_name'], result_data['value']))) | ||
|
||
if self.result is None: | ||
self.result = Result(result_data, self.column_names, self.timestamp_column_name, self.chunker) | ||
else: | ||
# self.result = self.result.data.loc[self.result.data['period'] == 'reference', :] | ||
self.result.data = pd.concat([self.result.data, result_data]).reset_index(drop=True) | ||
|
||
return self.result | ||
|
||
|
||
def calculate_value_counts( | ||
data: Union[np.ndarray, pd.Series], | ||
chunker: Chunker, | ||
missing_category_label, | ||
max_number_of_categories, | ||
timestamps: Optional[Union[np.ndarray, pd.Series]] = None, | ||
column_name: Optional[str] = None, | ||
): | ||
if isinstance(data, np.ndarray): | ||
if column_name is None: | ||
raise InvalidArgumentsException("'column_name' can not be None when 'data' is of type 'np.ndarray'.") | ||
data = pd.Series(data, name=column_name) | ||
else: | ||
column_name = data.name | ||
|
||
data = data.astype("category") | ||
cat_str = [str(value) for value in data.cat.categories.values] | ||
data = data.cat.rename_categories(cat_str) | ||
data = data.cat.add_categories([missing_category_label, 'Other']) | ||
data = data.fillna(missing_category_label) | ||
|
||
if max_number_of_categories: | ||
top_categories = data.value_counts().index.tolist()[:max_number_of_categories] | ||
if data.nunique() > max_number_of_categories + 1: | ||
data.loc[~data.isin(top_categories)] = 'Other' | ||
|
||
data = data.cat.remove_unused_categories() | ||
|
||
categories_ordered = data.value_counts().index.tolist() | ||
categorical_data = pd.Categorical(data, categories_ordered) | ||
|
||
# TODO: deal with None timestamps | ||
if isinstance(timestamps, pd.Series): | ||
timestamps = timestamps.reset_index() | ||
|
||
chunks = chunker.split(pd.concat([pd.Series(categorical_data, name=column_name), timestamps], axis=1)) | ||
data_with_chunk_keys = pd.concat([chunk.data.assign(chunk_index=chunk.chunk_index) for chunk in chunks]) | ||
|
||
value_counts_table = ( | ||
data_with_chunk_keys.groupby(['chunk_index'])[column_name] | ||
.value_counts() | ||
.to_frame('value_counts') | ||
.sort_values(by=['chunk_index', 'value_counts']) | ||
.reset_index() | ||
.rename(columns={column_name: 'value'}) | ||
.assign(column_name=column_name) | ||
) | ||
|
||
value_counts_table['value_counts_total'] = value_counts_table['chunk_index'].map( | ||
value_counts_table.groupby('chunk_index')['value_counts'].sum() | ||
) | ||
value_counts_table['value_counts_normalised'] = ( | ||
value_counts_table['value_counts'] / value_counts_table['value_counts_total'] | ||
) | ||
|
||
return value_counts_table |
Oops, something went wrong.