Skip to content

Commit

Permalink
calzone first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonfan1997 committed Sep 25, 2024
0 parents commit 33120aa
Show file tree
Hide file tree
Showing 139 changed files with 38,255 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
*.pyc
.vscode/*
calzone.egg-info/
example_data/*result*.csv
example_data/*output*.csv
example_data/*png
test_notebook/*
133 changes: 133 additions & 0 deletions GUI_cal_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse
import numpy as np
import subprocess
import time
import matplotlib.pyplot as plt
from nicegui import ui, app


def run_program():
csv_file = csv_file_input.value
save_metrics = save_metrics_input.value
save_plot = save_plot_input.value

selected_metrics = [metric for metric, checkbox in metrics_checkboxes.items() if checkbox.value]
metric_arg = "all" if not selected_metrics else ",".join(selected_metrics)

args = [
"--csv_file", str(csv_file),
"--metrics", str(metric_arg),
"--n_bootstrap", str(int(n_bootstrap_input.value)),
"--bootstrap_ci", str(bootstrap_ci_input.value),
"--class_to_calculate", str(class_to_calculate_input.value),
"--num_bins", str(num_bins_input.value),
"--save_metrics", str(save_metrics),
"--plot_bins", str(plot_bins_input.value),
]

if prevalence_adjustment_checkbox.value:
args.append("--prevalence_adjustment")

if plot_checkbox.value:
args.append("--plot")
args.append("--save_plot")
args.append(str(save_plot))
if verbose_checkbox.value:
args.append("--verbose")

command = ["python", "cal_metrics.py"] + args
print("Running command:", " ".join(command))

process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
output, error = process.communicate()

output_area.value = output
if error:
output_area.value += "\nErrors:\n" + error

if plot_checkbox.value:
display_plot(save_plot)

def display_plot(plot_path):
plot_image.clear()
plot_image.set_source(plot_path)
ui.update(plot_image)

def clear_cache():
ui.run_javascript('''
function clearCache() {
if ('caches' in window) {
caches.keys().then(function(names) {
for (let name of names)
caches.delete(name);
});
}
if (window.performance && window.performance.memory) {
window.performance.memory.usedJSHeapSize = 0;
}
sessionStorage.clear();
localStorage.clear();
document.cookie.split(";").forEach(function(c) {
document.cookie = c.replace(/^ +/, "").replace(/=.*/, "=;expires=" + new Date().toUTCString() + ";path=/");
});
location.reload(true);
}
clearCache();
''')
ui.notify('Browser cache cleared')

plot_image.clear()
ui.update(plot_image)

with ui.row().classes('w-full justify-center'):
with ui.column().classes('w-1/3 p-4'):
ui.label('Calibration Metrics GUI').classes('text-h4')
csv_file_input = ui.input(label='CSV File', placeholder='Enter file path').classes('w-full')
ui.label('Metrics:')
metrics_options = ["all", "SpiegelhalterZ", "ECE-H", "MCE-H", "HL-H", "ECE-C", "MCE-C", "HL-C", "COX", "Loess"]
metrics_checkboxes = {metric: ui.checkbox(metric, on_change=lambda m=metric: update_checkboxes(m)) for metric in metrics_options}


with ui.column().classes('w-1/3 p-4'):
ui.label('Bootstrap:')
n_bootstrap_input = ui.number(label='Number of Bootstrap Samples', value=0, min=0)
bootstrap_ci_input = ui.number(label='Bootstrap Confidence Interval', value=0.95, min=0, max=1, step=0.01)
ui.label('Setting:')
class_to_calculate_input = ui.number(label='Class to Calculate Metrics for', value=1, step=1)
num_bins_input = ui.number(label='Number of Bins for ECE/MCE/HL Test', value=10, min=2, step=1)
save_metrics_input = ui.input(label='Save Metrics to', placeholder='Enter file path').classes('w-full')

with ui.column().classes('w-1/3 p-4'):
ui.label('Prevalence Adjustment:')
prevalence_adjustment_checkbox = ui.checkbox('Perform Prevalence Adjustment', value=False)
ui.label('Plot:')
plot_checkbox = ui.checkbox('Plot Reliability Diagram', value=False)
plot_bins_input = ui.number(label='Number of Bins for Reliability Diagram', value=10, min=2, step=1)
save_plot_input = ui.input(label='Save Plot to', placeholder='Enter file path').classes('w-full')
verbose_checkbox = ui.checkbox('Print Verbose Output', value=True)
ui.button('Run', on_click=run_program).classes('w-full')
ui.button('Clear Browser Cache', on_click=clear_cache).classes('w-full')

with ui.row().classes('w-full justify-center'):
with ui.column().classes('w-2/3 p-4'):
output_area = ui.textarea(label='Output').classes('w-full')
with ui.column().classes('w-1/3 p-4'):
plot_image = ui.image().classes('w-full')

def update_checkboxes(changed_metric):
if changed_metric == "all":
all_checked = metrics_checkboxes["all"].value
for metric, checkbox in metrics_checkboxes.items():
if metric != "all":
checkbox.value = False
checkbox.disabled = all_checked
else:
metrics_checkboxes["all"].value = False
any_checked = any(checkbox.value for metric, checkbox in metrics_checkboxes.items() if metric != "all")
metrics_checkboxes["all"].disabled = any_checked

ui.run()
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

# calzone

calzone is a Python package for calculating and analyzing various measurements and metrics.

## Documentation

For detailed documentation and API reference, please visit our [documentation page]().
202 changes: 202 additions & 0 deletions cal_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import argparse
import numpy as np
import os
from calzone.metrics import CalibrationMetrics, get_CI
from calzone.utils import *
from calzone.vis import plot_reliability_diagram


def perform_calculation(probs, labels, args, suffix=""):
"""
Calculate calibration metrics and visualize reliability diagram.
Args:
probs (numpy.ndarray): Predicted probabilities for each class.
labels (numpy.ndarray): True labels.
args (argparse.Namespace): Command-line arguments.
suffix (str, optional): Suffix for output files. Defaults to "".
Returns:
numpy.ndarray: Calculated metrics and confidence intervals (if bootstrapping is used).
"""
# Initialize CalibrationMetrics object
cal_metrics = CalibrationMetrics(
class_to_calculate=args.class_to_calculate, num_bins=args.num_bins
)

# Calculate metrics
metrics_to_calculate = args.metrics.split(',') if args.metrics else ['all']
if metrics_to_calculate == ['all']:
metrics_to_calculate = 'all'
result = cal_metrics.calculate_metrics(
y_true=labels,
y_proba=probs,
metrics=metrics_to_calculate,
perform_pervalance_adjustment=args.prevalence_adjustment
)

keys = list(result.keys())
result = np.array(list(result.values())).reshape(1, -1)

# Perform bootstrap if requested
if args.n_bootstrap > 0:
bootstrap_results = cal_metrics.bootstrap(
y_true=labels,
y_proba=probs,
n_samples=args.n_bootstrap,
metrics=metrics_to_calculate,
perform_pervalance_adjustment=args.prevalence_adjustment
)
CI = get_CI(bootstrap_results)
result = np.vstack((result, np.array(list(CI.values())).T))

# Print metrics if verbose mode is on
if args.verbose:
print_metrics(result, keys, args.n_bootstrap, suffix)

# Save metrics to CSV
if args.save_metrics:
save_metrics_to_csv(result, keys, args.save_metrics, suffix)

# Plot reliability diagram
if args.plot:
plot_reliability(labels, probs, args, suffix)

return result

def print_metrics(result, keys, n_bootstrap, suffix):
"""
Print calculated metrics.
Args:
result (numpy.ndarray): Calculated metrics and confidence intervals.
keys (list): Names of the calculated metrics.
n_bootstrap (int): Number of bootstrap samples.
suffix (str): Suffix for output files.
"""
if n_bootstrap > 0:
print_header = ("Metrics with bootstrap confidence intervals:"
if suffix == "" else
f"Metrics for {suffix} with bootstrap confidence intervals:")
print(print_header)
for i, num in enumerate(keys):
print(f"{num}: {result[0][i]}", f"({result[1][i]}, {result[2][i]})")
else:
print_header = "Metrics:" if suffix == "" else f"Metrics for subgroup {suffix}:"
print(print_header)
for i, num in enumerate(keys):
print(f"{num}: {result[0][i]}")

def save_metrics_to_csv(result, keys, save_metrics, suffix):
"""
Save calculated metrics to a CSV file.
Args:
result (numpy.ndarray): Calculated metrics and confidence intervals.
keys (list): Names of the calculated metrics.
save_metrics (str): Path to save the CSV file.
suffix (str): Suffix for output files.
"""
if suffix == "":
filename = save_metrics
else:
split_filename = save_metrics.split('.')
pathwithoutextension = '.'.join(split_filename[:-1])
filename = pathwithoutextension + "_" + suffix + '.csv'
np.savetxt(filename, np.array(result), delimiter=',',
header=','.join(keys), comments='', fmt='%s')

def plot_reliability(labels, probs, args, suffix):
"""
Plot and save reliability diagram.
Args:
labels (numpy.ndarray): True labels.
probs (numpy.ndarray): Predicted probabilities for each class.
args (argparse.Namespace): Command-line arguments.
suffix (str): Suffix for output files.
"""
if suffix == "":
filename = args.save_plot
if args.save_diagram_output == "":
diagram_filename = None
else:
diagram_filename = args.save_diagram_output
else:
split_filename = args.save_metrics.split('.')
pathwithoutextension = '.'.join(split_filename[:-1])
filename = pathwithoutextension + "_" + suffix + '.png'
if args.save_diagram_output == "":
diagram_filename = None
else:
split_filename = args.save_diagram_output.split('.')
pathwithoutextension = '.'.join(split_filename[:-1])
diagram_filename = pathwithoutextension + "_" + suffix + '.csv'
reliability, confidence, bin_edge, bin_count = reliability_diagram(
y_true=labels, y_proba=probs,
num_bins=args.plot_bins, class_to_plot=args.class_to_calculate,
save_path=diagram_filename
)
plot_reliability_diagram(reliability, confidence, bin_count,
save_path=filename, title=suffix, error_bar=True)

def main():
"""
Main function to parse arguments and perform calibration calculations.
"""
parser = argparse.ArgumentParser(
description="Calculate calibration metrics and visualize reliability diagram."
)
parser.add_argument("--csv_file", type=str,
help="Path to the input CSV file. (If there is header,it must be in: "
"proba_0,proba_1,...,subgroup_1(optional),subgroup_2(optional),...label. "
"If no header, then the columns must be in the order of "
"proba_0,proba_1,...,label)")
parser.add_argument("--metrics", type=str,
help="Comma-separated list of specific metrics to calculate "
"(SpiegelhalterZ,ECE-H,MCE-H,HL-H,ECE-C,MCE-C,HL-C,COX,Loess,all). "
"Default: all")
parser.add_argument("--prevalence_adjustment", default=False, action="store_true",
help="Perform prevalence adjustment (default: False)")
parser.add_argument("--n_bootstrap", type=int, default=0,
help="Number of bootstrap samples (default: 0)")
parser.add_argument("--bootstrap_ci", type=float, default=0.95,
help="Bootstrap confidence interval (default: 0.95)")
parser.add_argument("--class_to_calculate", type=int, default=1,
help="Class to calculate metrics for (default: 1)")
parser.add_argument("--num_bins", type=int, default=10,
help="Number of bins for ECE/MCE/HL calculations (default: 10)")
parser.add_argument("--save_metrics", type=str,
help="Save the metrics to a csv file")
parser.add_argument("--plot", default=False, action="store_true",
help="Plot reliability diagram (default: False)")
parser.add_argument("--plot_bins", type=int, default=10,
help="Number of bins for reliability diagram")
parser.add_argument("--save_plot", default="", type=str,
help="Save the plot to a file")
parser.add_argument("--save_diagram_output", default="", type=str,
help="Save the reliability diagram output to a file")
parser.add_argument("--verbose", default=True, action="store_true",
help="Print verbose output")

args = parser.parse_args()

# Load data from CSV
loader = data_loader(args.csv_file)

# Perform calculations
if not loader.have_subgroup:
perform_calculation(probs=loader.probs, labels=loader.labels, args=args, suffix="")
else:
perform_calculation(probs=loader.probs, labels=loader.labels, args=args, suffix="")
for i, subgroup in enumerate(loader.subgroups):
groups = np.unique(loader.data[:, loader.subgroup_indices[i]])
for group in groups:
subgroup_data = loader.data[loader.data[:, loader.subgroup_indices[i]] == group]
probs_group = subgroup_data[:, :-len(loader.subgroups)-1].astype(float)
labels_group = subgroup_data[:, -1:].astype(int)
perform_calculation(probs=probs_group, labels=labels_group,
args=args, suffix=f"subgroup_{subgroup}_group_{group}")

if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions calzone/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import metrics
from . import vis
from . import utils

__version__ = '0.1.0'
Loading

0 comments on commit 33120aa

Please sign in to comment.