-
Notifications
You must be signed in to change notification settings - Fork 144
Adjusting and tuning a segmentation model
-
class_weights
: When sampling image patches for training the UNet, how often should those patches be centered around foreground voxels of a specific class? -
feature_channels
: This parameter has the largest impact on the capacity of the learned model. You should monitor the curves for training and validation set loss - if you see training loss going down ever further, but validation loss going up, try decreasing thefeature_channels
. -
train_batch_size
: How many image patches are fed through the model at one time during training? -
crop_size
: How large are the image patches that are fed into the model?
Segmentation models are trained on equally shaped crops (patches) that are taken from the raw image. The raw images for different patients can have different sizes, but we can only feed equally shaped patches into the model for training. How are those patches chosen?
Firstly, one of the classes is chosen. "Classes" here means either one of the foreground classes (for example, heart or lung in the Lung model), or the background class. The background class captures the information that a voxel is assigned to none of the foreground classes. Choosing a class is guided by the class_weights
parameter of the model configuration - this is a list that has N+1
entries for an N
class segmentation model, with the first
entry being the background class. The values must sum up to 1.0.
Then, with the class chosen, pick a voxel at random from that class. This voxel becomes the center point of the training patch. Construct a patch of size crop_size
centered around that voxel, while keeping it inside of the boundaries of the image. This patch is then fed into the segmentation model during training.
Let's take a simplified Lung model, with only "lung" and "heart" as foreground classes. You would set ground_truth_ids = ["lung", "heart"]
. Here are examples for different choices of class weights:
-
class_weights = [0.1, 0.6, 0.3]
: Choose background in 10% of the cases, lung in 60% of cases, and heart in 30%. -
class_weights = [0.0, 0.5, 0.5]
: Always choose a training patch centered on a foreground voxels, and do that similarly often for heart and lung. -
class_weights = equally_weighted_classes(["lung", "heart"])
: With this helper function, you would also create a weight list[0.0, 0.5, 0.5]
. -
class_weights = equally_weighted_classes(["lung", "heart"], background_weight=0.2)
: This would create a weight list[0.2, 0.4, 0.4]
.
In particular if your scans cover a large body area with only a small area of foreground voxels, ensure that the background class has sufficient samples (for example, the scans are full body CT, but you wish to build a heart segmentation model). Loosely speaking, you also need to tell the model how a body region looks like where there is certainly no heart.
We built the Prostate segmentation models on Azure ND24 VMs. Those have 4 Tesla P40 GPUs with particularly large VRAM, 24 GB per GPU, and 448 GB of system RAM. For these GPUs, we use
crop_size=(64, 224, 224)
train_batch_size=8
If you want to adjust the models to run on smaller GPUs, it is best to first tune the train_batch_size
parameter.
When running on Azure NC24 with 4 Tesla K80 GPUs with 12 GB per GPU,
try train_batch_size=4
.