diff --git a/GANDLF/models/brain_age.py b/GANDLF/models/brain_age.py index 436cf732b..3bc88e78b 100644 --- a/GANDLF/models/brain_age.py +++ b/GANDLF/models/brain_age.py @@ -1,7 +1,7 @@ import torch.nn as nn import sys import torchvision - +import traceback def brainage(parameters): """ @@ -18,11 +18,13 @@ def brainage(parameters): """ # Check that the input data is 2D - if parameters["model"]["dimension"] != 2: - sys.exit("Brain Age predictions only works on 2D data") + assert parameters["model"]["dimension"] == 2, "Brain Age predictions only work on 2D data" - # Load the pretrained VGG16 model - model = torchvision.models.vgg16(pretrained=True) + try: + # Load the pretrained VGG16 model + model = torchvision.models.vgg16(pretrained=True) + except Exception: + sys.exit("Error: Failed to load VGG16 model: " + traceback.format_exc()) # Remove the final convolutional layer model.final_convolution_layer = None @@ -36,19 +38,13 @@ def brainage(parameters): features = list(model.classifier.children())[:-1] # Remove the last layer features.extend( [ - nn.Linear( - num_features, 1024 - ), # Add a linear layer with 1024 output features + nn.Linear(num_features, 1024), # Add a linear layer with 1024 output features nn.ReLU(True), # Add a ReLU activation function nn.Dropout2d(0.8), # Add a 2D dropout layer with a probability of 0.8 - nn.Linear( - 1024, 1 - ), # Add a linear layer with 1 output feature (for brain age prediction) + nn.Linear(1024, 1), # Add a linear layer with 1 output feature (for brain age prediction) ] ) - model.classifier = nn.Sequential( - *features - ) # Replace the model classifier with the modified one + model.classifier = nn.Sequential(*features) # Replace the model classifier with the modified one # Set the "amp" parameter to False (not yet implemented for VGG) parameters["model"]["amp"] = False