-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 33120aa
Showing
139 changed files
with
38,255 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
*.pyc | ||
.vscode/* | ||
calzone.egg-info/ | ||
example_data/*result*.csv | ||
example_data/*output*.csv | ||
example_data/*png | ||
test_notebook/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import argparse | ||
import numpy as np | ||
import subprocess | ||
import time | ||
import matplotlib.pyplot as plt | ||
from nicegui import ui, app | ||
|
||
|
||
def run_program(): | ||
csv_file = csv_file_input.value | ||
save_metrics = save_metrics_input.value | ||
save_plot = save_plot_input.value | ||
|
||
selected_metrics = [metric for metric, checkbox in metrics_checkboxes.items() if checkbox.value] | ||
metric_arg = "all" if not selected_metrics else ",".join(selected_metrics) | ||
|
||
args = [ | ||
"--csv_file", str(csv_file), | ||
"--metrics", str(metric_arg), | ||
"--n_bootstrap", str(int(n_bootstrap_input.value)), | ||
"--bootstrap_ci", str(bootstrap_ci_input.value), | ||
"--class_to_calculate", str(class_to_calculate_input.value), | ||
"--num_bins", str(num_bins_input.value), | ||
"--save_metrics", str(save_metrics), | ||
"--plot_bins", str(plot_bins_input.value), | ||
] | ||
|
||
if prevalence_adjustment_checkbox.value: | ||
args.append("--prevalence_adjustment") | ||
|
||
if plot_checkbox.value: | ||
args.append("--plot") | ||
args.append("--save_plot") | ||
args.append(str(save_plot)) | ||
if verbose_checkbox.value: | ||
args.append("--verbose") | ||
|
||
command = ["python", "cal_metrics.py"] + args | ||
print("Running command:", " ".join(command)) | ||
|
||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) | ||
output, error = process.communicate() | ||
|
||
output_area.value = output | ||
if error: | ||
output_area.value += "\nErrors:\n" + error | ||
|
||
if plot_checkbox.value: | ||
display_plot(save_plot) | ||
|
||
def display_plot(plot_path): | ||
plot_image.clear() | ||
plot_image.set_source(plot_path) | ||
ui.update(plot_image) | ||
|
||
def clear_cache(): | ||
ui.run_javascript(''' | ||
function clearCache() { | ||
if ('caches' in window) { | ||
caches.keys().then(function(names) { | ||
for (let name of names) | ||
caches.delete(name); | ||
}); | ||
} | ||
if (window.performance && window.performance.memory) { | ||
window.performance.memory.usedJSHeapSize = 0; | ||
} | ||
sessionStorage.clear(); | ||
localStorage.clear(); | ||
document.cookie.split(";").forEach(function(c) { | ||
document.cookie = c.replace(/^ +/, "").replace(/=.*/, "=;expires=" + new Date().toUTCString() + ";path=/"); | ||
}); | ||
location.reload(true); | ||
} | ||
clearCache(); | ||
''') | ||
ui.notify('Browser cache cleared') | ||
|
||
plot_image.clear() | ||
ui.update(plot_image) | ||
|
||
with ui.row().classes('w-full justify-center'): | ||
with ui.column().classes('w-1/3 p-4'): | ||
ui.label('Calibration Metrics GUI').classes('text-h4') | ||
csv_file_input = ui.input(label='CSV File', placeholder='Enter file path').classes('w-full') | ||
ui.label('Metrics:') | ||
metrics_options = ["all", "SpiegelhalterZ", "ECE-H", "MCE-H", "HL-H", "ECE-C", "MCE-C", "HL-C", "COX", "Loess"] | ||
metrics_checkboxes = {metric: ui.checkbox(metric, on_change=lambda m=metric: update_checkboxes(m)) for metric in metrics_options} | ||
|
||
|
||
with ui.column().classes('w-1/3 p-4'): | ||
ui.label('Bootstrap:') | ||
n_bootstrap_input = ui.number(label='Number of Bootstrap Samples', value=0, min=0) | ||
bootstrap_ci_input = ui.number(label='Bootstrap Confidence Interval', value=0.95, min=0, max=1, step=0.01) | ||
ui.label('Setting:') | ||
class_to_calculate_input = ui.number(label='Class to Calculate Metrics for', value=1, step=1) | ||
num_bins_input = ui.number(label='Number of Bins for ECE/MCE/HL Test', value=10, min=2, step=1) | ||
save_metrics_input = ui.input(label='Save Metrics to', placeholder='Enter file path').classes('w-full') | ||
|
||
with ui.column().classes('w-1/3 p-4'): | ||
ui.label('Prevalence Adjustment:') | ||
prevalence_adjustment_checkbox = ui.checkbox('Perform Prevalence Adjustment', value=False) | ||
ui.label('Plot:') | ||
plot_checkbox = ui.checkbox('Plot Reliability Diagram', value=False) | ||
plot_bins_input = ui.number(label='Number of Bins for Reliability Diagram', value=10, min=2, step=1) | ||
save_plot_input = ui.input(label='Save Plot to', placeholder='Enter file path').classes('w-full') | ||
verbose_checkbox = ui.checkbox('Print Verbose Output', value=True) | ||
ui.button('Run', on_click=run_program).classes('w-full') | ||
ui.button('Clear Browser Cache', on_click=clear_cache).classes('w-full') | ||
|
||
with ui.row().classes('w-full justify-center'): | ||
with ui.column().classes('w-2/3 p-4'): | ||
output_area = ui.textarea(label='Output').classes('w-full') | ||
with ui.column().classes('w-1/3 p-4'): | ||
plot_image = ui.image().classes('w-full') | ||
|
||
def update_checkboxes(changed_metric): | ||
if changed_metric == "all": | ||
all_checked = metrics_checkboxes["all"].value | ||
for metric, checkbox in metrics_checkboxes.items(): | ||
if metric != "all": | ||
checkbox.value = False | ||
checkbox.disabled = all_checked | ||
else: | ||
metrics_checkboxes["all"].value = False | ||
any_checked = any(checkbox.value for metric, checkbox in metrics_checkboxes.items() if metric != "all") | ||
metrics_checkboxes["all"].disabled = any_checked | ||
|
||
ui.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
# calzone | ||
|
||
calzone is a Python package for calculating and analyzing various measurements and metrics. | ||
|
||
## Documentation | ||
|
||
For detailed documentation and API reference, please visit our [documentation page](). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import argparse | ||
import numpy as np | ||
import os | ||
from calzone.metrics import CalibrationMetrics, get_CI | ||
from calzone.utils import * | ||
from calzone.vis import plot_reliability_diagram | ||
|
||
|
||
def perform_calculation(probs, labels, args, suffix=""): | ||
""" | ||
Calculate calibration metrics and visualize reliability diagram. | ||
Args: | ||
probs (numpy.ndarray): Predicted probabilities for each class. | ||
labels (numpy.ndarray): True labels. | ||
args (argparse.Namespace): Command-line arguments. | ||
suffix (str, optional): Suffix for output files. Defaults to "". | ||
Returns: | ||
numpy.ndarray: Calculated metrics and confidence intervals (if bootstrapping is used). | ||
""" | ||
# Initialize CalibrationMetrics object | ||
cal_metrics = CalibrationMetrics( | ||
class_to_calculate=args.class_to_calculate, num_bins=args.num_bins | ||
) | ||
|
||
# Calculate metrics | ||
metrics_to_calculate = args.metrics.split(',') if args.metrics else ['all'] | ||
if metrics_to_calculate == ['all']: | ||
metrics_to_calculate = 'all' | ||
result = cal_metrics.calculate_metrics( | ||
y_true=labels, | ||
y_proba=probs, | ||
metrics=metrics_to_calculate, | ||
perform_pervalance_adjustment=args.prevalence_adjustment | ||
) | ||
|
||
keys = list(result.keys()) | ||
result = np.array(list(result.values())).reshape(1, -1) | ||
|
||
# Perform bootstrap if requested | ||
if args.n_bootstrap > 0: | ||
bootstrap_results = cal_metrics.bootstrap( | ||
y_true=labels, | ||
y_proba=probs, | ||
n_samples=args.n_bootstrap, | ||
metrics=metrics_to_calculate, | ||
perform_pervalance_adjustment=args.prevalence_adjustment | ||
) | ||
CI = get_CI(bootstrap_results) | ||
result = np.vstack((result, np.array(list(CI.values())).T)) | ||
|
||
# Print metrics if verbose mode is on | ||
if args.verbose: | ||
print_metrics(result, keys, args.n_bootstrap, suffix) | ||
|
||
# Save metrics to CSV | ||
if args.save_metrics: | ||
save_metrics_to_csv(result, keys, args.save_metrics, suffix) | ||
|
||
# Plot reliability diagram | ||
if args.plot: | ||
plot_reliability(labels, probs, args, suffix) | ||
|
||
return result | ||
|
||
def print_metrics(result, keys, n_bootstrap, suffix): | ||
""" | ||
Print calculated metrics. | ||
Args: | ||
result (numpy.ndarray): Calculated metrics and confidence intervals. | ||
keys (list): Names of the calculated metrics. | ||
n_bootstrap (int): Number of bootstrap samples. | ||
suffix (str): Suffix for output files. | ||
""" | ||
if n_bootstrap > 0: | ||
print_header = ("Metrics with bootstrap confidence intervals:" | ||
if suffix == "" else | ||
f"Metrics for {suffix} with bootstrap confidence intervals:") | ||
print(print_header) | ||
for i, num in enumerate(keys): | ||
print(f"{num}: {result[0][i]}", f"({result[1][i]}, {result[2][i]})") | ||
else: | ||
print_header = "Metrics:" if suffix == "" else f"Metrics for subgroup {suffix}:" | ||
print(print_header) | ||
for i, num in enumerate(keys): | ||
print(f"{num}: {result[0][i]}") | ||
|
||
def save_metrics_to_csv(result, keys, save_metrics, suffix): | ||
""" | ||
Save calculated metrics to a CSV file. | ||
Args: | ||
result (numpy.ndarray): Calculated metrics and confidence intervals. | ||
keys (list): Names of the calculated metrics. | ||
save_metrics (str): Path to save the CSV file. | ||
suffix (str): Suffix for output files. | ||
""" | ||
if suffix == "": | ||
filename = save_metrics | ||
else: | ||
split_filename = save_metrics.split('.') | ||
pathwithoutextension = '.'.join(split_filename[:-1]) | ||
filename = pathwithoutextension + "_" + suffix + '.csv' | ||
np.savetxt(filename, np.array(result), delimiter=',', | ||
header=','.join(keys), comments='', fmt='%s') | ||
|
||
def plot_reliability(labels, probs, args, suffix): | ||
""" | ||
Plot and save reliability diagram. | ||
Args: | ||
labels (numpy.ndarray): True labels. | ||
probs (numpy.ndarray): Predicted probabilities for each class. | ||
args (argparse.Namespace): Command-line arguments. | ||
suffix (str): Suffix for output files. | ||
""" | ||
if suffix == "": | ||
filename = args.save_plot | ||
if args.save_diagram_output == "": | ||
diagram_filename = None | ||
else: | ||
diagram_filename = args.save_diagram_output | ||
else: | ||
split_filename = args.save_metrics.split('.') | ||
pathwithoutextension = '.'.join(split_filename[:-1]) | ||
filename = pathwithoutextension + "_" + suffix + '.png' | ||
if args.save_diagram_output == "": | ||
diagram_filename = None | ||
else: | ||
split_filename = args.save_diagram_output.split('.') | ||
pathwithoutextension = '.'.join(split_filename[:-1]) | ||
diagram_filename = pathwithoutextension + "_" + suffix + '.csv' | ||
reliability, confidence, bin_edge, bin_count = reliability_diagram( | ||
y_true=labels, y_proba=probs, | ||
num_bins=args.plot_bins, class_to_plot=args.class_to_calculate, | ||
save_path=diagram_filename | ||
) | ||
plot_reliability_diagram(reliability, confidence, bin_count, | ||
save_path=filename, title=suffix, error_bar=True) | ||
|
||
def main(): | ||
""" | ||
Main function to parse arguments and perform calibration calculations. | ||
""" | ||
parser = argparse.ArgumentParser( | ||
description="Calculate calibration metrics and visualize reliability diagram." | ||
) | ||
parser.add_argument("--csv_file", type=str, | ||
help="Path to the input CSV file. (If there is header,it must be in: " | ||
"proba_0,proba_1,...,subgroup_1(optional),subgroup_2(optional),...label. " | ||
"If no header, then the columns must be in the order of " | ||
"proba_0,proba_1,...,label)") | ||
parser.add_argument("--metrics", type=str, | ||
help="Comma-separated list of specific metrics to calculate " | ||
"(SpiegelhalterZ,ECE-H,MCE-H,HL-H,ECE-C,MCE-C,HL-C,COX,Loess,all). " | ||
"Default: all") | ||
parser.add_argument("--prevalence_adjustment", default=False, action="store_true", | ||
help="Perform prevalence adjustment (default: False)") | ||
parser.add_argument("--n_bootstrap", type=int, default=0, | ||
help="Number of bootstrap samples (default: 0)") | ||
parser.add_argument("--bootstrap_ci", type=float, default=0.95, | ||
help="Bootstrap confidence interval (default: 0.95)") | ||
parser.add_argument("--class_to_calculate", type=int, default=1, | ||
help="Class to calculate metrics for (default: 1)") | ||
parser.add_argument("--num_bins", type=int, default=10, | ||
help="Number of bins for ECE/MCE/HL calculations (default: 10)") | ||
parser.add_argument("--save_metrics", type=str, | ||
help="Save the metrics to a csv file") | ||
parser.add_argument("--plot", default=False, action="store_true", | ||
help="Plot reliability diagram (default: False)") | ||
parser.add_argument("--plot_bins", type=int, default=10, | ||
help="Number of bins for reliability diagram") | ||
parser.add_argument("--save_plot", default="", type=str, | ||
help="Save the plot to a file") | ||
parser.add_argument("--save_diagram_output", default="", type=str, | ||
help="Save the reliability diagram output to a file") | ||
parser.add_argument("--verbose", default=True, action="store_true", | ||
help="Print verbose output") | ||
|
||
args = parser.parse_args() | ||
|
||
# Load data from CSV | ||
loader = data_loader(args.csv_file) | ||
|
||
# Perform calculations | ||
if not loader.have_subgroup: | ||
perform_calculation(probs=loader.probs, labels=loader.labels, args=args, suffix="") | ||
else: | ||
perform_calculation(probs=loader.probs, labels=loader.labels, args=args, suffix="") | ||
for i, subgroup in enumerate(loader.subgroups): | ||
groups = np.unique(loader.data[:, loader.subgroup_indices[i]]) | ||
for group in groups: | ||
subgroup_data = loader.data[loader.data[:, loader.subgroup_indices[i]] == group] | ||
probs_group = subgroup_data[:, :-len(loader.subgroups)-1].astype(float) | ||
labels_group = subgroup_data[:, -1:].astype(int) | ||
perform_calculation(probs=probs_group, labels=labels_group, | ||
args=args, suffix=f"subgroup_{subgroup}_group_{group}") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from . import metrics | ||
from . import vis | ||
from . import utils | ||
|
||
__version__ = '0.1.0' |
Oops, something went wrong.