Skip to content

Commit

Permalink
Add confusion matrix example to tutorial notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Nov 23, 2023
1 parent b174a88 commit dac6a4b
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 472 deletions.
148 changes: 1 addition & 147 deletions cyclops/models/plotter.py
Original file line number Diff line number Diff line change
@@ -1,152 +1,7 @@
"""Plotting functions."""

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import seaborn as sns

from cyclops.models.utils import metrics_binary


def plot_pretty_confusion_matrix(confusion_matrix: np.ndarray) -> None:
"""Plot pretty confusion matrix.
Parameters
----------
confusion_matrix : np.ndarray
confusion matrix
"""
sns.set(style="white")
_, ax = plt.subplots(figsize=(9, 6))
sns.heatmap(
np.eye(2),
annot=confusion_matrix,
fmt="g",
annot_kws={"size": 50},
cmap=sns.color_palette(["tomato", "palegreen"], as_cmap=True),
cbar=False,
yticklabels=["True", "False"],
xticklabels=["True", "False"],
ax=ax,
)
ax.xaxis.tick_top()
ax.xaxis.set_label_position("top")
ax.tick_params(labelsize=20, length=0)

ax.set_title("Confusion Matrix for Test Set", size=24, pad=20)
ax.set_xlabel("Predicted Values", size=20)
ax.set_ylabel("Actual Values", size=20)

additional_texts = [
"(True Positive)",
"(False Negative)",
"(False Positive)",
"(True Negative)",
]
for text_elt, additional_text in zip(ax.texts, additional_texts):
ax.text(
*text_elt.get_position(),
"\n" + additional_text,
color=text_elt.get_color(),
ha="center",
va="top",
size=24,
)
plt.tight_layout()
plt.show()


def plot_confusion_matrix(confusion_matrix: np.ndarray, class_names: list) -> go.Figure:
"""Plot confusion matrix.
Parameters
----------
confusion_matrix : np.ndarray
confusion matrix
class_names : list
data class names
Returns
-------
go.Figure
plot figure
"""
confusion_matrix = (
confusion_matrix.astype("float") / confusion_matrix.sum(axis=1)[:, np.newaxis]
)

layout = {
"title": "Confusion Matrix",
"xaxis": {"title": "Predicted value"},
"yaxis": {"title": "Real value"},
}

fig = go.Figure(
data=go.Heatmap(
z=confusion_matrix,
x=class_names,
y=class_names,
hoverongaps=False,
colorscale="Greens",
),
layout=layout,
)
fig.update_layout(height=512, width=1024)
return fig


def plot_auroc_across_timesteps(
y_pred_values: np.ndarray,
y_pred_labels: np.ndarray,
y_test_labels: np.ndarray,
) -> go.Figure:
"""Plot AUC_ROC across timesteps.
Parameters
----------
y_pred_values : np.ndarray
prediction values
y_pred_labels : np.ndarray
prediction labels
y_test_labels : np.ndarray
data labels
Returns
-------
go.Figure
plot figures
"""
num_timesteps = y_pred_labels.shape[1]
auroc_timesteps = []
for i in range(num_timesteps):
labels = y_test_labels[:, i]
pred_vals = y_pred_values[:, i]
preds = y_pred_labels[:, i]
pred_vals = pred_vals[labels != -1]
preds = preds[labels != -1]
labels = labels[labels != -1]
pred_metrics = metrics_binary(labels, pred_vals, preds, verbose=False)
auroc_timesteps.append(pred_metrics["auroc"])

print(auroc_timesteps)

prediction_hours = list(range(24, 168, 24))
fig = go.Figure(
data=[go.Bar(x=prediction_hours, y=auroc_timesteps, name="model confidence")],
)

fig.update_xaxes(tickvals=prediction_hours)
fig.update_yaxes(range=[min(auroc_timesteps) - 0.05, max(auroc_timesteps) + 0.05])

fig.update_layout(
title="AUROC split by no. of hours after admission",
autosize=False,
xaxis_title="No. of hours after admission",
)
return fig


def plot_risk_mortality(predictions: np.ndarray, labels: np.ndarray) -> go.Figure:
Expand Down Expand Up @@ -218,13 +73,12 @@ def plot_risk_mortality(predictions: np.ndarray, labels: np.ndarray) -> go.Figur
fig.update_yaxes(range=[label_h, 1])
fig.update_xaxes(tickvals=prediction_hours)
fig.update_xaxes(showline=True, linewidth=2, linecolor="black")

fig.add_hline(y=0.5)

fig.update_layout(
title="Model output visualization",
autosize=False,
xaxis_title="No. of hours after admission",
yaxis_title="Model confidence",
)

return fig
207 changes: 0 additions & 207 deletions cyclops/models/predictor.py

This file was deleted.

Loading

0 comments on commit dac6a4b

Please sign in to comment.