Skip to content

Commit 3d8283f

Browse files
authored
Merge pull request #496 from VectorInstitute/report_baseline_comparison
Report Template Update
2 parents 03097ef + 533fe36 commit 3d8283f

File tree

14 files changed

+2574
-1174
lines changed

14 files changed

+2574
-1174
lines changed

cyclops/data/loader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from datasets.arrow_dataset import Dataset
99
from datasets.features import Image, Value
1010
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
11+
from sklearn.model_selection import GroupShuffleSplit
1112

1213
from cyclops.data.preprocess import nihcxr_preprocess
1314
from cyclops.data.utils import generate_timestamps
@@ -20,6 +21,7 @@ def load_nihcxr(
2021
train_time_range: Tuple[str, str] = ("1/1/2019", "10/19/2019"),
2122
test_time_range: Tuple[str, str] = ("10/20/2019", "12/25/2019"),
2223
progress: bool = False,
24+
seed: int = 0,
2325
) -> Dataset:
2426
"""Load NIH Chest X-Ray dataset as a Huggingface dataset."""
2527
if not progress:
@@ -40,9 +42,15 @@ def load_nihcxr(
4042
train_df = df[df["Image Index"].isin(train_id)]
4143
test_df = df[df["Image Index"].isin(test_id)]
4244

45+
gss = GroupShuffleSplit(train_size=0.8, test_size=0.2, random_state=seed)
46+
train_inds, val_inds = next(
47+
gss.split(X=range(len(train_df)), groups=train_df["Patient ID"]),
48+
)
49+
4350
nih_ds = DatasetDict(
4451
{
45-
"train": Dataset.from_pandas(train_df),
52+
"train": Dataset.from_pandas(train_df.iloc[train_inds]),
53+
"val": Dataset.from_pandas(train_df.iloc[val_inds]),
4654
"test": Dataset.from_pandas(test_df),
4755
},
4856
)

cyclops/monitor/detector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,7 @@ def balanced_sensitivity_test(
316316
np.random.choice(ds_target.shape[0], sample, replace=False),
317317
)
318318
ds_target_balanced = concatenate_datasets(
319-
ds_target_sample1,
320-
ds_target_sample2,
319+
[ds_target_sample1, ds_target_sample2],
321320
)
322321

323322
drift_results = self._detect_shift_sample(ds_target_balanced)

cyclops/report/model_card/fields.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,90 @@ class FairnessReport(BaseModelCardField, composable_with=["FairnessAnalysis"]):
545545
description="Tests related to fairness considerations.",
546546
default_factory=list,
547547
)
548+
549+
550+
class MetricCard(
551+
BaseModelCardField,
552+
list_factory=True,
553+
composable_with=["MetricCardCollection"],
554+
):
555+
"""A metric card."""
556+
557+
name: Optional[StrictStr] = Field(
558+
None,
559+
description="The name of the metric.",
560+
)
561+
562+
type: Optional[StrictStr] = Field(
563+
None,
564+
description="The type of metric.",
565+
)
566+
567+
slice: Optional[StrictStr] = Field(
568+
None,
569+
description="The name of the slice the metric was computed on.",
570+
)
571+
572+
tooltip: Optional[StrictStr] = Field(
573+
None,
574+
description="A tooltip for the metric.",
575+
)
576+
577+
value: Optional[StrictFloat] = Field(
578+
None,
579+
description="The value of the metric.",
580+
)
581+
582+
threshold: Optional[StrictFloat] = Field(
583+
None,
584+
description="Threshold required to pass the test.",
585+
)
586+
587+
passed: Optional[StrictBool] = Field(
588+
None,
589+
description="Whether the model result satisfies the given threshold.",
590+
)
591+
592+
history: List[StrictFloat] = Field(
593+
None,
594+
description="History of the metric over time.",
595+
)
596+
597+
trend: Optional[StrictStr] = Field(
598+
None,
599+
description="The trend of the metric over time.",
600+
)
601+
602+
plot: Optional[GraphicsCollection] = Field(
603+
None,
604+
description="A plot of the performance over time.",
605+
)
606+
607+
608+
class MetricCardCollection(BaseModelCardField, composable_with="Overview"):
609+
"""A collection of metric cards to be displayed in the model card."""
610+
611+
metrics: Optional[List[StrictStr]] = Field(
612+
None,
613+
description="A list of metric names in the Metric Card collection.",
614+
)
615+
616+
tooltips: Optional[List[StrictStr]] = Field(
617+
None,
618+
description="A list of tooltips in the Metric Card collection.",
619+
)
620+
621+
slices: Optional[List[StrictStr]] = Field(
622+
None,
623+
description="A list of slices in the Metric Card collection.",
624+
)
625+
626+
values: Optional[List[List[StrictStr]]] = Field(
627+
None,
628+
description="A list of values for each slice in the Metric Card collection.",
629+
)
630+
631+
collection: Optional[List[MetricCard]] = Field(
632+
description="A collection of metric cards.",
633+
default_factory=list,
634+
)

cyclops/report/model_card/model_card.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
FairnessAnalysis,
1414
ModelDetails,
1515
ModelParameters,
16+
Overview,
1617
QuantitativeAnalysis,
1718
)
1819

@@ -30,27 +31,15 @@ class Config(BaseModelCardConfig):
3031

3132
extra: Extra = Extra.forbid
3233

33-
model_details: Optional[ModelDetails] = Field(
34-
None,
35-
description="Descriptive metadata for the model.",
36-
)
37-
model_parameters: Optional[ModelParameters] = Field(
34+
overview: Optional[Overview] = Field(
3835
None,
39-
description="Technical metadata for the model.",
36+
description="A high-level overview of the model.",
4037
)
4138
datasets: Optional[Datasets] = Field(
4239
None,
4340
description="Information about the datasets used to train, validate \
4441
and/or test the model.",
4542
)
46-
considerations: Optional[Considerations] = Field(
47-
None,
48-
description=inspect.cleandoc(
49-
"""
50-
Any considerations related to model construction, training, and
51-
application""",
52-
),
53-
)
5443
quantitative_analysis: Optional[QuantitativeAnalysis] = Field(
5544
None,
5645
description="Quantitative analysis of model performance.",
@@ -63,6 +52,22 @@ class Config(BaseModelCardConfig):
6352
None,
6453
description="Fairness analysis being reported.",
6554
)
55+
model_details: Optional[ModelDetails] = Field(
56+
None,
57+
description="Descriptive metadata for the model.",
58+
)
59+
model_parameters: Optional[ModelParameters] = Field(
60+
None,
61+
description="Technical metadata for the model.",
62+
)
63+
considerations: Optional[Considerations] = Field(
64+
None,
65+
description=inspect.cleandoc(
66+
"""
67+
Any considerations related to model construction, training, and
68+
application""",
69+
),
70+
)
6671

6772
def get_section(self, section_name: str) -> BaseModelCardSection:
6873
"""Retrieve a section from the model card.

cyclops/report/model_card/sections.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
GraphicsCollection,
1616
KeyVal,
1717
License,
18+
MetricCardCollection,
1819
Owner,
1920
PerformanceMetric,
2021
Reference,
@@ -26,6 +27,15 @@
2627
)
2728

2829

30+
class Overview(BaseModelCardSection):
31+
"""Overview section with aggregate metrics."""
32+
33+
metric_cards: Optional[MetricCardCollection] = Field(
34+
None,
35+
description="Comparative metrics between baseline and periodic report.",
36+
)
37+
38+
2939
class ModelDetails(BaseModelCardSection):
3040
"""Details about the model."""
3141

0 commit comments

Comments
 (0)