Skip to content

Commit

Permalink
feature(Report): Calibration plot and many fixes. (#164)
Browse files Browse the repository at this point in the history
* feature(Report): Callibration plot.

* fix: Calling diagrams happend twice.

* chore: Diagram title is updated already elsewhere.

* fix: Spelling.

* fix: title
  • Loading branch information
szemyd authored May 12, 2023
1 parent c115d3e commit 411a333
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 51 deletions.
18 changes: 10 additions & 8 deletions docs/examples/evaluate_classification_with_probabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@

from krisi import score

score(
y=np.random.randint(0, 2, 1000),
predictions=np.random.randint(0, 2, 1000),
probabilities=np.random.uniform(0, 1, 1000),
sc = score(
y=np.random.randint(0, 1, 1000),
predictions=np.random.randint(0, 1, 1000),
probabilities=np.random.uniform(0, 1, (1000, 1)),
# classification=True, # Optional, tries to decide based on if target contains integers
calculation="single",
).print()
)
sc.print()
sc.generate_report()

score(
y=np.random.randint(0, 2, 1000),
predictions=np.random.randint(0, 2, 1000),
probabilities=np.random.uniform(0, 1, 1000),
y=np.random.randint(0, 1, 1000),
predictions=np.random.randint(0, 1, 1000),
probabilities=np.random.uniform(0, 1, (1000, 1)),
# classification=True, # Optional, tries to decide based on if target contains integers
calculation="rolling",
).print()
37 changes: 28 additions & 9 deletions src/krisi/evaluate/library/default_metrics_classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pandas as pd
from sklearn.metrics import (
accuracy_score,
brier_score_loss,
Expand All @@ -7,7 +8,11 @@
recall_score,
)

from krisi.evaluate.library.diagrams import display_single_value, display_time_series
from krisi.evaluate.library.diagrams import (
callibration_plot,
display_single_value,
display_time_series,
)
from krisi.evaluate.library.metric_wrappers import brier_multi
from krisi.evaluate.metric import Metric
from krisi.evaluate.type import MetricCategories
Expand All @@ -18,7 +23,7 @@
category=MetricCategories.class_err,
info="In multilabel classification, this function computes subset accuracy: the set of labels predicted for a sample must exactly match the corresponding set of labels in y_true. https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html",
func=accuracy_score,
plot_funcs=[(display_single_value, dict(width=500.0))],
plot_funcs=[(display_single_value, dict(width=750.0))],
plot_func_rolling=(display_time_series, dict(width=1500.0)),
)
""" ~ """
Expand All @@ -30,7 +35,7 @@
info="The recall is the ratio tp / (tp + fn) where tp is the number of true positives and fn the number of false negatives. The recall is intuitively the ability of the classifier to find all the positive samples.\nThe best value is 1 and the worst value is 0. https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html",
parameters={"average": "binary"},
func=recall_score,
plot_funcs=[(display_single_value, dict(width=500.0))],
plot_funcs=[(display_single_value, dict(width=750.0))],
plot_func_rolling=(display_time_series, dict(width=1500.0)),
)
"""~"""
Expand All @@ -41,7 +46,7 @@
info="The precision is the ratio tp / (tp + fp) where tp is the number of true positives and fp the number of false positives. The precision is intuitively the ability of the classifier not to label as positive a sample that is negative.\nThe best value is 1 and the worst value is 0. https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html",
parameters={"average": "binary"},
func=precision_score,
plot_funcs=[(display_single_value, dict(width=500.0))],
plot_funcs=[(display_single_value, dict(width=750.0))],
plot_func_rolling=(display_time_series, dict(width=1500.0)),
)
"""~"""
Expand All @@ -52,7 +57,7 @@
info="The F1 score can be interpreted as a harmonic mean of the precision and recall, where an F1 score reaches its best value at 1 and worst score at 0. The relative contribution of precision and recall to the F1 score are equal. https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html",
parameters={"average": "macro"},
func=f1_score,
plot_funcs=[(display_single_value, dict(width=500.0))],
plot_funcs=[(display_single_value, dict(width=750.0))],
plot_func_rolling=(display_time_series, dict(width=1500.0)),
supports_multiclass=True,
)
Expand All @@ -63,7 +68,7 @@
category=MetricCategories.class_err,
info="The Matthews correlation coefficient is used in machine learning as a measure of the quality of binary and multiclass classifications. It takes into account true and false positives and negatives and is generally regarded as a balanced measure which can be used even if the classes are of very different sizes. The MCC is in essence a correlation coefficient value between -1 and +1. A coefficient of +1 represents a perfect prediction, 0 an average random prediction and -1 an inverse prediction. The statistic is also known as the phi coefficient.",
func=matthews_corrcoef,
plot_funcs=[(display_single_value, dict(width=500.0))],
plot_funcs=[(display_single_value, dict(width=750.0))],
plot_func_rolling=(display_time_series, dict(width=1500.0)),
)
"""~"""
Expand All @@ -76,20 +81,33 @@
y_true=y, y_prob=prob, **kwargs
),
parameters=dict(pos_label=1),
plot_funcs=[(display_single_value, dict(width=500.0))],
plot_funcs=[(display_single_value, dict(width=750.0))],
plot_func_rolling=(display_time_series, dict(width=1500.0)),
accepts_probabilities=True,
)
"""~"""
calibration = Metric[float](
name="Calibration Plot",
key="calibration",
category=MetricCategories.class_err,
info="Used to plot the calibration of a model with it probabilities.",
func=lambda y, pred, prob: pd.concat(
[y, prob.iloc[:, 0].rename("probs")], axis="columns"
),
plot_funcs=[(callibration_plot, dict(width=1500.0, bin_size=0.1))],
plot_func_rolling=(display_time_series, dict(width=1500.0)),
accepts_probabilities=True,
supports_multiclass=True,
)
"""~"""

brier_score_multi = Metric[float](
name="Brier Score Multilabel",
key="brier_score_multi",
category=MetricCategories.class_err,
info="Multilabel calculation of the Brier score loss. The smaller the Brier score loss, the better, hence the naming with “loss”. The Brier score measures the mean squared difference between the predicted probability and the actual outcome. The Brier score always takes on a value between zero and one, since this is the largest possible difference between a predicted probability (which must be between zero and one) and the actual outcome (which can take on values of only 0 and 1). It can be decomposed as the sum of refinement loss and calibration loss.",
func=lambda y, pred, prob, **kwargs: brier_multi(y, prob, **kwargs),
parameters=dict(pos_label=1),
plot_funcs=[(display_single_value, dict(width=500.0))],
plot_funcs=[(display_single_value, dict(width=750.0))],
plot_func_rolling=(display_time_series, dict(width=1500.0)),
accepts_probabilities=True,
supports_multiclass=True,
Expand All @@ -103,6 +121,7 @@
matthew_corr,
brier_score,
brier_score_multi,
calibration,
]
"""~"""
minimal_classification_metrics = [accuracy, f_one_score]
Expand Down
66 changes: 48 additions & 18 deletions src/krisi/evaluate/library/diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@
from krisi.evaluate.type import MetricResult


def display_time_series(data: List[MetricResult], title: str = "") -> "go.Figure":
def display_time_series(data: List[MetricResult], **kwargs) -> "go.Figure":
import plotly.express as px

title = kwargs.get("title", "")
df = pd.DataFrame(data, columns=[title])
df["iteration"] = list(range(len(data)))
fig = px.line(
df,
x="iteration",
y=title,
)
fig.update_layout(title=title)
return fig


def display_single_value(data: MetricResult, title: str = "") -> "go.Figure":
def display_single_value(data: MetricResult, **kwargs) -> "go.Figure":
import plotly.graph_objects as go

fig = go.Figure()
Expand All @@ -39,22 +39,15 @@ def display_single_value(data: MetricResult, title: str = "") -> "go.Figure":
)
)

fig.update_layout(title=title)
return fig


def display_acf_plot(
data: MetricResult,
title: str = "",
plot_pacf: bool = False,
data: MetricResult, plot_pacf: bool = False, **kwargs
) -> "go.Figure":
import plotly.graph_objects as go

title = (
title + " - Partial Autocorrelation (PACF)"
if plot_pacf
else "Autocorrelation (ACF)"
)
title = "Partial Autocorrelation (PACF)" if plot_pacf else "Autocorrelation (ACF)"

if not isinstance(data, pd.Series):
data = pd.Series(data)
Expand Down Expand Up @@ -101,16 +94,53 @@ def display_acf_plot(
return fig


def display_density_plot(
data: MetricResult,
title: str = "",
plot_pacf: bool = False,
) -> "go.Figure":
def display_density_plot(data: MetricResult, **kwargs) -> "go.Figure":
import plotly.express as px

if not isinstance(data, pd.Series):
data = pd.Series(data)
fig = px.histogram(data, marginal="box") # or violin, rug
fig.update_layout(title=title)

return fig


def callibration_plot(
data: MetricResult, bin_size: float = 0.1, **kwargs
) -> "go.Figure":
import plotly.graph_objects as go

y_true = data["y"]
y_prob = data["probs"]

# sort probabilities and corresponding true labels in ascending order
order = np.argsort(y_prob)
y_true = y_true[order]
y_prob = y_prob[order]

# calculate fraction of positives at each probability bin
bins = np.arange(0, 1.1, bin_size)
bin_indices = np.digitize(y_prob, bins)
bin_counts = np.bincount(bin_indices, minlength=len(bins) + 1)
fraction_positives = np.cumsum(bin_counts[:-1]) / np.sum(y_true)

scatter = go.Scatter(
x=bins, y=fraction_positives, mode="lines+markers", name="Data Points"
)

perfect_calibration = go.Scatter(
x=[0, 1],
y=[0, 1],
mode="lines",
name="Perfect Calibration",
line=dict(color="black", dash="dash"),
)

layout = go.Layout(
title="Calibration Curve",
xaxis=dict(title="Predicted Probability", tickvals=bins),
yaxis=dict(title="Fraction of Positives"),
showlegend=True,
)

fig = go.Figure(data=[scatter, perfect_calibration], layout=layout)
return fig
12 changes: 9 additions & 3 deletions src/krisi/report/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from krisi.evaluate.type import ScoreCardMetadata
from krisi.report.type import InteractiveFigure, PlotlyInput
from krisi.utils.environment import is_notebook
from krisi.utils.iterable_helpers import flatten
from krisi.utils.iterable_helpers import del_dict_keys, flatten

if TYPE_CHECKING:
from dash import dcc, html
Expand All @@ -28,7 +28,10 @@ def figure_with_controller(figure: InteractiveFigure):
return block(
graph=dcc.Graph(
id=figure.id,
figure=figure.get_figure(width=figure.plot_args["width"] - 216.0),
figure=figure.get_figure(
width=figure.plot_args["width"] - 216.0,
**del_dict_keys(figure.plot_args, "width"),
),
className="h-full flex align-center",
),
title=None,
Expand All @@ -51,7 +54,10 @@ def figure_with_controller(figure: InteractiveFigure):
return dcc.Graph(
className="h-full flex align-center",
id=figure.id,
figure=figure.get_figure(width=figure.plot_args["width"] - 216.0),
figure=figure.get_figure(
width=figure.plot_args["width"] - 216.0,
**del_dict_keys(figure.plot_args, "width"),
),
style={"display": "inline-block"},
)

Expand Down
14 changes: 8 additions & 6 deletions src/krisi/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,26 +202,28 @@ def create_report_from_scorecard(
report_title: str,
save_path: Path,
) -> Report:
custom_metric_html, interactive_figures = get_waterfall_metric_html(
obj.get_all_metrics()
)

get_html_elements = None
if DisplayModes.pdf in display_modes or DisplayModes.pdf.value in display_modes:
custom_metric_html, interactive_diagrams = get_waterfall_metric_html(
obj.get_all_metrics()
)
get_html_elements = get_html_elements_for_injection_scorecard(
obj=obj,
author=author,
project_name=obj.metadata.project_name,
date=datetime.datetime.now().strftime("%Y-%m-%d"),
custom_metric_html=custom_metric_html,
)
else:
get_html_elements = None
interactive_diagrams = get_all_interactive_diagrams(obj.get_all_metrics())
interactive_diagrams.sort(key=lambda x: x.plot_args["width"], reverse=True)

return Report(
title=report_title,
general_description="General Description",
modes=display_modes,
scorecard_metadata=obj.metadata,
figures=interactive_figures,
figures=interactive_diagrams,
html_template_url=html_template_url,
css_template_url=css_template_url,
get_html_elements=get_html_elements,
Expand Down
11 changes: 4 additions & 7 deletions src/krisi/report/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,15 @@ def plotly_interactive(
default_kwargs = kwargs

def wrapper(*args, **kwargs) -> "go.Figure":
width = kwargs.pop("width", None)
height = kwargs.pop("height", None)
title = kwargs.pop("title", None)

for key, value in default_kwargs.items():
if key not in kwargs:
kwargs[key] = value

fig = plot_function(data_source, *args, **kwargs)
width = kwargs.pop("width", None)
height = kwargs.pop("height", None)
title = kwargs.get("title", None)

# fig.update_layout(width=width)
# fig.update_layout(autosize=False, width=width, height=height)
fig = plot_function(data_source, *args, **kwargs)
fig.update_layout(autosize=False, width=width, height=height)
if title is not None:
fig.update_layout(title=title)
Expand Down
5 changes: 5 additions & 0 deletions src/krisi/utils/iterable_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,8 @@ def is_int(s: Any):
return False
else:
return True


def del_dict_keys(d: dict, keys: Union[str, List[str]]) -> dict:
keys = wrap_in_list(keys)
return {k: v for k, v in d.items() if k not in keys}

0 comments on commit 411a333

Please sign in to comment.