Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade base torchmetrics version #742

Merged
merged 29 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
faa2a06
Modified classification metrics
szmazurek Nov 19, 2023
97c395c
Upgrade to torchmetrics 1.1.2
szmazurek Nov 19, 2023
baf5153
Hotfix for test_full.py script
szmazurek Nov 19, 2023
1c962d7
Merge pull request #10 from szmazurek/feature/torchmetrics_upgrade
szmazurek Nov 19, 2023
429b59e
Update GANDLF/metrics/generic.py
szmazurek Nov 20, 2023
705e014
Update GANDLF/metrics/generic.py
szmazurek Nov 20, 2023
b4b993f
Update GANDLF/metrics/generic.py
szmazurek Nov 20, 2023
d005330
Update classification.py
szmazurek Nov 20, 2023
4686f35
Update generic.py
szmazurek Nov 20, 2023
9ca3cf9
Manually trigger this action
szmazurek Nov 21, 2023
8ad6eb5
Merge pull request #11 from szmazurek/feature/torchmetrics_upgrade
szmazurek Nov 21, 2023
0f62f67
Bugfix
szmazurek Nov 21, 2023
b58928c
Bugfixes
szmazurek Nov 21, 2023
04ef297
added `WARNING` string
sarthakpati Nov 21, 2023
ebcb73e
Modifications and hotfixes:
szmazurek Nov 21, 2023
fbb5fd4
Merge branch 'master' into feature/torchmetrics_upgrade
szmazurek Nov 21, 2023
303ce9e
Refactoring and simplification:
szmazurek Nov 22, 2023
1b2ee5a
Merge pull request #12 from szmazurek/feature/torchmetrics_upgrade
szmazurek Nov 22, 2023
2ef57f6
Merge branch 'master' into master
szmazurek Nov 22, 2023
111a03a
Update generic.py
szmazurek Nov 22, 2023
711099d
Update GANDLF/utils/generic.py
szmazurek Nov 22, 2023
8259fc9
Update GANDLF/metrics/generic.py
szmazurek Nov 22, 2023
a34e401
Update GANDLF/metrics/generic.py
szmazurek Nov 22, 2023
68a9926
Update GANDLF/metrics/generic.py
szmazurek Nov 22, 2023
2053681
Update GANDLF/utils/__init__.py
szmazurek Nov 22, 2023
c6da07a
Update GANDLF/metrics/classification.py
szmazurek Nov 22, 2023
18988ad
Update GANDLF/metrics/classification.py
szmazurek Nov 22, 2023
b6af6db
Update GANDLF/utils/generic.py
szmazurek Nov 22, 2023
c0d251f
Update GANDLF/metrics/generic.py
szmazurek Nov 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions GANDLF/metrics/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torchmetrics as tm
from torch.nn.functional import one_hot
from ..utils import get_output_from_calculator
from GANDLF.utils.generic import determine_classification_task_type


def overall_stats(predictions, ground_truth, params):
Expand All @@ -26,42 +28,82 @@ def overall_stats(predictions, ground_truth, params):
"per_class_average": "macro",
"per_class_weighted": "weighted",
}
task = determine_classification_task_type(params)
# consider adding a "multilabel field in the future"
# metrics that need the "average" parameter
for average_type, average_type_key in average_types_keys.items():

for average_type_key in average_types_keys.values():
# multidim_average is not used when constructing these metrics
# think of having it
calculators = {
"accuracy": tm.Accuracy(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"precision": tm.Precision(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"recall": tm.Recall(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"f1": tm.F1Score(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"specificity": tm.Specificity(
num_classes=params["model"]["num_classes"], average=average_type_key
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key,
),
"aucroc": tm.AUROC(
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key
if average_type_key != "micro"
else "macro",
),
## weird error for multi-class problem, where pos_label is not getting set
# "aucroc": tm.AUROC(
# num_classes=params["model"]["num_classes"], average=average_type_key
# ),
}
for metric_name, calculator in calculators.items():
output_metrics[
f"{metric_name}_{average_type}"
] = get_output_from_calculator(predictions, ground_truth, calculator)
if metric_name == "aucroc":
one_hot_preds = one_hot(
predictions.long(),
num_classes=params["model"]["num_classes"],
)
output_metrics[metric_name] = get_output_from_calculator(
one_hot_preds.float(), ground_truth, calculator
)
else:
output_metrics[metric_name] = get_output_from_calculator(
predictions, ground_truth, calculator
)

#### HERE WE NEED TO MODIFY TESTS - ROC IS RETURNING A TUPLE. WE MAY ALSO DISCRAD IT ####
# what is AUC metric telling at all? Computing it for predictions and ground truth
# is not making sense
# metrics that do not have any "average" parameter
calculators = {
"auc": tm.AUC(reorder=True),
## weird error for multi-class problem, where pos_label is not getting set
# "roc": tm.ROC(num_classes=params["model"]["num_classes"]),
}
for metric_name, calculator in calculators.items():
output_metrics[metric_name] = get_output_from_calculator(
predictions, ground_truth, calculator
)
# calculators = {
#
# # "auc": tm.AUC(reorder=True),
# ## weird error for multi-class problem, where pos_label is not getting set
# "roc": tm.ROC(task=task, num_classes=params["model"]["num_classes"]),
# }
# for metric_name, calculator in calculators.items():
# if metric_name == "roc":
# one_hot_preds = one_hot(
# predictions.long(), num_classes=params["model"]["num_classes"]
# )
# output_metrics[metric_name] = get_output_from_calculator(
# one_hot_preds.float(), ground_truth, calculator
# )
# else:
# output_metrics[metric_name] = get_output_from_calculator(
# predictions, ground_truth, calculator
# )

return output_metrics
59 changes: 44 additions & 15 deletions GANDLF/metrics/generic.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,60 @@
import torch
from torchmetrics import F1Score, Precision, Recall, JaccardIndex, Accuracy, Specificity
from torchmetrics import (
F1Score,
Precision,
Recall,
JaccardIndex,
Accuracy,
Specificity,
)
from GANDLF.utils.tensor import one_hot
from GANDLF.utils.generic import (
determine_classification_task_type,
define_average_type_key,
define_multidim_average_type_key,
)


def generic_function_output_with_check(predicted_classes, label, metric_function):
def generic_function_output_with_check(
predicted_classes, label, metric_function
):
if torch.min(predicted_classes) < 0:
print(
"WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations."
)
return torch.zeros((1), device=predicted_classes.device)
else:
# I need to do this with try-except, otherwise for binary problems it will
# raise and error as the binary metrics do not have .num_classes
# attribute.
# https://github.com/Lightning-AI/torchmetrics/blob/v1.1.2/src/torchmetrics/classification/accuracy.py#L31-L146 link to example from BinaryAccuracy.
try:
max_clamp_val = metric_function.num_classes - 1
except AttributeError:
max_clamp_val = 1
szmazurek marked this conversation as resolved.
Show resolved Hide resolved
predicted_new = torch.clamp(
predicted_classes.cpu().int(), max=metric_function.num_classes - 1
predicted_classes.cpu().int(), max=max_clamp_val
)
predicted_new = predicted_new.reshape(label.shape)
return metric_function(predicted_new, label.cpu().int())


def generic_torchmetrics_score(output, label, metric_class, metric_key, params):
def generic_torchmetrics_score(
output, label, metric_class, metric_key, params
):
task = determine_classification_task_type(params)
num_classes = params["model"]["num_classes"]
predicted_classes = output
if params["problem_type"] == "classification":
predicted_classes = torch.argmax(output, 1)
elif params["problem_type"] == "segmentation":
label = one_hot(label, params["model"]["class_list"])
else:
params["metrics"][metric_key]["multi_class"] = False
params["metrics"][metric_key]["mdmc_average"] = None
metric_function = metric_class(
average=params["metrics"][metric_key]["average"],
task=task,
num_classes=num_classes,
multiclass=params["metrics"][metric_key]["multi_class"],
mdmc_average=params["metrics"][metric_key]["mdmc_average"],
threshold=params["metrics"][metric_key]["threshold"],
average=define_average_type_key(params, metric_key),
multidim_average=define_multidim_average_type_key(params, metric_key),
)

return generic_function_output_with_check(
Expand All @@ -45,19 +67,25 @@ def recall_score(output, label, params):


def precision_score(output, label, params):
return generic_torchmetrics_score(output, label, Precision, "precision", params)
return generic_torchmetrics_score(
output, label, Precision, "precision", params
)


def f1_score(output, label, params):
return generic_torchmetrics_score(output, label, F1Score, "f1", params)


def accuracy(output, label, params):
return generic_torchmetrics_score(output, label, Accuracy, "accuracy", params)
return generic_torchmetrics_score(
output, label, Accuracy, "accuracy", params
)


def specificity_score(output, label, params):
return generic_torchmetrics_score(output, label, Specificity, "specificity", params)
return generic_torchmetrics_score(
output, label, Specificity, "specificity", params
)


def iou_score(output, label, params):
Expand All @@ -67,10 +95,11 @@ def iou_score(output, label, params):
predicted_classes = torch.argmax(output, 1)
elif params["problem_type"] == "segmentation":
label = one_hot(label, params["model"]["class_list"])

task = determine_classification_task_type(params)
recall = JaccardIndex(
reduction=params["metrics"]["iou"]["reduction"],
task=task,
num_classes=num_classes,
average=define_average_type_key(params, "iou"),
threshold=params["metrics"]["iou"]["threshold"],
)

Expand Down
1 change: 1 addition & 0 deletions GANDLF/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
suppress_stdout_stderr,
set_determinism,
print_and_format_metrics,
determine_classification_task_type,
)

from .modelio import (
Expand Down
59 changes: 56 additions & 3 deletions GANDLF/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import SimpleITK as sitk
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from os import devnull
from typing import Dict, Any, Union


@contextmanager
Expand Down Expand Up @@ -48,6 +49,18 @@ def checkPatchDivisibility(patch_size, number=16):
return True


def determine_classification_task_type(params: Dict[str, Union[Dict[str, Any], Any]]) -> str:
"""Determine the task (binary or multiclass) from the model config.
Args:
params (dict): The parameter dictionary containing training and data information.

Returns:
str: A string that denotes the classification task type.
"""
task = "binary" if params["model"]["num_classes"] == 2 else "multiclass"
return task


def get_date_time():
"""
Get a well-parsed date string
Expand Down Expand Up @@ -146,7 +159,10 @@ def checkPatchDimensions(patch_size, numlay):
patch_size_to_check = patch_size_to_check[:-1]

if all(
[x >= 2 ** (numlay + 1) and x % 2**numlay == 0 for x in patch_size_to_check]
[
x >= 2 ** (numlay + 1) and x % 2**numlay == 0
for x in patch_size_to_check
]
):
return numlay
else:
Expand Down Expand Up @@ -182,7 +198,9 @@ def get_array_from_image_or_tensor(input_tensor_or_image):
elif isinstance(input_tensor_or_image, np.ndarray):
return input_tensor_or_image
else:
raise ValueError("Input must be a torch.Tensor or sitk.Image or np.ndarray")
raise ValueError(
"Input must be a torch.Tensor or sitk.Image or np.ndarray"
)


def set_determinism(seed=42):
Expand Down Expand Up @@ -252,7 +270,9 @@ def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict:
output_metrics_dict = deepcopy(cohort_level_metrics)
for metric in metrics_dict_from_parameters:
if isinstance(sample_level_metrics[metric], np.ndarray):
to_print = (sample_level_metrics[metric] / length_of_dataloader).tolist()
to_print = (
sample_level_metrics[metric] / length_of_dataloader
).tolist()
else:
to_print = sample_level_metrics[metric] / length_of_dataloader
output_metrics_dict[metric] = to_print
Expand All @@ -266,3 +286,36 @@ def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict:
)

return output_metrics_dict


def define_average_type_key(
params: Dict[str, Union[Dict[str, Any], Any]], metric_name: str
) -> str:
"""Determine if the the 'average' filed is defined in the metric config.
If not, fallback to the default 'macro'
values.
Args:
params (dict): The parameter dictionary containing training and data information.
metric_name (str): The name of the metric.

Returns:
str: The average type key.
"""
average_type_key = params["metrics"][metric_name].get("average", "macro")
return average_type_key


def define_multidim_average_type_key(params, metric_name) -> str:
"""Determine if the the 'multidim_average' filed is defined in the metric config.
If not, fallback to the default 'global'.
Args:
params (dict): The parameter dictionary containing training and data information.
metric_name (str): The name of the metric.

Returns:
str: The average type key.
"""
average_type_key = params["metrics"][metric_name].get(
"multidim_average", "global"
)
return average_type_key
10 changes: 5 additions & 5 deletions samples/config_all_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ metrics:
# - hausdorff # hausdorff 100 percentile, segmentation
# - hausdorff95 # hausdorff 95 percentile, segmentation
# - mse # regression/classification
# - accuracy # classification
# - accuracy # classification ## more details https://lightning.ai/docs/torchmetrics/v1.1.2/classification/accuracy.html
# - classification_accuracy # classification
# - balanced_accuracy # classification ## more details https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
# - per_label_accuracy # used for classification
# - f1 # classification/segmentation
# - precision # classification/segmentation ## more details https://torchmetrics.readthedocs.io/en/latest/references/modules.html#id3
# - recall # classification/segmentation ## more details https://torchmetrics.readthedocs.io/en/latest/references/modules.html#id4
# - iou # classification/segmentation ## more details https://torchmetrics.readthedocs.io/en/latest/references/modules.html#iou
# - f1 # classification/segmentation ## more details https://lightning.ai/docs/torchmetrics/v1.1.2/classification/f1_score.html
# - precision # classification/segmentation ## more details https://lightning.ai/docs/torchmetrics/v1.1.2/classification/precision.html
# - recall # classification/segmentation ## more details https://lightning.ai/docs/torchmetrics/v1.1.2/classification/recall.html
# - iou # classification/segmentation ## more details https://lightning.ai/docs/torchmetrics/v1.1.2/classification/jaccard_index.html
## this customizes the inference, primarily used for segmentation outputs
inference_mechanism: {
grid_aggregator_overlap: crop, # this option provides the option to strategize the grid aggregation output; should be either 'crop' or 'average' - https://torchio.readthedocs.io/patches/patch_inference.html#grid-aggregator
Expand Down
14 changes: 10 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
readme = readme_file.read()
except Exception as error:
readme = "No README information found."
sys.stderr.write("Warning: Could not open '%s' due %s\n" % ("README.md", error))
sys.stderr.write(
"Warning: Could not open '%s' due %s\n" % ("README.md", error)
)


class CustomInstallCommand(install):
Expand All @@ -39,7 +41,9 @@ def run(self):

except Exception as error:
__version__ = "0.0.1"
sys.stderr.write("Warning: Could not open '%s' due %s\n" % (filepath, error))
sys.stderr.write(
"Warning: Could not open '%s' due %s\n" % (filepath, error)
)

# Handle cases where specific files need to be bundled into the final package as installed via PyPI
dockerfiles = [
Expand All @@ -54,7 +58,9 @@ def run(self):
]
setup_files = ["setup.py", ".dockerignore", "pyproject.toml", "MANIFEST.in"]
all_extra_files = dockerfiles + entrypoint_files + setup_files
all_extra_files_pathcorrected = [os.path.join("../", item) for item in all_extra_files]
all_extra_files_pathcorrected = [
os.path.join("../", item) for item in all_extra_files
]
# find_packages should only ever find these as subpackages of gandlf, not as top-level packages
# generate this dynamically?
# GANDLF.GANDLF is needed to prevent recursion madness in deployments
Expand Down Expand Up @@ -99,7 +105,7 @@ def run(self):
"psutil",
"medcam",
"opencv-python",
"torchmetrics==0.8.1",
"torchmetrics==1.1.2",
"zarr==2.10.3",
"pydicom",
"onnx",
Expand Down
Loading
Loading