Skip to content

Commit

Permalink
Distribution checks/comparisons (#163)
Browse files Browse the repository at this point in the history
* Reworking visual checks to be distribution-focused

* Updating package names

* Adding depdencies used to read the datacube (required for testing)

* Further dependencies for the data-cube format

* Changing type annotation to work with python 3.9

* Changing type annotation to work with python 3.9 (again)
  • Loading branch information
JosephCottam authored May 26, 2023
1 parent b15d214 commit 5da3f83
Show file tree
Hide file tree
Showing 9 changed files with 1,419 additions and 0 deletions.
786 changes: 786 additions & 0 deletions notebook/Visual Checks.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ install_requires =
networkx
pandas
mira @ git+https://github.com/indralab/mira.git
xarray
netcdf4
h5netcdf
dask

zip_safe = false
include_package_data = true
python_requires = >=3.9
Expand Down
Empty file.
201 changes: 201 additions & 0 deletions src/pyciemss/visuals/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from typing import List, Dict, Any, Callable
from numbers import Number

from . import vega
import numpy as np
import pandas as pd
from scipy.spatial.distance import jensenshannon

from dataclasses import dataclass

# TODO: Look at scipy's KL and JS


@dataclass(repr=False)
class Result:
"""Results of a check.
At a minimum, the combined check, the individual checks and
a visual representation of the evidence. May also include additional
information computed for that individual check.
"""
status: bool
checks: Dict[str, Any]
schema: Dict

def __repr__(self):
always_display = ["status", "checks", "schema"]
schema = "<missing>" if self.schema is None else "<present>"

additional = [f for f in dir(self)
if not f.startswith("_") and f not in always_display]
extras = (f"; Additional Fields: {additional}"
if len(additional) > 0 else "")
return f"Result(status:{self.status}, checks:{self.checks}, schema:{schema}{extras})"


def contains(
ref_lower: Number,
ref_upper: Number,
pct: float = None
) -> Callable[[pd.DataFrame], bool]:
"""Check-generator function. Returns a function that performs a test.
returns -- A function that takes a dataframe tests it against the bounds.
If pct IS NOT SUPPLIED, the returned function needs
a dataframe with bin-boundaries on the index. It
checks if ref_lower and ref_upper are within the
distribution range.
If pct IS SUPPLIED, the returned function checks takes
a dataframe with bin-boundaries on the index and a 'count'
column. It checks if the distribution between ref_lower
and ref_upper is AT LEAST pct% of the total data.
"""

def pct_test(bins: pd.DataFrame) -> bool:
total = bins["count"].sum()
level0 = bins.index.get_level_values(0)
level1 = bins.index.get_level_values(1)
subset = bins[(level0 >= ref_lower) & (level1 <= ref_upper)]
covered = subset["count"].sum()

return covered / total >= pct

def simple_test(bins: pd.DataFrame) -> bool:
level0 = bins.index.get_level_values(0)
level1 = bins.index.get_level_values(1)

data_lower = min(level0.min(), level1.min())
data_upper = max(level0.max(), level1.max())
return (data_lower <= ref_lower <= data_upper) and (
data_lower <= ref_upper <= data_upper
)

if pct is not None:
return pct_test
else:
return simple_test


def JS(max_acceptable: float, *, verbose: bool = False) -> Callable[[Any, Any], bool]:
"""Check-generator function. Returns a function that performs a test against jensen-shannon distance.
max_acceptable -- Threshold for the returned check
returns -- Returns a function that checks if JS distance of two lists of bin-counts is less than max_acceptable.
Returned function takes two lists of bins-counts and returns a boolean result. The signature
is roughly (list[Number], list[Number]) -> bool.
"""

def _inner(a, b):
a = np.asarray(a, dtype=np.float32)
b = np.asarray(b, dtype=np.float32)
js = jensenshannon(a, b)
if verbose:
print(f"JS distance is {js}")
return js <= max_acceptable

return _inner


def check_distribution_range(
distribution: pd.DataFrame,
lower: Number,
upper: Number,
*,
label: str = "distribution",
tests: Dict[str, Callable[[pd.DataFrame, Number, Number], bool]] = {},
combiner: Callable[[List[bool]], bool] = all,
**kwargs,
) -> Result:
"""
Checks a single distribution against a lower- and upper-bound.
distribution -- Distribution to check
lower -- Lower bound to compare to the distribution
upper -- Upper bound to compare to the distribution
label -- Label to put on resulting plot
tests -- Tests to make against the distribution
(Typed as dict of label/value, but can be list of callables instead)
combiner -- Combines the results of the test
"""
if isinstance(tests, list):
tests = dict(enumerate(tests))

combined_args = {**{label: distribution}, **kwargs}
schema, bins = vega.histogram_multi(
xrefs=[lower, upper], return_bins=True, **combined_args
)

checks = {label: test(bins) for label, test in tests.items()}
status = combiner([*checks.values()])
if not status:
status_msg = f"Failed ({sum(checks.values())/len(checks.values()):.0%} passing)"
else:
status_msg = "Passed"

schema["title"]["text"] = ["Distribution Check (Histogram)", status_msg]

rslt = Result(status, checks, schema)
rslt.bins = bins
return rslt


def compare_distributions(
subject: pd.DataFrame,
reference: pd.DataFrame,
*,
tests: Dict[str, Callable[[pd.DataFrame, pd.DataFrame], bool]] = {},
combiner: Callable[[List[bool]], bool] = all,
**kwargs,
) -> Result:
"""
Compares two distributions.
This function returns a histogram visualization of the two distributions
and the result running passed checks against those two distributions.
NOTE: As tests may be non-symetric, tests will be called with an aligned
distribution built from the subject and the reference. The subject
distribution will be passed first as the first argument.
"""
if isinstance(tests, list):
tests = dict(enumerate(tests))

schema, bins = vega.histogram_multi(
Subject=subject, Reference=reference, return_bins=True, **kwargs
)

groups = dict([*bins.groupby("label")])
subject_dist = (
groups["Subject"]
.rename(columns={"count": "subject"})
.drop(columns=["label"])
)
reference_dist = (
groups["Reference"]
.rename(columns={"count": "ref"})
.drop(columns=["label"])
)

aligned = (
subject_dist
.join(reference_dist, how="outer")
.fillna(0)
)

checks = {label: test(aligned["subject"].values, aligned["ref"].values)
for label, test in tests.items()}
status = combiner([*checks.values()])

if not status:
status_msg = f"Failed ({sum(checks.values())/len(checks.values()):.0%} passing)"
else:
status_msg = "Passed"

schema["title"]["text"] = ["Distribution Comparison", status_msg]

rslt = Result(status, checks, schema)
rslt.aligned = aligned
return rslt
112 changes: 112 additions & 0 deletions src/pyciemss/visuals/histogram_static_bins_multi.vg.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
{
"$schema": "https://vega.github.io/schema/vega/v5.json",
"description": "Histogram for visualizing a univariate distribution using static bins.",
"width": 500,
"height": 100,
"padding": 5,
"data": [
{
"name": "binned",
"values": []
},
{"name": "xref", "values": []},
{"name": "yref", "values": []},
{
"name": "_ranges",
"source": ["binned"],
"transform": [
{
"type": "aggregate",
"fields": ["bin0", "bin1"],
"ops": ["min", "max"],
"as": ["min", "max"]
}
]
}
],
"legends": [
{"fill": "color"}
],
"scales": [
{
"name": "xscale",
"type": "linear",
"range": "width",
"domain": {"data": "_ranges", "fields": ["min", "max"]}
},
{
"name": "yscale",
"type": "linear",
"range": "height",
"round": true,
"domain": {"data": "binned", "field": "count"},
"zero": true,
"nice": true
},
{
"name": "color",
"type": "ordinal",
"domain": {"data": "binned", "field": "label"},
"range": {"scheme": "dark2"}
}
],
"axes": [
{"orient": "bottom", "scale": "xscale", "zindex": 1},
{"orient": "left", "scale": "yscale", "tickCount": 5, "zindex": 1}
],
"title": {
"text": "Histogram",
"orient": "top",
"anchor": "start",
"frame": "group"
},
"marks": [
{
"name": "bins",
"type": "rect",
"from": {"data": "binned"},
"encode": {
"update": {
"x": {"scale": "xscale", "field": "bin0"},
"x2": {"scale": "xscale", "field": "bin1", "offset": -0.5},
"y": {"scale": "yscale", "field": "count"},
"y2": {"scale": "yscale", "value": 0},
"tooltip": {"field": "count"},
"fill": {"scale": "color", "field": "label"},
"opacity": {"value": 0.7}
},
"hover": {"fill": {"value": "firebrick"}}
}
},
{
"name": "x_highlights",
"type": "rule",
"clip": true,
"from": {"data": "xref"},
"encode": {
"enter": {"stroke": {"value": "red"}},
"update": {
"x": {"scale": "xscale", "field": "value"},
"y2": {"value": 0},
"y": {"signal": "height"},
"opacity": {"value": 1}
}
}
},
{
"name": "y_highlights",
"type": "rule",
"clip": true,
"from": {"data": "yref"},
"encode": {
"enter": {"stroke": {"value": "red"}},
"update": {
"y": {"scale": "yscale", "field": "count"},
"x2": {"value": 0},
"x": {"signal": "width"},
"opacity": {"value": 1}
}
}
}
]
}
Loading

0 comments on commit 5da3f83

Please sign in to comment.