Skip to content

Commit

Permalink
updated ncc metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati committed Dec 19, 2024
1 parent e4f6850 commit b1f0b54
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 77 deletions.
20 changes: 4 additions & 16 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
peak_signal_noise_ratio,
mean_squared_log_error,
mean_absolute_error,
ncc_mean,
ncc_std,
ncc_max,
ncc_min,
ncc_metrics,
)
from GANDLF.losses.segmentation import dice
from GANDLF.metrics.segmentation import (
Expand Down Expand Up @@ -375,18 +372,9 @@ def __percentile_clip(
# ncc metrics
compute_ncc = parameters.get("compute_ncc", True)
if compute_ncc:
overall_stats_dict[current_subject_id]["ncc_mean"] = ncc_mean(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_std"] = ncc_std(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_max"] = ncc_max(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_min"] = ncc_min(
output_infill, gt_image_infill
)
calculated_ncc_metrics = ncc_metrics(output_infill, gt_image_infill)
for key, value in calculated_ncc_metrics.items():
overall_stats_dict[current_subject_id][key] = value.item()

# only voxels that are to be inferred (-> flat array)
# these are required for mse, psnr, etc.
Expand Down
5 changes: 1 addition & 4 deletions GANDLF/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@
peak_signal_noise_ratio,
mean_squared_log_error,
mean_absolute_error,
ncc_mean,
ncc_std,
ncc_max,
ncc_min,
ncc_metrics,
)
import GANDLF.metrics.classification as classification
import GANDLF.metrics.regression as regression
Expand Down
72 changes: 15 additions & 57 deletions GANDLF/metrics/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,69 +163,27 @@ def __convert_to_grayscale(image: sitk.Image) -> sitk.Image:
return correlation_filter.Execute(target_image, pred_image)


def ncc_mean(prediction: torch.Tensor, target: torch.Tensor) -> float:
def ncc_metrics(prediction: torch.Tensor, target: torch.Tensor) -> dict:
"""
Computes normalized cross correlation mean between target and prediction.
Computes normalized cross correlation metrics between target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
float: The normalized cross correlation mean.
dict: The normalized cross correlation metrics.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMean()


def ncc_std(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation standard deviation between target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
float: The normalized cross correlation standard deviation.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetSigma()


def ncc_max(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation maximum between target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
float: The normalized cross correlation maximum.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMaximum()


def ncc_min(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation minimum between target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
float: The normalized cross correlation minimum.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMinimum()
stats_filter = sitk.LabelStatisticsImageFilter()
stats_filter.UseHistogramsOn()
# ensure that we are not considering zeros
onesImage = corr_image == corr_image
stats_filter.Execute(corr_image, onesImage)
return {
"mean": stats_filter.GetMean(1),
"std": stats_filter.GetSigma(1),
"max": stats_filter.GetMaximum(1),
"min": stats_filter.GetMinimum(1),
"median": stats_filter.GetMedian(1),
}

0 comments on commit b1f0b54

Please sign in to comment.