Skip to content

Commit

Permalink
Merge pull request #141 from ANTsX/NewWeightsDABext
Browse files Browse the repository at this point in the history
ENH:  New weights for brain extraction and deep atropos.
  • Loading branch information
ntustison authored Oct 23, 2024
2 parents 9049dc4 + 61bf5cd commit d9f8a9a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 18 deletions.
59 changes: 45 additions & 14 deletions antspynet/utilities/brain_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def brain_extraction(image,
brain_extraction_t1 = brain_extraction(image, modality="t1", verbose=verbose)
brain_mask = ants.iMath_get_largest_component(
ants.threshold_image(brain_extraction_t1, 0.5, 10000))
brain_mask = ants.morphology(brain_mask,"close",morphological_radius).iMath_fill_holes()
brain_mask = ants.morphology(brain_mask, "close", morphological_radius).iMath_fill_holes()

brain_extraction_t1nobrainer = brain_extraction(image * ants.iMath_MD(brain_mask, radius=morphological_radius),
modality = "t1nobrainer", verbose=verbose)
Expand Down Expand Up @@ -140,6 +140,9 @@ def brain_extraction(image,
weights_file_name_prefix = "brainExtractionInfantT1"
elif modality == "t2infant":
weights_file_name_prefix = "brainExtractionInfantT2"
elif modality == "bw20":
weights_file_name_prefix = "brainExtractionBrainWeb20"
is_standard_network = True
else:
raise ValueError("Unknown modality type.")

Expand All @@ -151,11 +154,14 @@ def brain_extraction(image,
if verbose:
print("Brain extraction: retrieving template.")

reorient_template_file_name_path = get_antsxnet_data("S_template3")
reorient_template = ants.image_read(reorient_template_file_name_path)
if is_standard_network and (modality != "t1.v1" and modality != "mra"):
ants.set_spacing(reorient_template, (1.5, 1.5, 1.5))
resampled_image_size = reorient_template.shape
if modality == "bw20":
reorient_template = ants.image_read(get_antsxnet_data("nki"))
resampled_image_size = reorient_template.shape
else:
reorient_template = ants.image_read(get_antsxnet_data("S_template3"))
if is_standard_network and (modality != "t1.v1" and modality != "mra"):
ants.set_spacing(reorient_template, (1.5, 1.5, 1.5))
resampled_image_size = reorient_template.shape

number_of_filters = (8, 16, 32, 64)
mode = "classification"
Expand All @@ -164,11 +170,21 @@ def brain_extraction(image,
number_of_classification_labels = 1
mode = "sigmoid"

unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
number_of_outputs=number_of_classification_labels, mode=mode,
number_of_filters=number_of_filters, dropout_rate=0.0,
convolution_kernel_size=3, deconvolution_kernel_size=2,
weight_decay=1e-5)
unet_model = None
if modality == "bw20":
mode = "classification"
number_of_classification_labels = 4 # background, brain, skull, skin/misc.
unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
number_of_outputs=number_of_classification_labels, mode=mode,
number_of_filters=number_of_filters, dropout_rate=0.0,
convolution_kernel_size=3, deconvolution_kernel_size=2,
weight_decay=0)
else:
unet_model = create_unet_model_3d((*resampled_image_size, channel_size),
number_of_outputs=number_of_classification_labels, mode=mode,
number_of_filters=number_of_filters, dropout_rate=0.0,
convolution_kernel_size=3, deconvolution_kernel_size=2,
weight_decay=1e-5)

unet_model.load_weights(weights_file_name)

Expand All @@ -195,15 +211,30 @@ def brain_extraction(image,
print("Brain extraction: prediction and decoding.")

predicted_data = unet_model.predict(batchX, verbose=verbose)
probability_images_array = decode_unet(predicted_data, reorient_template)
probability_images = decode_unet(predicted_data, reorient_template)

if verbose:
print("Brain extraction: renormalize probability mask to native space.")

xfrm_inv = xfrm.invert()
probability_image = xfrm_inv.apply_to_image(probability_images_array[0][number_of_classification_labels-1], input_images[0])

return(probability_image)
if modality == "bw20":
probability_images_warped = list()
for i in range(number_of_classification_labels):
probability_images_warped.append(xfrm_inv.apply_to_image(
probability_images[0][i], input_images[0]))

image_matrix = ants.image_list_to_matrix(probability_images_warped, input_images[0] * 0 + 1)
segmentation_matrix = np.argmax(image_matrix, axis=0)
segmentation_image = ants.matrix_to_images(
np.expand_dims(segmentation_matrix, axis=0), input_images[0] * 0 + 1)[0]

return_dict = {'segmentation_image' : segmentation_image,
'probability_images' : probability_images_warped}
return(return_dict)
else:
probability_image = xfrm_inv.apply_to_image(probability_images[0][number_of_classification_labels-1], input_images[0])
return(probability_image)

else:

Expand Down
5 changes: 3 additions & 2 deletions antspynet/utilities/deep_atropos.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def deep_atropos(t1,
np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0]

return_dict = {'segmentation_image' : segmentation_image,
'probability_images' : probability_images}
'probability_images' : probability_images}
return(return_dict)

else:
Expand Down Expand Up @@ -222,7 +222,8 @@ def deep_atropos(t1,
rescale_intensities=True,
verbose=verbose)
if i == 0:
t1_mask = brain_extraction(input_images[0], modality="t1", verbose=verbose)
t1_bw20 = brain_extraction(input_images[0], modality="bw20", verbose=verbose)
t1_mask = t1_bw20['probability_images'][1]
n4 = n4 * t1_mask
reg = ants.registration(hcp_t1_template, n4,
type_of_transform="antsRegistrationSyNQuick[a]",
Expand Down
6 changes: 4 additions & 2 deletions antspynet/utilities/get_pretrained_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def switch_networks(argument):
"brainExtractionInfantT1T2": "https://ndownloader.figshare.com/files/22968833",
"brainExtractionInfantT1": "https://ndownloader.figshare.com/files/22968836",
"brainExtractionInfantT2": "https://ndownloader.figshare.com/files/22968830",
"brainExtractionBrainWeb20" : "https://figshare.com/ndownloader/files/49946058",
"brainSegmentation": "https://ndownloader.figshare.com/files/13900010",
"brainSegmentationPatchBased": "https://ndownloader.figshare.com/files/14249717",
"bratsStage1": "https://figshare.com/ndownloader/files/42384756",
Expand Down Expand Up @@ -171,8 +172,8 @@ def switch_networks(argument):
"tb_antsxnet": "https://figshare.com/ndownloader/files/45820599",
"wholeTumorSegmentationT2Flair": "https://ndownloader.figshare.com/files/14087045",
"wholeLungMaskFromVentilation": "https://ndownloader.figshare.com/files/28914441",
"DeepAtroposHcpT1Weights": "https://figshare.com/ndownloader/files/49132504",
"DeepAtroposHcpT1T2Weights": "https://figshare.com/ndownloader/files/49132498",
"DeepAtroposHcpT1Weights": "https://figshare.com/ndownloader/files/49946070",
"DeepAtroposHcpT1T2Weights": "https://figshare.com/ndownloader/files/49132498", # "https://figshare.com/ndownloader/files/49132504"
"DeepAtroposHcpT1FAWeights": "https://figshare.com/ndownloader/files/49132507",
"DeepAtroposHcpT1T2FAWeights": "https://figshare.com/ndownloader/files/49132501"
}
Expand Down Expand Up @@ -204,6 +205,7 @@ def switch_networks(argument):
"brainExtractionInfantT1T2",
"brainExtractionInfantT1",
"brainExtractionInfantT2",
"brainExtractionBrainWeb20",
"brainSegmentation",
"brainSegmentationPatchBased",
"bratsStage1",
Expand Down

0 comments on commit d9f8a9a

Please sign in to comment.