Skip to content

Commit dceb470

Browse files
authored
Add some unit tests for the ModelCardReport methods (#529)
1 parent 2ce9da3 commit dceb470

File tree

9 files changed

+229
-46
lines changed

9 files changed

+229
-46
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ repos:
5656
- id: doctest
5757
name: doctest
5858
entry: python3 -m doctest -o NORMALIZE_WHITESPACE
59-
files: "^cyclops/evaluate/"
59+
files: "^cyclops/"
6060
language: system
6161

6262
- repo: local

cyclops/data/slicer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class SliceSpec:
110110
... {
111111
... "feature_1": {"value": "value_1"},
112112
... "feature_2": {
113-
... "min_value": "2020-01-01", keep_nulls: False,
113+
... "min_value": "2020-01-01", "keep_nulls": False,
114114
... },
115115
... "feature_3": {"year": ["2000", "2010", "2020"]},
116116
... },
@@ -119,8 +119,22 @@ class SliceSpec:
119119
>>> for slice_name, slice_func in slice_spec.slices():
120120
... print(slice_name)
121121
... # do something with slice_func here (e.g. dataset.filter(slice_func))
122-
123-
"""
122+
feature_1:non_null
123+
feature_2:non_null&feature_3:non_null
124+
feature_1:value_1
125+
feature_1:value_1, value_2
126+
!(feature_1:value_1)
127+
feature_1:[2020-01-01 - 2020-12-31]
128+
feature_1:(5 - 60)
129+
feature_1:year=[2020, 2021, 2022]
130+
feature_1:month=[6, 7, 8]
131+
feature_1:month=6, day=1
132+
feature_1:contains value_1
133+
feature_1:contains ['value_1', 'value_2']
134+
feature_1:value_1&feature_2:[2020-01-01 - inf]&feature_3:year=['2000', '2010', '2020']
135+
overall
136+
137+
""" # noqa: W505
124138

125139
spec_list: List[Dict[str, Dict[str, Any]]] = field(
126140
default_factory=lambda: [{}],

cyclops/monitor/clinical_applicator.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ class ClinicalShiftApplicator:
1717
The source and target datasets are then generated by splitting
1818
the original dataset along the categorical feature.
1919
20-
Examples
21-
--------
22-
>>> from cyclops.monitor.clinical_applicator import ClinicalShiftApplicator
23-
>>> from cyclops.data.utils import load_nih
24-
>>> ds = load_nih(path="/mnt/data/nihcxr")
25-
>>> applicator = ClinicalShiftApplicator("hospital_type",
26-
source = ["hospital_type_1", "hospital_type_2"]
27-
target = ["hospital_type_3", "hospital_type_4", "hospital_type_5"]
28-
)
29-
>>> ds_source, ds_target = applicator.apply_shift(ds)
20+
# Examples
21+
# --------
22+
# >>> from cyclops.monitor.clinical_applicator import ClinicalShiftApplicator
23+
# >>> from cyclops.data.loader import load_nihcxr
24+
# >>> ds = load_nihcxr(path="/mnt/data/nihcxr")
25+
# >>> applicator = ClinicalShiftApplicator("hospital_type",
26+
# source = ["hospital_type_1", "hospital_type_2"]
27+
# target = ["hospital_type_3", "hospital_type_4", "hospital_type_5"]
28+
# )
29+
# >>> ds_source, ds_target = applicator.apply_shift(ds)
3030
3131
3232
Parameters

cyclops/monitor/reductor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class Reductor:
4343
Examples
4444
--------
4545
>>> # (Data is loaded from memory)
46-
>>> from drift_detection.reductor import Reductor
46+
>>> from cyclops.monitor.reductor import Reductor
4747
>>> from sklearn.datasets import load_diabetes
4848
>>> X, y = load_diabetes(return_X_y=True)
4949
>>> reductor = Reductor("pca")

cyclops/monitor/synthetic_applicator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
class SyntheticShiftApplicator:
1515
"""The SyntheticShiftApplicator class is used induce synthetic dataset shift.
1616
17-
Examples
18-
--------
19-
>>> from drift_detection.experimenter import Experimenter
20-
>>> from sklearn.datasets import load_diabetes
21-
>>> X, y = load_diabetes(return_X_y=True)
22-
>>> X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.5, random_state=42)
23-
>>> applicator = SyntheticShiftApplicator(shift_type="gn_shift")
24-
>>> X_shift = applicator.apply_shift(X_train, noise_amt=0.1, delta=0.1)
17+
# Examples
18+
# --------
19+
# >>> from sklearn.datasets import load_diabetes
20+
# >>> X, y = load_diabetes(return_X_y=True)
21+
# >>> dataset = Dataset.from_dict({"X": X, "y": y})
22+
# >>> dataset = dataset.train_test_split(test_size=0.5, seed=42)
23+
# >>> applicator = SyntheticShiftApplicator(shift_type="gn_shift")
24+
# >>> X_shift = applicator.apply_shift(dataset["test"])
2525
2626
Parameters
2727
----------

cyclops/monitor/tester.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class TSTester:
147147
>>> tester.fit(X_s)
148148
>>> p_val, dist = tester.test_shift(X_t)
149149
>>> print(p_val, dist)
150-
0.0 0.3
150+
1.3805797e-12 0.51
151151
"""
152152

153153
def __init__(
@@ -348,26 +348,27 @@ class DCTester:
348348
rate and entropy are shown to be powerful discriminative statistics
349349
for harmful covariate shift (HCS).
350350
351-
Examples
352-
--------
353-
>>> from cyclops.monitor.tester import DCTester
354-
355-
>>> nih_ds = load_nihcxr(DATA_DIR)
356-
>>> base_model = DenseNet(weights="densenet121-res224-nih")
357-
>>> detectron = DCTester("detectron", model=base_model)
358-
>>> detectron = DCTester("detectron",
359-
base_model=base_model,
360-
model=base_model,
361-
feature_columns="image",
362-
transforms=transforms,
363-
task="multilabel",
364-
max_epochs_per_model=5,
365-
ensemble_size=5,
366-
lr=0.01,
367-
num_runs=5)
368-
369-
>>> detectron.fit(source_ds)
370-
>>> p_val, distance = detectron.predict(target_ds)
351+
# Examples
352+
# --------
353+
# >>> from cyclops.monitor.tester import DCTester
354+
# >>> from cyclops.data.loader import load_nihcxr
355+
# >>> from cyclops.models.catalog import DenseNet
356+
357+
# >>> nih_ds = load_nihcxr(DATA_DIR)
358+
# >>> base_model = DenseNet(weights="densenet121-res224-nih")
359+
# >>> detectron = DCTester("detectron",
360+
# base_model=base_model,
361+
# model=base_model,
362+
# feature_columns="image",
363+
# transforms=None,
364+
# task="multilabel",
365+
# max_epochs_per_model=5,
366+
# ensemble_size=5,
367+
# lr=0.01,
368+
# num_runs=5
369+
# )
370+
# >>> detectron.fit(source_ds)
371+
# >>> p_val, distance = detectron.predict(target_ds)
371372
372373
Parameters
373374
----------

cyclops/report/report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def log_descriptor(
202202
203203
Examples
204204
--------
205-
>>> from cylops.report import ModelCardReport
205+
>>> from cyclops.report import ModelCardReport
206206
>>> report = ModelCardReport()
207207
>>> report.log_descriptor(
208208
... name="tradeoff",

cyclops/report/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def str_to_snake_case(string: str) -> str:
4343
>>> str_to_snake_case("Hello-World")
4444
'hello_world'
4545
>>> str_to_snake_case("Hello_World")
46-
'hello_world'
46+
'hello__world'
4747
>>> str_to_snake_case("Hello World")
4848
'hello_world'
4949
>>> str_to_snake_case("hello_world")

tests/cyclops/report/test_report.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Test cyclops report module model report."""
2+
3+
from unittest import TestCase
4+
5+
from cyclops.report import ModelCardReport
6+
7+
8+
class TestModelCardReport(TestCase):
9+
"""Test ModelCardReport."""
10+
11+
def setUp(self):
12+
"""Set up test fixtures."""
13+
self.model_card_report = ModelCardReport("reports")
14+
15+
def test_instantiation_with_optional_output_dir(self):
16+
"""Test instantiation with optional output_dir."""
17+
assert self.model_card_report.output_dir == "reports"
18+
19+
def test_log_owner_with_name(self):
20+
"""Test log_owner with name."""
21+
self.model_card_report.log_owner(name="John Doe")
22+
assert (
23+
self.model_card_report._model_card.model_details.owners[0].name
24+
== "John Doe"
25+
)
26+
27+
def test_log_owner_with_name_and_contact(self):
28+
"""Test log_owner with name and contact."""
29+
self.model_card_report.log_owner(
30+
name="John Doe",
31+
contact="john.doe@example.com",
32+
)
33+
assert (
34+
self.model_card_report._model_card.model_details.owners[0].name
35+
== "John Doe"
36+
)
37+
assert (
38+
self.model_card_report._model_card.model_details.owners[0].contact
39+
== "john.doe@example.com"
40+
)
41+
42+
def test_log_owner_with_name_and_role(self):
43+
"""Test log_owner with name and role."""
44+
self.model_card_report.log_owner(name="John Doe", role="Developer")
45+
assert (
46+
self.model_card_report._model_card.model_details.owners[0].name
47+
== "John Doe"
48+
)
49+
assert (
50+
self.model_card_report._model_card.model_details.owners[0].role
51+
== "Developer"
52+
)
53+
54+
def test_valid_name_and_description(self):
55+
"""Test valid name and description."""
56+
self.model_card_report.log_descriptor(
57+
name="ethical_considerations",
58+
description="This model was trained on data collected from a potentially biased source.",
59+
section_name="considerations",
60+
)
61+
62+
section = self.model_card_report._model_card.get_section("considerations")
63+
descriptor = section.ethical_considerations
64+
65+
assert (
66+
descriptor[0].description
67+
== "This model was trained on data collected from a potentially biased source."
68+
)
69+
70+
def test_log_user_with_description_to_considerations_section(self):
71+
"""Test log_user with description to considerations section."""
72+
self.model_card_report.log_user(description="This is a user description")
73+
assert len(self.model_card_report._model_card.considerations.users) == 1
74+
assert (
75+
self.model_card_report._model_card.considerations.users[0].description
76+
== "This is a user description"
77+
)
78+
79+
def test_log_performance_metric(self):
80+
"""Test log_performance_metric."""
81+
self.model_card_report.log_quantitative_analysis(
82+
analysis_type="performance",
83+
name="accuracy",
84+
value=0.85,
85+
metric_slice="test",
86+
decision_threshold=0.8,
87+
description="Accuracy of the model on the test set",
88+
pass_fail_thresholds=[0.9, 0.85, 0.8],
89+
pass_fail_threshold_fns=[lambda x, t: x >= t for _ in range(3)],
90+
)
91+
assert (
92+
self.model_card_report._model_card.quantitative_analysis.performance_metrics[
93+
0
94+
].type
95+
== "accuracy"
96+
)
97+
assert (
98+
self.model_card_report._model_card.quantitative_analysis.performance_metrics[
99+
0
100+
].value
101+
== 0.85
102+
)
103+
assert (
104+
self.model_card_report._model_card.quantitative_analysis.performance_metrics[
105+
0
106+
].slice
107+
== "test"
108+
)
109+
assert (
110+
self.model_card_report._model_card.quantitative_analysis.performance_metrics[
111+
0
112+
].decision_threshold
113+
== 0.8
114+
)
115+
assert (
116+
self.model_card_report._model_card.quantitative_analysis.performance_metrics[
117+
0
118+
].description
119+
== "Accuracy of the model on the test set"
120+
)
121+
assert (
122+
len(
123+
self.model_card_report._model_card.quantitative_analysis.performance_metrics[
124+
0
125+
].tests,
126+
)
127+
== 3
128+
)
129+
130+
def test_log_quantitative_analysis_performance(self):
131+
"""Test log_quantitative_analysis (performance)."""
132+
self.model_card_report.log_quantitative_analysis(
133+
analysis_type="performance",
134+
name="accuracy",
135+
value=0.85,
136+
)
137+
assert (
138+
self.model_card_report._model_card.quantitative_analysis.performance_metrics[
139+
0
140+
].type
141+
== "accuracy"
142+
)
143+
assert (
144+
self.model_card_report._model_card.quantitative_analysis.performance_metrics[
145+
0
146+
].value
147+
== 0.85
148+
)
149+
150+
def test_log_quantitative_analysis_fairness(self):
151+
"""Test log_quantitative_analysis (fairness)."""
152+
self.model_card_report.log_quantitative_analysis(
153+
analysis_type="fairness",
154+
name="disparate_impact",
155+
value=0.9,
156+
)
157+
assert (
158+
self.model_card_report._model_card.fairness_analysis.fairness_reports[
159+
0
160+
].type
161+
== "disparate_impact"
162+
)
163+
assert (
164+
self.model_card_report._model_card.fairness_analysis.fairness_reports[
165+
0
166+
].value
167+
== 0.9
168+
)

0 commit comments

Comments
 (0)