|
933 | 933 | ")"
|
934 | 934 | ]
|
935 | 935 | },
|
| 936 | + { |
| 937 | + "cell_type": "markdown", |
| 938 | + "metadata": {}, |
| 939 | + "source": [ |
| 940 | + "**Log the performance metrics to the report.**\n", |
| 941 | + "\n", |
| 942 | + "We can add a performance metric to the model card using the `log_performance_metric` method, which expects a dictionary where the keys are in the following format: `slice_name/metric_name`. For instance, `overall/accuracy`. \n", |
| 943 | + "\n", |
| 944 | + "We first need to process the evaluation results to get the metrics in the right format." |
| 945 | + ] |
| 946 | + }, |
936 | 947 | {
|
937 | 948 | "cell_type": "code",
|
938 | 949 | "execution_count": null,
|
|
947 | 958 | ")"
|
948 | 959 | ]
|
949 | 960 | },
|
| 961 | + { |
| 962 | + "cell_type": "code", |
| 963 | + "execution_count": null, |
| 964 | + "metadata": {}, |
| 965 | + "outputs": [], |
| 966 | + "source": [ |
| 967 | + "test_datasets = [\n", |
| 968 | + " test_dataset.shard(NUM_EVALS, i, contiguous=True) for i in range(NUM_EVALS)\n", |
| 969 | + "]\n", |
| 970 | + "eval_timestamps = [test_dataset[\"dischtime\"][-1] for test_dataset in test_datasets]\n", |
| 971 | + "\n", |
| 972 | + "for i, test_dataset in enumerate(test_datasets):\n", |
| 973 | + " report = ModelCardReport()\n", |
| 974 | + " report.log_owner(\n", |
| 975 | + " name=\"CyclOps Team\",\n", |
| 976 | + " contact=\"vectorinstitute.github.io/cyclops/\",\n", |
| 977 | + " email=\"cyclops@vectorinstitute.ai\",\n", |
| 978 | + " )\n", |
| 979 | + " results, dataset_with_preds = mortality_task.evaluate(\n", |
| 980 | + " test_dataset,\n", |
| 981 | + " metric_collection,\n", |
| 982 | + " model_names=model_name,\n", |
| 983 | + " transforms=preprocessor,\n", |
| 984 | + " prediction_column_prefix=\"preds\",\n", |
| 985 | + " slice_spec=slice_spec,\n", |
| 986 | + " batch_size=64,\n", |
| 987 | + " fairness_config=fairness_config,\n", |
| 988 | + " override_fairness_metrics=False,\n", |
| 989 | + " )\n", |
| 990 | + " results_flat = flatten_results_dict(\n", |
| 991 | + " results=results,\n", |
| 992 | + " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", |
| 993 | + " model_name=model_name_results,\n", |
| 994 | + " )\n", |
| 995 | + " for name, metric in results_flat.items():\n", |
| 996 | + " split, name = name.split(\"/\") # noqa: PLW2901\n", |
| 997 | + " descriptions = {\n", |
| 998 | + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", |
| 999 | + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", |
| 1000 | + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", |
| 1001 | + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", |
| 1002 | + " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", |
| 1003 | + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", |
| 1004 | + " }\n", |
| 1005 | + " report.log_quantitative_analysis(\n", |
| 1006 | + " \"performance\",\n", |
| 1007 | + " name=name,\n", |
| 1008 | + " value=metric.tolist(),\n", |
| 1009 | + " description=descriptions[name],\n", |
| 1010 | + " metric_slice=split,\n", |
| 1011 | + " pass_fail_thresholds=0.7,\n", |
| 1012 | + " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", |
| 1013 | + " )\n", |
| 1014 | + " print(str(eval_timestamps[i]))\n", |
| 1015 | + " report_path = report.export(\n", |
| 1016 | + " output_filename=\"discharge_prediction_report_periodic.html\",\n", |
| 1017 | + " synthetic_timestamp=str(eval_timestamps[i]),\n", |
| 1018 | + " )\n", |
| 1019 | + " shutil.copy(f\"{report_path}\", \".\")" |
| 1020 | + ] |
| 1021 | + }, |
950 | 1022 | {
|
951 | 1023 | "cell_type": "markdown",
|
952 | 1024 | "metadata": {},
|
|
1165 | 1237 | "source": [
|
1166 | 1238 | "report.log_from_dict(\n",
|
1167 | 1239 | " data={\n",
|
1168 |
| - " \"name\": \"Mortality Prediction Model\",\n", |
| 1240 | + " \"name\": \"Discharge Prediction Model\",\n", |
1169 | 1241 | " \"description\": \"The model was trained on the MIMICIV dataset \\\n",
|
1170 |
| - " to predict risk of in-hospital mortality.\",\n", |
| 1242 | + " to predict probability of patient being discharged.\",\n", |
1171 | 1243 | " },\n",
|
1172 | 1244 | " section_name=\"model_details\",\n",
|
1173 | 1245 | ")\n",
|
|
1176 | 1248 | " date=str(date.today()),\n",
|
1177 | 1249 | " description=\"Initial Release\",\n",
|
1178 | 1250 | ")\n",
|
1179 |
| - "report.log_owner(\n", |
1180 |
| - " name=\"CyclOps Team\",\n", |
1181 |
| - " contact=\"vectorinstitute.github.io/cyclops/\",\n", |
1182 |
| - " email=\"cyclops@vectorinstitute.ai\",\n", |
1183 |
| - ")\n", |
1184 | 1251 | "report.log_license(identifier=\"Apache-2.0\")\n",
|
1185 | 1252 | "report.log_reference(\n",
|
1186 | 1253 | " link=\"https://xgboost.readthedocs.io/en/stable/python/python_api.html\", # noqa: E501\n",
|
|
1241 | 1308 | "Once the model card is populated, you can generate the report using the `export` method. The report is generated in the form of an HTML file. A JSON file containing the model card data will also be generated along with the HTML file. By default, the files will be saved in a folder named `cyclops_reports` in the current working directory. You can change the path by passing a `output_dir` argument when instantiating the `ModelCardReport` class."
|
1242 | 1309 | ]
|
1243 | 1310 | },
|
1244 |
| - { |
1245 |
| - "cell_type": "markdown", |
1246 |
| - "metadata": {}, |
1247 |
| - "source": [ |
1248 |
| - "**Log the performance metrics to the report.**\n", |
1249 |
| - "\n", |
1250 |
| - "We can add a performance metric to the model card using the `log_performance_metric` method, which expects a dictionary where the keys are in the following format: `slice_name/metric_name`. For instance, `overall/accuracy`. \n", |
1251 |
| - "\n", |
1252 |
| - "We first need to process the evaluation results to get the metrics in the right format." |
1253 |
| - ] |
1254 |
| - }, |
1255 |
| - { |
1256 |
| - "cell_type": "code", |
1257 |
| - "execution_count": null, |
1258 |
| - "metadata": { |
1259 |
| - "tags": [] |
1260 |
| - }, |
1261 |
| - "outputs": [], |
1262 |
| - "source": [ |
1263 |
| - "test_datasets = [\n", |
1264 |
| - " test_dataset.shard(NUM_EVALS, i, contiguous=True) for i in range(NUM_EVALS)\n", |
1265 |
| - "]" |
1266 |
| - ] |
1267 |
| - }, |
1268 | 1311 | {
|
1269 | 1312 | "cell_type": "code",
|
1270 | 1313 | "execution_count": null,
|
1271 | 1314 | "metadata": {},
|
1272 | 1315 | "outputs": [],
|
1273 | 1316 | "source": [
|
1274 |
| - "eval_timestamps = [test_dataset[\"dischtime\"][-1] for test_dataset in test_datasets]\n", |
1275 |
| - "\n", |
1276 |
| - "for i, test_dataset in enumerate(test_datasets):\n", |
1277 |
| - " if i > 0:\n", |
1278 |
| - " report = ModelCardReport()\n", |
1279 |
| - " report.log_owner(\n", |
1280 |
| - " name=\"CyclOps Team\",\n", |
1281 |
| - " contact=\"vectorinstitute.github.io/cyclops/\",\n", |
1282 |
| - " email=\"cyclops@vectorinstitute.ai\",\n", |
1283 |
| - " )\n", |
1284 |
| - " results, dataset_with_preds = mortality_task.evaluate(\n", |
1285 |
| - " test_dataset,\n", |
1286 |
| - " metric_collection,\n", |
1287 |
| - " model_names=model_name,\n", |
1288 |
| - " transforms=preprocessor,\n", |
1289 |
| - " prediction_column_prefix=\"preds\",\n", |
1290 |
| - " slice_spec=slice_spec,\n", |
1291 |
| - " batch_size=64,\n", |
1292 |
| - " )\n", |
1293 |
| - " results_flat = flatten_results_dict(\n", |
1294 |
| - " results=results,\n", |
1295 |
| - " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", |
1296 |
| - " model_name=model_name_results,\n", |
1297 |
| - " )\n", |
1298 |
| - " for name, metric in results_flat.items():\n", |
1299 |
| - " split, name = name.split(\"/\") # noqa: PLW2901\n", |
1300 |
| - " descriptions = {\n", |
1301 |
| - " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", |
1302 |
| - " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", |
1303 |
| - " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", |
1304 |
| - " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", |
1305 |
| - " \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n", |
1306 |
| - " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", |
1307 |
| - " }\n", |
1308 |
| - " report.log_quantitative_analysis(\n", |
1309 |
| - " \"performance\",\n", |
1310 |
| - " name=name,\n", |
1311 |
| - " value=metric.tolist(),\n", |
1312 |
| - " description=descriptions[name],\n", |
1313 |
| - " metric_slice=split,\n", |
1314 |
| - " pass_fail_thresholds=0.7,\n", |
1315 |
| - " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", |
1316 |
| - " )\n", |
1317 |
| - " print(str(eval_timestamps[i]))\n", |
1318 |
| - " report_path = report.export(\n", |
1319 |
| - " output_filename=\"discharge_prediction_report_periodic.html\",\n", |
1320 |
| - " synthetic_timestamp=str(eval_timestamps[i]),\n", |
1321 |
| - " )\n", |
1322 |
| - " shutil.copy(f\"{report_path}\", \".\")\n", |
| 1317 | + "report_path = report.export(\n", |
| 1318 | + " output_filename=\"discharge_prediction_report_periodic.html\",\n", |
| 1319 | + " synthetic_timestamp=str(eval_timestamps[-1]),\n", |
| 1320 | + ")\n", |
| 1321 | + "shutil.copy(f\"{report_path}\", \".\")\n", |
1323 | 1322 | "shutil.rmtree(\"./cyclops_reports\")"
|
1324 | 1323 | ]
|
1325 | 1324 | },
|
|
0 commit comments