-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Distribution checks/comparisons (#163)
* Reworking visual checks to be distribution-focused * Updating package names * Adding depdencies used to read the datacube (required for testing) * Further dependencies for the data-cube format * Changing type annotation to work with python 3.9 * Changing type annotation to work with python 3.9 (again)
- Loading branch information
1 parent
b15d214
commit 5da3f83
Showing
9 changed files
with
1,419 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
from typing import List, Dict, Any, Callable | ||
from numbers import Number | ||
|
||
from . import vega | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.spatial.distance import jensenshannon | ||
|
||
from dataclasses import dataclass | ||
|
||
# TODO: Look at scipy's KL and JS | ||
|
||
|
||
@dataclass(repr=False) | ||
class Result: | ||
"""Results of a check. | ||
At a minimum, the combined check, the individual checks and | ||
a visual representation of the evidence. May also include additional | ||
information computed for that individual check. | ||
""" | ||
status: bool | ||
checks: Dict[str, Any] | ||
schema: Dict | ||
|
||
def __repr__(self): | ||
always_display = ["status", "checks", "schema"] | ||
schema = "<missing>" if self.schema is None else "<present>" | ||
|
||
additional = [f for f in dir(self) | ||
if not f.startswith("_") and f not in always_display] | ||
extras = (f"; Additional Fields: {additional}" | ||
if len(additional) > 0 else "") | ||
return f"Result(status:{self.status}, checks:{self.checks}, schema:{schema}{extras})" | ||
|
||
|
||
def contains( | ||
ref_lower: Number, | ||
ref_upper: Number, | ||
pct: float = None | ||
) -> Callable[[pd.DataFrame], bool]: | ||
"""Check-generator function. Returns a function that performs a test. | ||
returns -- A function that takes a dataframe tests it against the bounds. | ||
If pct IS NOT SUPPLIED, the returned function needs | ||
a dataframe with bin-boundaries on the index. It | ||
checks if ref_lower and ref_upper are within the | ||
distribution range. | ||
If pct IS SUPPLIED, the returned function checks takes | ||
a dataframe with bin-boundaries on the index and a 'count' | ||
column. It checks if the distribution between ref_lower | ||
and ref_upper is AT LEAST pct% of the total data. | ||
""" | ||
|
||
def pct_test(bins: pd.DataFrame) -> bool: | ||
total = bins["count"].sum() | ||
level0 = bins.index.get_level_values(0) | ||
level1 = bins.index.get_level_values(1) | ||
subset = bins[(level0 >= ref_lower) & (level1 <= ref_upper)] | ||
covered = subset["count"].sum() | ||
|
||
return covered / total >= pct | ||
|
||
def simple_test(bins: pd.DataFrame) -> bool: | ||
level0 = bins.index.get_level_values(0) | ||
level1 = bins.index.get_level_values(1) | ||
|
||
data_lower = min(level0.min(), level1.min()) | ||
data_upper = max(level0.max(), level1.max()) | ||
return (data_lower <= ref_lower <= data_upper) and ( | ||
data_lower <= ref_upper <= data_upper | ||
) | ||
|
||
if pct is not None: | ||
return pct_test | ||
else: | ||
return simple_test | ||
|
||
|
||
def JS(max_acceptable: float, *, verbose: bool = False) -> Callable[[Any, Any], bool]: | ||
"""Check-generator function. Returns a function that performs a test against jensen-shannon distance. | ||
max_acceptable -- Threshold for the returned check | ||
returns -- Returns a function that checks if JS distance of two lists of bin-counts is less than max_acceptable. | ||
Returned function takes two lists of bins-counts and returns a boolean result. The signature | ||
is roughly (list[Number], list[Number]) -> bool. | ||
""" | ||
|
||
def _inner(a, b): | ||
a = np.asarray(a, dtype=np.float32) | ||
b = np.asarray(b, dtype=np.float32) | ||
js = jensenshannon(a, b) | ||
if verbose: | ||
print(f"JS distance is {js}") | ||
return js <= max_acceptable | ||
|
||
return _inner | ||
|
||
|
||
def check_distribution_range( | ||
distribution: pd.DataFrame, | ||
lower: Number, | ||
upper: Number, | ||
*, | ||
label: str = "distribution", | ||
tests: Dict[str, Callable[[pd.DataFrame, Number, Number], bool]] = {}, | ||
combiner: Callable[[List[bool]], bool] = all, | ||
**kwargs, | ||
) -> Result: | ||
""" | ||
Checks a single distribution against a lower- and upper-bound. | ||
distribution -- Distribution to check | ||
lower -- Lower bound to compare to the distribution | ||
upper -- Upper bound to compare to the distribution | ||
label -- Label to put on resulting plot | ||
tests -- Tests to make against the distribution | ||
(Typed as dict of label/value, but can be list of callables instead) | ||
combiner -- Combines the results of the test | ||
""" | ||
if isinstance(tests, list): | ||
tests = dict(enumerate(tests)) | ||
|
||
combined_args = {**{label: distribution}, **kwargs} | ||
schema, bins = vega.histogram_multi( | ||
xrefs=[lower, upper], return_bins=True, **combined_args | ||
) | ||
|
||
checks = {label: test(bins) for label, test in tests.items()} | ||
status = combiner([*checks.values()]) | ||
if not status: | ||
status_msg = f"Failed ({sum(checks.values())/len(checks.values()):.0%} passing)" | ||
else: | ||
status_msg = "Passed" | ||
|
||
schema["title"]["text"] = ["Distribution Check (Histogram)", status_msg] | ||
|
||
rslt = Result(status, checks, schema) | ||
rslt.bins = bins | ||
return rslt | ||
|
||
|
||
def compare_distributions( | ||
subject: pd.DataFrame, | ||
reference: pd.DataFrame, | ||
*, | ||
tests: Dict[str, Callable[[pd.DataFrame, pd.DataFrame], bool]] = {}, | ||
combiner: Callable[[List[bool]], bool] = all, | ||
**kwargs, | ||
) -> Result: | ||
""" | ||
Compares two distributions. | ||
This function returns a histogram visualization of the two distributions | ||
and the result running passed checks against those two distributions. | ||
NOTE: As tests may be non-symetric, tests will be called with an aligned | ||
distribution built from the subject and the reference. The subject | ||
distribution will be passed first as the first argument. | ||
""" | ||
if isinstance(tests, list): | ||
tests = dict(enumerate(tests)) | ||
|
||
schema, bins = vega.histogram_multi( | ||
Subject=subject, Reference=reference, return_bins=True, **kwargs | ||
) | ||
|
||
groups = dict([*bins.groupby("label")]) | ||
subject_dist = ( | ||
groups["Subject"] | ||
.rename(columns={"count": "subject"}) | ||
.drop(columns=["label"]) | ||
) | ||
reference_dist = ( | ||
groups["Reference"] | ||
.rename(columns={"count": "ref"}) | ||
.drop(columns=["label"]) | ||
) | ||
|
||
aligned = ( | ||
subject_dist | ||
.join(reference_dist, how="outer") | ||
.fillna(0) | ||
) | ||
|
||
checks = {label: test(aligned["subject"].values, aligned["ref"].values) | ||
for label, test in tests.items()} | ||
status = combiner([*checks.values()]) | ||
|
||
if not status: | ||
status_msg = f"Failed ({sum(checks.values())/len(checks.values()):.0%} passing)" | ||
else: | ||
status_msg = "Passed" | ||
|
||
schema["title"]["text"] = ["Distribution Comparison", status_msg] | ||
|
||
rslt = Result(status, checks, schema) | ||
rslt.aligned = aligned | ||
return rslt |
112 changes: 112 additions & 0 deletions
112
src/pyciemss/visuals/histogram_static_bins_multi.vg.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
{ | ||
"$schema": "https://vega.github.io/schema/vega/v5.json", | ||
"description": "Histogram for visualizing a univariate distribution using static bins.", | ||
"width": 500, | ||
"height": 100, | ||
"padding": 5, | ||
"data": [ | ||
{ | ||
"name": "binned", | ||
"values": [] | ||
}, | ||
{"name": "xref", "values": []}, | ||
{"name": "yref", "values": []}, | ||
{ | ||
"name": "_ranges", | ||
"source": ["binned"], | ||
"transform": [ | ||
{ | ||
"type": "aggregate", | ||
"fields": ["bin0", "bin1"], | ||
"ops": ["min", "max"], | ||
"as": ["min", "max"] | ||
} | ||
] | ||
} | ||
], | ||
"legends": [ | ||
{"fill": "color"} | ||
], | ||
"scales": [ | ||
{ | ||
"name": "xscale", | ||
"type": "linear", | ||
"range": "width", | ||
"domain": {"data": "_ranges", "fields": ["min", "max"]} | ||
}, | ||
{ | ||
"name": "yscale", | ||
"type": "linear", | ||
"range": "height", | ||
"round": true, | ||
"domain": {"data": "binned", "field": "count"}, | ||
"zero": true, | ||
"nice": true | ||
}, | ||
{ | ||
"name": "color", | ||
"type": "ordinal", | ||
"domain": {"data": "binned", "field": "label"}, | ||
"range": {"scheme": "dark2"} | ||
} | ||
], | ||
"axes": [ | ||
{"orient": "bottom", "scale": "xscale", "zindex": 1}, | ||
{"orient": "left", "scale": "yscale", "tickCount": 5, "zindex": 1} | ||
], | ||
"title": { | ||
"text": "Histogram", | ||
"orient": "top", | ||
"anchor": "start", | ||
"frame": "group" | ||
}, | ||
"marks": [ | ||
{ | ||
"name": "bins", | ||
"type": "rect", | ||
"from": {"data": "binned"}, | ||
"encode": { | ||
"update": { | ||
"x": {"scale": "xscale", "field": "bin0"}, | ||
"x2": {"scale": "xscale", "field": "bin1", "offset": -0.5}, | ||
"y": {"scale": "yscale", "field": "count"}, | ||
"y2": {"scale": "yscale", "value": 0}, | ||
"tooltip": {"field": "count"}, | ||
"fill": {"scale": "color", "field": "label"}, | ||
"opacity": {"value": 0.7} | ||
}, | ||
"hover": {"fill": {"value": "firebrick"}} | ||
} | ||
}, | ||
{ | ||
"name": "x_highlights", | ||
"type": "rule", | ||
"clip": true, | ||
"from": {"data": "xref"}, | ||
"encode": { | ||
"enter": {"stroke": {"value": "red"}}, | ||
"update": { | ||
"x": {"scale": "xscale", "field": "value"}, | ||
"y2": {"value": 0}, | ||
"y": {"signal": "height"}, | ||
"opacity": {"value": 1} | ||
} | ||
} | ||
}, | ||
{ | ||
"name": "y_highlights", | ||
"type": "rule", | ||
"clip": true, | ||
"from": {"data": "yref"}, | ||
"encode": { | ||
"enter": {"stroke": {"value": "red"}}, | ||
"update": { | ||
"y": {"scale": "yscale", "field": "count"}, | ||
"x2": {"value": 0}, | ||
"x": {"signal": "width"}, | ||
"opacity": {"value": 1} | ||
} | ||
} | ||
} | ||
] | ||
} |
Oops, something went wrong.