Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure synthesis metrics have an option to take voided image #981

Merged
merged 24 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b0247ac
added standardized CLI options
scap3yvt Apr 18, 2024
f920859
Merge branch '849-feature-standardize-the-cli-option-for-output-direc…
sarthakpati Apr 18, 2024
025e2d1
Merge branch 'master' of https://github.com/mlcommons/GaNDLF
sarthakpati Apr 29, 2024
abb1143
Merge branch 'mlcommons:master' into master
sarthakpati Jun 17, 2024
1dddc54
Merge branch 'mlcommons:master' into master
sarthakpati Jul 25, 2024
dad1e22
Merge branch 'new-apis_v0.1.0-dev' of https://github.com/mlcommons/Ga…
sarthakpati Jul 26, 2024
3f8b965
Merge branch 'new-apis_v0.1.0-dev' of https://github.com/mlcommons/Ga…
sarthakpati Jul 26, 2024
040f351
Merge branch 'new-apis_v0.1.0-dev' of https://github.com/mlcommons/Ga…
sarthakpati Jul 26, 2024
781f503
Merge branch 'new-apis_v0.1.0-dev' of https://github.com/mlcommons/Ga…
sarthakpati Jul 26, 2024
5fdd2c1
Merge branch 'master' of https://github.com/mlcommons/GaNDLF
sarthakpati Jul 31, 2024
b7f4570
Merge branch 'master' of https://github.com/sarthakpati/GaNDLF
sarthakpati Sep 10, 2024
2e1e025
Merge branch 'master' of https://github.com/sarthakpati/GaNDLF
sarthakpati Nov 20, 2024
c02b9dd
Merge branch 'master' of https://github.com/mlcommons/GaNDLF
sarthakpati Nov 21, 2024
f737db0
Merge branch 'master' of https://github.com/sarthakpati/GaNDLF
sarthakpati Dec 16, 2024
ca49b4c
putting comments and added parameter to get rmse
sarthakpati Dec 19, 2024
1c4afd7
ensure that the brain mask and void image are treated differently
sarthakpati Dec 19, 2024
05f7101
updated dictionary for spell checker
sarthakpati Dec 19, 2024
ba3ea83
lint fix
sarthakpati Dec 19, 2024
e4f6850
typo fix and unnecessary word removed
sarthakpati Dec 19, 2024
b1f0b54
updated ncc metrics
sarthakpati Dec 19, 2024
c12c37d
addressing comment
sarthakpati Dec 19, 2024
08dc42a
ensure `ncc` gets picked up correctly
sarthakpati Dec 19, 2024
8b9a797
putting the `ncc` in the metric calculation itself for clarity
sarthakpati Dec 19, 2024
1e9447c
fixing error
sarthakpati Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .spelling/.spelling/expect.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Abhishek

Check warning on line 1 in .spelling/.spelling/expect.txt

View workflow job for this annotation

GitHub Actions / Check Spelling

Skipping `.spelling/.spelling/expect.txt` because there seems to be more noise (732) than unique words (0) (total: 732 / 0). (noisy-file)
Abousamra
acdfbac
acsconv
Expand Down Expand Up @@ -487,6 +487,7 @@
rgbtorgba
rigourous
Ritesh
rmse
rmsprop
rocm
rocmdocs
Expand Down Expand Up @@ -561,7 +562,6 @@
thresholding
Thu
tiatoolbox
tiffslide
timepoints
timm
tio
Expand Down Expand Up @@ -597,8 +597,6 @@
unitwise
unsqueeze
upenn
Uploaing
Uploded
upsample
upsampled
upsampling
Expand Down Expand Up @@ -725,7 +723,6 @@
zwezggl
zzokqk
thirdparty
adopy
Shohei
crcrpar
lrs
Expand Down
63 changes: 34 additions & 29 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
overall_stats,
structural_similarity_index,
mean_squared_error,
root_mean_squared_error,
peak_signal_noise_ratio,
mean_squared_log_error,
mean_absolute_error,
ncc_mean,
ncc_std,
ncc_max,
ncc_min,
ncc_metrics,
)
from GANDLF.losses.segmentation import dice
from GANDLF.metrics.segmentation import (
Expand Down Expand Up @@ -302,21 +300,22 @@ def __percentile_clip(
reference_tensor = (
input_tensor if reference_tensor is None else reference_tensor
)
v_min, v_max = np.percentile(
reference_tensor, [p_min, p_max]
) # get p_min percentile and p_max percentile

# get p_min percentile and p_max percentile
v_min, v_max = np.percentile(reference_tensor, [p_min, p_max])
# set lower bound to be 0 if strictlyPositive is enabled
v_min = max(v_min, 0.0) if strictlyPositive else v_min
output_tensor = np.clip(
input_tensor, v_min, v_max
) # clip values to percentiles from reference_tensor
output_tensor = (output_tensor - v_min) / (
v_max - v_min
) # normalizes values to [0;1]
# clip values to percentiles from reference_tensor
output_tensor = np.clip(input_tensor, v_min, v_max)
# normalizes values to [0;1]
output_tensor = (output_tensor - v_min) / (v_max - v_min)
return output_tensor

input_df = __update_header_location_case_insensitive(input_df, "Mask", False)
# these are additional columns that could be present for synthesis tasks
for column_to_make_case_insensitive in ["Mask", "VoidImage"]:
input_df = __update_header_location_case_insensitive(
input_df, column_to_make_case_insensitive, False
)

for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
current_subject_id = row["SubjectID"]
overall_stats_dict[current_subject_id] = {}
Expand All @@ -332,16 +331,26 @@ def __percentile_clip(
)
).byte()

void_image_present = True if "VoidImage" in row else False
void_image = (
__fix_2d_tensor(torchio.ScalarImage(row["VoidImage"]).data)
if "VoidImage" in row
else torch.from_numpy(
np.ones(target_image.numpy().shape, dtype=np.uint8)
)
)

# Get Infill region (we really are only interested in the infill region)
output_infill = (pred_image * mask).float()
gt_image_infill = (target_image * mask).float()

# Normalize to [0;1] based on GT (otherwise MSE will depend on the image intensity range)
normalize = parameters.get("normalize", True)
if normalize:
# use all the tissue that is not masked for normalization
reference_tensor = (
target_image * ~mask
) # use all the tissue that is not masked for normalization
target_image * ~mask if not void_image_present else void_image
)
gt_image_infill = __percentile_clip(
gt_image_infill,
reference_tensor=reference_tensor,
Expand All @@ -364,18 +373,10 @@ def __percentile_clip(
# ncc metrics
compute_ncc = parameters.get("compute_ncc", True)
if compute_ncc:
overall_stats_dict[current_subject_id]["ncc_mean"] = ncc_mean(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_std"] = ncc_std(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_max"] = ncc_max(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_min"] = ncc_min(
output_infill, gt_image_infill
)
calculated_ncc_metrics = ncc_metrics(output_infill, gt_image_infill)
for key, value in calculated_ncc_metrics.items():
# we don't need the ".item()" here, since the values are already scalars
overall_stats_dict[current_subject_id][key] = value

# only voxels that are to be inferred (-> flat array)
# these are required for mse, psnr, etc.
Expand All @@ -386,6 +387,10 @@ def __percentile_clip(
output_infill, gt_image_infill
).item()

overall_stats_dict[current_subject_id]["rmse"] = root_mean_squared_error(
output_infill, gt_image_infill
).item()

overall_stats_dict[current_subject_id]["msle"] = mean_squared_log_error(
output_infill, gt_image_infill
).item()
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/entrypoints/hf_hub_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
@click.option(
"--hf-template",
"-hft",
help="Adding the template path for the model card it is Required during Uploaing a model",
help="Adding the template path for the model card: it is required during model upload",
default=huggingface_file_path,
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
Expand Down
6 changes: 2 additions & 4 deletions GANDLF/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@
from .synthesis import (
structural_similarity_index,
mean_squared_error,
root_mean_squared_error,
peak_signal_noise_ratio,
mean_squared_log_error,
mean_absolute_error,
ncc_mean,
ncc_std,
ncc_max,
ncc_min,
ncc_metrics,
)
import GANDLF.metrics.classification as classification
import GANDLF.metrics.regression as regression
Expand Down
101 changes: 39 additions & 62 deletions GANDLF/metrics/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,28 @@ def mean_squared_error(prediction: torch.Tensor, target: torch.Tensor) -> torch.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.

Returns:
torch.Tensor: The mean squared error or its square root.
"""
mse = MeanSquaredError()
mse = MeanSquaredError(squared=True)
return mse(preds=prediction, target=target)


def root_mean_squared_error(
prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Computes the mean squared error between the target and prediction.

Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.

Returns:
torch.Tensor: The mean squared error or its square root.
"""
mse = MeanSquaredError(squared=False)
return mse(preds=prediction, target=target)


Expand Down Expand Up @@ -78,10 +98,9 @@ def peak_signal_noise_ratio(
return psnr(preds=prediction, target=target)
else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0'
mse = mean_squared_error(target, prediction)
if data_range == None: # compute data_range like torchmetrics if not given
min_v = (
0 if torch.min(target) > 0 else torch.min(target)
) # look at this line
if data_range is None: # compute data_range like torchmetrics if not given
# put the min value to 0 if all values are positive
min_v = 0 if torch.min(target) > 0 else torch.min(target)
max_v = torch.max(target)
else:
min_v, max_v = data_range
Expand Down Expand Up @@ -158,69 +177,27 @@ def __convert_to_grayscale(image: sitk.Image) -> sitk.Image:
return correlation_filter.Execute(target_image, pred_image)


def ncc_mean(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation mean between target and prediction.

Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.

Returns:
float: The normalized cross correlation mean.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMean()


def ncc_std(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation standard deviation between target and prediction.

Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.

Returns:
float: The normalized cross correlation standard deviation.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetSigma()


def ncc_max(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation maximum between target and prediction.

Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.

Returns:
float: The normalized cross correlation maximum.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMaximum()


def ncc_min(prediction: torch.Tensor, target: torch.Tensor) -> float:
def ncc_metrics(prediction: torch.Tensor, target: torch.Tensor) -> dict:
"""
Computes normalized cross correlation minimum between target and prediction.
Computes normalized cross correlation metrics between target and prediction.

Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.

Returns:
float: The normalized cross correlation minimum.
dict: The normalized cross correlation metrics.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMinimum()
stats_filter = sitk.LabelStatisticsImageFilter()
stats_filter.UseHistogramsOn()
# ensure that we are not considering zeros
onesImage = corr_image == corr_image
stats_filter.Execute(corr_image, onesImage)
return {
"ncc_mean": stats_filter.GetMean(1),
"ncc_std": stats_filter.GetSigma(1),
"ncc_max": stats_filter.GetMaximum(1),
"ncc_min": stats_filter.GetMinimum(1),
"ncc_median": stats_filter.GetMedian(1),
}
Loading