Skip to content

Commit

Permalink
Fix PR curve (order), add AUPRC to comparison plot
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Dec 7, 2023
1 parent cb7d32f commit 7183444
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 17 deletions.
46 changes: 33 additions & 13 deletions cyclops/report/plot/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def roc_curve(
assert isinstance(
auroc,
float,
), "aurocs must be a float for binary tasks"
), "AUROCs must be a float for binary tasks"
name = f"Model (AUC = {auroc:.2f})"
else:
name = "Model"
Expand All @@ -153,7 +153,7 @@ def roc_curve(
if auroc is not None:
assert (
len(auroc) == self.class_num # type: ignore[arg-type]
), "Aurocs must be of length class_num for \
), "AUROCs must be of length class_num for \
multiclass/multilabel tasks"
name = f"{self.class_names[i]} (AUC = {auroc[i]:.2f})" # type: ignore[index] # noqa: E501
else:
Expand Down Expand Up @@ -228,7 +228,7 @@ def roc_curve_comparison(
assert isinstance(
aurocs[slice_name],
float,
), "Aurocs must be a float for binary tasks"
), "AUROCs must be a float for binary tasks"
name = f"{slice_name} (AUC = {aurocs[slice_name]:.2f})"
else:
name = slice_name
Expand All @@ -246,13 +246,13 @@ def roc_curve_comparison(
for slice_name, slice_curve in roc_curves.items():
assert (
len(slice_curve[0]) == len(slice_curve[1]) == self.class_num
), f"Fprs and tprs must be of length class_num for \
), f"FPRs and TPRs must be of length class_num for \
multiclass/multilabel tasks in slice {slice_name}"
for i in range(self.class_num):
if aurocs and slice_name in aurocs:
assert (
len(aurocs[slice_name]) == self.class_num # type: ignore[arg-type] # noqa: E501
), "Aurocs must be of length class_num for \
), "AUROCs must be of length class_num for \
multiclass/multilabel tasks"
name = f"{slice_name}, {self.class_names[i]} \
(AUC = {aurocs[i]:.2f})" # type: ignore[index]
Expand Down Expand Up @@ -324,8 +324,8 @@ def precision_recall_curve(
The figure object.
"""
recalls = precision_recall_curve[0]
precisions = precision_recall_curve[1]
recalls = precision_recall_curve[1]
precisions = precision_recall_curve[0]

if self.task_type == "binary":
trace = line_plot(
Expand Down Expand Up @@ -365,6 +365,9 @@ def precision_recall_curve(
def precision_recall_curve_comparison(
self,
precision_recall_curves: Dict[str, Tuple[npt.NDArray[np.float_], ...]],
auprcs: Optional[
Dict[str, Union[float, List[float], npt.NDArray[np.float_]]]
] = None,
title: Optional[str] = "Precision-Recall Curve Comparison",
layout: Optional[go.Layout] = None,
**plot_kwargs: Any,
Expand All @@ -377,6 +380,8 @@ def precision_recall_curve_comparison(
Dictionary of precision-recall curves, where the key is \
the group or subpopulation name and the value is a tuple \
of (recalls, precisions, thresholds)
auprcs : Dict[str, Union[float, list, np.ndarray]], optional
AUPRCs for each subpopulation or group specified by name, by default None
layout : Optional[go.Layout], optional
Customized figure layout, by default None
title: str, optional
Expand All @@ -393,11 +398,18 @@ def precision_recall_curve_comparison(
trace = []
if self.task_type == "binary":
for slice_name, slice_curve in precision_recall_curves.items():
name = f"{slice_name}"
if auprcs and slice_name in auprcs:
assert isinstance(
auprcs[slice_name],
float,
), "AUPRCs must be a float for binary tasks"
name = f"{slice_name} (AUC = {auprcs[slice_name]:.2f})"
else:
name = f"{slice_name}"
trace.append(
line_plot(
x=slice_curve[0],
y=slice_curve[1],
x=slice_curve[1],
y=slice_curve[0],
trace_name=name,
**plot_kwargs,
),
Expand All @@ -409,11 +421,19 @@ def precision_recall_curve_comparison(
), f"Recalls and precisions must be of length class_num for \
multiclass/multilabel tasks in slice {slice_name}"
for i in range(self.class_num):
name = f"{slice_name}: {self.class_names[i]}"
if auprcs and slice_name in auprcs:
assert (
len(auprcs[slice_name]) == self.class_num # type: ignore[arg-type] # noqa: E501
), "AUPRCs must be of length class_num for \
multiclass/multilabel tasks"
name = f"{slice_name}, {self.class_names[i]} \
(AUC = {auprcs[i]:.2f})"
else:
name = f"{slice_name}: {self.class_names[i]}"
trace.append(
line_plot(
x=slice_curve[0][i],
y=slice_curve[1][i],
x=slice_curve[1][i],
y=slice_curve[0][i],
trace_name=name,
**plot_kwargs,
),
Expand Down
44 changes: 42 additions & 2 deletions docs/source/tutorials/kaggle/heart_failure_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
"source": [
"class_counts = df[\"outcome\"].value_counts()\n",
"class_ratio = class_counts[0] / class_counts[1]\n",
"print(class_ratio)"
"print(class_ratio, class_counts)"
]
},
{
Expand Down Expand Up @@ -714,6 +714,7 @@
" \"recall\",\n",
" \"f1_score\",\n",
" \"auroc\",\n",
" \"average_precision\",\n",
" \"roc_curve\",\n",
" \"precision_recall_curve\",\n",
"]\n",
Expand Down Expand Up @@ -895,14 +896,15 @@
" results=results_female,\n",
" model_name=model_name,\n",
")\n",
"\n",
"# ruff: noqa: W505\n",
"for name, metric in results_female_flat.items():\n",
" split, name = name.split(\"/\") # noqa: PLW2901\n",
" descriptions = {\n",
" \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
" \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
" \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
" \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
" \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
" \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
" }\n",
" report.log_quantitative_analysis(\n",
Expand All @@ -922,6 +924,7 @@
" \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
" \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
" \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
" \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
" \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
" }\n",
" report.log_quantitative_analysis(\n",
Expand Down Expand Up @@ -986,6 +989,43 @@
"roc_plot.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# extracting the precision-recall curves and average precision results for all the slices\n",
"pr_curves = {\n",
" slice_name: slice_results[\"BinaryPrecisionRecallCurve\"]\n",
" for slice_name, slice_results in results[model_name].items()\n",
"}\n",
"average_precisions = {\n",
" slice_name: slice_results[\"BinaryAveragePrecision\"]\n",
" for slice_name, slice_results in results[model_name].items()\n",
"}\n",
"pr_curves.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plotting the precision-recall curves for all the slices\n",
"pr_plot = plotter.precision_recall_curve_comparison(\n",
" pr_curves,\n",
" auprcs=average_precisions,\n",
")\n",
"report.log_plotly_figure(\n",
" fig=pr_plot,\n",
" caption=\"Precision-Recall Curve Comparison\",\n",
" section_name=\"quantitative analysis\",\n",
")\n",
"pr_plot.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
44 changes: 42 additions & 2 deletions docs/source/tutorials/mimiciv/mortality_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@
"# The data is heavily unbalanced.\n",
"class_counts = cohort[\"mortality_outcome\"].value_counts()\n",
"class_ratio = class_counts[0] / class_counts[1]\n",
"print(class_ratio)"
"print(class_ratio, class_counts)"
]
},
{
Expand Down Expand Up @@ -780,6 +780,7 @@
" \"recall\",\n",
" \"f1_score\",\n",
" \"auroc\",\n",
" \"average_precision\",\n",
" \"roc_curve\",\n",
" \"precision_recall_curve\",\n",
"]\n",
Expand Down Expand Up @@ -937,13 +938,15 @@
"metadata": {},
"outputs": [],
"source": [
"# ruff: noqa: W505\n",
"for name, metric in results_flat.items():\n",
" split, name = name.split(\"/\") # noqa: PLW2901\n",
" descriptions = {\n",
" \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n",
" \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n",
" \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n",
" \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n",
" \"BinaryAveragePrecision\": \"The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.\",\n",
" \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n",
" }\n",
" report.log_quantitative_analysis(\n",
Expand Down Expand Up @@ -992,6 +995,24 @@
"roc_curves.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# extracting the precision-recall curves and average precision results for all the slices\n",
"pr_curves = {\n",
" slice_name: slice_results[\"BinaryPrecisionRecallCurve\"]\n",
" for slice_name, slice_results in results[model_name].items()\n",
"}\n",
"average_precisions = {\n",
" slice_name: slice_results[\"BinaryAveragePrecision\"]\n",
" for slice_name, slice_results in results[model_name].items()\n",
"}\n",
"pr_curves.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -1002,12 +1023,31 @@
"roc_plot = plotter.roc_curve_comparison(roc_curves, aurocs=aurocs)\n",
"report.log_plotly_figure(\n",
" fig=roc_plot,\n",
" caption=\"ROC Curve for Female Patients\",\n",
" caption=\"ROC Curve Comparison\",\n",
" section_name=\"quantitative analysis\",\n",
")\n",
"roc_plot.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plotting the precision-recall curves for all the slices\n",
"pr_plot = plotter.precision_recall_curve_comparison(\n",
" pr_curves,\n",
" auprcs=average_precisions,\n",
")\n",
"report.log_plotly_figure(\n",
" fig=pr_plot,\n",
" caption=\"Precision-Recall Curve Comparison\",\n",
" section_name=\"quantitative analysis\",\n",
")\n",
"pr_plot.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 7183444

Please sign in to comment.