Skip to content

Commit b1f9d6a

Browse files
committed
Small fixes to discharge prediction notebook
1 parent 9936852 commit b1f9d6a

File tree

3 files changed

+85
-83
lines changed

3 files changed

+85
-83
lines changed

benchmarks/mimiciv/discharge_prediction.ipynb

Lines changed: 79 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,17 @@
933933
")"
934934
]
935935
},
936+
{
937+
"cell_type": "markdown",
938+
"metadata": {},
939+
"source": [
940+
"**Log the performance metrics to the report.**\n",
941+
"\n",
942+
"We can add a performance metric to the model card using the `log_performance_metric` method, which expects a dictionary where the keys are in the following format: `slice_name/metric_name`. For instance, `overall/accuracy`. \n",
943+
"\n",
944+
"We first need to process the evaluation results to get the metrics in the right format."
945+
]
946+
},
936947
{
937948
"cell_type": "code",
938949
"execution_count": null,
@@ -947,6 +958,67 @@
947958
")"
948959
]
949960
},
961+
{
962+
"cell_type": "code",
963+
"execution_count": null,
964+
"metadata": {},
965+
"outputs": [],
966+
"source": [
967+
"test_datasets = [\n",
968+
" test_dataset.shard(NUM_EVALS, i, contiguous=True) for i in range(NUM_EVALS)\n",
969+
"]\n",
970+
"eval_timestamps = [test_dataset[\"dischtime\"][-1] for test_dataset in test_datasets]\n",
971+
"\n",
972+
"for i, test_dataset in enumerate(test_datasets):\n",
973+
" report = ModelCardReport()\n",
974+
" report.log_owner(\n",
975+
" name=\"CyclOps Team\",\n",
976+
" contact=\"vectorinstitute.github.io/cyclops/\",\n",
977+
" email=\"cyclops@vectorinstitute.ai\",\n",
978+
" )\n",
979+
" results, dataset_with_preds = mortality_task.evaluate(\n",
980+
" test_dataset,\n",
981+
" metric_collection,\n",
982+
" model_names=model_name,\n",
983+
" transforms=preprocessor,\n",
984+
" prediction_column_prefix=\"preds\",\n",
985+
" slice_spec=slice_spec,\n",
986+
" batch_size=64,\n",
987+
" fairness_config=fairness_config,\n",
988+
" override_fairness_metrics=False,\n",
989+
" )\n",
990+
" results_flat = flatten_results_dict(\n",
991+
" results=results,\n",
992+
" remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n",
993+
" model_name=model_name_results,\n",
994+
" )\n",
995+
" for name, metric in results_flat.items():\n",
996+
" split, name = name.split(\"/\") # noqa: PLW2901\n",
997+
" descriptions = {\n",
998+
" \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
999+
" \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
1000+
" \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
1001+
" \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
1002+
" \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
1003+
" \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
1004+
" }\n",
1005+
" report.log_quantitative_analysis(\n",
1006+
" \"performance\",\n",
1007+
" name=name,\n",
1008+
" value=metric.tolist(),\n",
1009+
" description=descriptions[name],\n",
1010+
" metric_slice=split,\n",
1011+
" pass_fail_thresholds=0.7,\n",
1012+
" pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
1013+
" )\n",
1014+
" print(str(eval_timestamps[i]))\n",
1015+
" report_path = report.export(\n",
1016+
" output_filename=\"discharge_prediction_report_periodic.html\",\n",
1017+
" synthetic_timestamp=str(eval_timestamps[i]),\n",
1018+
" )\n",
1019+
" shutil.copy(f\"{report_path}\", \".\")"
1020+
]
1021+
},
9501022
{
9511023
"cell_type": "markdown",
9521024
"metadata": {},
@@ -1165,9 +1237,9 @@
11651237
"source": [
11661238
"report.log_from_dict(\n",
11671239
" data={\n",
1168-
" \"name\": \"Mortality Prediction Model\",\n",
1240+
" \"name\": \"Discharge Prediction Model\",\n",
11691241
" \"description\": \"The model was trained on the MIMICIV dataset \\\n",
1170-
" to predict risk of in-hospital mortality.\",\n",
1242+
" to predict probability of patient being discharged.\",\n",
11711243
" },\n",
11721244
" section_name=\"model_details\",\n",
11731245
")\n",
@@ -1176,11 +1248,6 @@
11761248
" date=str(date.today()),\n",
11771249
" description=\"Initial Release\",\n",
11781250
")\n",
1179-
"report.log_owner(\n",
1180-
" name=\"CyclOps Team\",\n",
1181-
" contact=\"vectorinstitute.github.io/cyclops/\",\n",
1182-
" email=\"cyclops@vectorinstitute.ai\",\n",
1183-
")\n",
11841251
"report.log_license(identifier=\"Apache-2.0\")\n",
11851252
"report.log_reference(\n",
11861253
" link=\"https://xgboost.readthedocs.io/en/stable/python/python_api.html\", # noqa: E501\n",
@@ -1241,85 +1308,17 @@
12411308
"Once the model card is populated, you can generate the report using the `export` method. The report is generated in the form of an HTML file. A JSON file containing the model card data will also be generated along with the HTML file. By default, the files will be saved in a folder named `cyclops_reports` in the current working directory. You can change the path by passing a `output_dir` argument when instantiating the `ModelCardReport` class."
12421309
]
12431310
},
1244-
{
1245-
"cell_type": "markdown",
1246-
"metadata": {},
1247-
"source": [
1248-
"**Log the performance metrics to the report.**\n",
1249-
"\n",
1250-
"We can add a performance metric to the model card using the `log_performance_metric` method, which expects a dictionary where the keys are in the following format: `slice_name/metric_name`. For instance, `overall/accuracy`. \n",
1251-
"\n",
1252-
"We first need to process the evaluation results to get the metrics in the right format."
1253-
]
1254-
},
1255-
{
1256-
"cell_type": "code",
1257-
"execution_count": null,
1258-
"metadata": {
1259-
"tags": []
1260-
},
1261-
"outputs": [],
1262-
"source": [
1263-
"test_datasets = [\n",
1264-
" test_dataset.shard(NUM_EVALS, i, contiguous=True) for i in range(NUM_EVALS)\n",
1265-
"]"
1266-
]
1267-
},
12681311
{
12691312
"cell_type": "code",
12701313
"execution_count": null,
12711314
"metadata": {},
12721315
"outputs": [],
12731316
"source": [
1274-
"eval_timestamps = [test_dataset[\"dischtime\"][-1] for test_dataset in test_datasets]\n",
1275-
"\n",
1276-
"for i, test_dataset in enumerate(test_datasets):\n",
1277-
" if i > 0:\n",
1278-
" report = ModelCardReport()\n",
1279-
" report.log_owner(\n",
1280-
" name=\"CyclOps Team\",\n",
1281-
" contact=\"vectorinstitute.github.io/cyclops/\",\n",
1282-
" email=\"cyclops@vectorinstitute.ai\",\n",
1283-
" )\n",
1284-
" results, dataset_with_preds = mortality_task.evaluate(\n",
1285-
" test_dataset,\n",
1286-
" metric_collection,\n",
1287-
" model_names=model_name,\n",
1288-
" transforms=preprocessor,\n",
1289-
" prediction_column_prefix=\"preds\",\n",
1290-
" slice_spec=slice_spec,\n",
1291-
" batch_size=64,\n",
1292-
" )\n",
1293-
" results_flat = flatten_results_dict(\n",
1294-
" results=results,\n",
1295-
" remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n",
1296-
" model_name=model_name_results,\n",
1297-
" )\n",
1298-
" for name, metric in results_flat.items():\n",
1299-
" split, name = name.split(\"/\") # noqa: PLW2901\n",
1300-
" descriptions = {\n",
1301-
" \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
1302-
" \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
1303-
" \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
1304-
" \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
1305-
" \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
1306-
" \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
1307-
" }\n",
1308-
" report.log_quantitative_analysis(\n",
1309-
" \"performance\",\n",
1310-
" name=name,\n",
1311-
" value=metric.tolist(),\n",
1312-
" description=descriptions[name],\n",
1313-
" metric_slice=split,\n",
1314-
" pass_fail_thresholds=0.7,\n",
1315-
" pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n",
1316-
" )\n",
1317-
" print(str(eval_timestamps[i]))\n",
1318-
" report_path = report.export(\n",
1319-
" output_filename=\"discharge_prediction_report_periodic.html\",\n",
1320-
" synthetic_timestamp=str(eval_timestamps[i]),\n",
1321-
" )\n",
1322-
" shutil.copy(f\"{report_path}\", \".\")\n",
1317+
"report_path = report.export(\n",
1318+
" output_filename=\"discharge_prediction_report_periodic.html\",\n",
1319+
" synthetic_timestamp=str(eval_timestamps[-1]),\n",
1320+
")\n",
1321+
"shutil.copy(f\"{report_path}\", \".\")\n",
13231322
"shutil.rmtree(\"./cyclops_reports\")"
13241323
]
13251324
},

benchmarks/mimiciv/icu_mortality_prediction.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,9 +1214,9 @@
12141214
"source": [
12151215
"report.log_from_dict(\n",
12161216
" data={\n",
1217-
" \"name\": \"Mortality Prediction Model\",\n",
1217+
" \"name\": \"ICU Mortality Prediction Model\",\n",
12181218
" \"description\": \"The model was trained on the MIMICIV dataset \\\n",
1219-
" to predict risk of in-hospital mortality.\",\n",
1219+
" to predict risk of mortality in the ICU.\",\n",
12201220
" },\n",
12211221
" section_name=\"model_details\",\n",
12221222
")\n",
@@ -1265,7 +1265,7 @@
12651265
")\n",
12661266
"report.log_user(description=\"ML Engineers\")\n",
12671267
"report.log_use_case(\n",
1268-
" description=\"Predicting prolonged length of stay\",\n",
1268+
" description=\"Predicting ICU mortality\",\n",
12691269
" kind=\"primary\",\n",
12701270
")\n",
12711271
"report.log_fairness_assessment(\n",

docs/source/tutorials/nihcxr/monitor_api.ipynb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@
291291
"nbconvert_exporter": "python",
292292
"pygments_lexer": "ipython3",
293293
"version": "3.10.12"
294+
},
295+
"nbsphinx": {
296+
"execute": "never"
294297
}
295298
},
296299
"nbformat": 4,

0 commit comments

Comments
 (0)