|
| 1 | +# Pytorch implementation of 3D UNet |
| 2 | + |
| 3 | +This implementation is based on the orginial 3D UNet paper and adapted to be used for MRI or CT image segmentation task |
| 4 | +> Link to the paper: [https://arxiv.org/pdf/1606.06650v1.pdf](https://arxiv.org/pdf/1606.06650v1.pdf) |
| 5 | +
|
| 6 | +## Model Architecture |
| 7 | + |
| 8 | +The model architecture follows an encoder-decoder design which requires the input to be divisible by 16 due to its downsampling rate in the analysis path. |
| 9 | + |
| 10 | + |
| 11 | + |
| 12 | +## Dataset |
| 13 | + |
| 14 | +The Dataset class used for training the network is specially adapted to be used for the **Medical Segmentation Decathlon challenge**. |
| 15 | + |
| 16 | +This dataset contains several segmentation tasks on various organs including **Liver Tumours, Brain Tumours, Hippocampus, Lung Tumours, Prostate, Cardiac, |
| 17 | +Pancreas Tumour, Colon Cancer, Hepatic Vessels and Spleen segmentation**. |
| 18 | + |
| 19 | +- Please also note that in the case which the task contain more than 2 classes (1: for foreground, 0: for background), you will need to modify the output |
| 20 | +of the model to reshape it to the size of the groundtruth mask in train.py file. |
| 21 | + |
| 22 | +> The link to the dataset: [http://medicaldecathlon.com/](http://medicaldecathlon.com/) |
| 23 | +
|
| 24 | +- The Dataset class uses Monai package for reading MRI or CT and also applying augmentations on them in the transform.py file. You can modify the applied |
| 25 | +transformation in this file according to your preferences. |
| 26 | + |
| 27 | +## Configure the network |
| 28 | + |
| 29 | +All the configurations and hyperparameters are set in the config.py file. |
| 30 | +Please note that you need to change the path to the dataset directory in the config.py file before running the model. |
| 31 | + |
| 32 | +**Parameters:** |
| 33 | + |
| 34 | +- DATASET_PATH -> the directory path to dataset .tar files |
| 35 | + |
| 36 | +- TASK_ID -> specifies the the segmentation task ID (see the dict below for hints) |
| 37 | + |
| 38 | +- IN_CHANNELS -> number of input channels |
| 39 | + |
| 40 | +- NUM_CLASSES -> specifies the number of output channels for dispirate classes |
| 41 | + |
| 42 | +- BACKGROUND_AS_CLASS -> if True, the model treats background as a class |
| 43 | + |
| 44 | +- TRAIN_VAL_TEST_SPLIT -> delineates the ratios in which the dataset shoud be splitted. The length of the array should be 3. |
| 45 | + |
| 46 | +- TRAINING_EPOCH -> number of training epochs |
| 47 | + |
| 48 | +- VAL_BATCH_SIZE -> specifies the batch size of the training DataLoader |
| 49 | + |
| 50 | +- TEST_BATCH_SIZE -> specifies the batch size of the test DataLoader |
| 51 | + |
| 52 | +- TRAIN_CUDA -> if True, moves the model and inference onto GPU |
| 53 | + |
| 54 | +- BCE_WEIGHTS -> the class weights for the Binary Cross Entropy loss |
| 55 | + |
| 56 | +## Training |
| 57 | + |
| 58 | +After configure config.py, you can start to train by running |
| 59 | + |
| 60 | +`python train.py` |
| 61 | + |
| 62 | +We also employ tensorboard to visualize the training process. |
| 63 | + |
| 64 | +`tensorboard --logdir=runs/` |
0 commit comments