This code is Implementation of UNet 3+ in pytorch.
I refered to Tensorflow Implementation of UNet 3+ github.
Hit star ⭐ if you find my work useful.
- UNet 3+ for Image Segmentation in Pytorch.
Requirements
- Python >= 3.10
- Pytorch >= 2.2.0
- CUDA 12.0
This code base is tested against above-mentioned Python and Pytorch versions. But it's expected to work for latest versions too.
- Clone code
git clone https://github.com/russel0719/UNet-3-Plus-Pytorch.git UNet3P
cd UNet3P
- Install other requirements.
pip install -r requirements.txt
- checkpoint: Model checkpoint and logs directory
- configs: Configuration file
- data: Dataset files (see Data Preparation) for more details
- data_preparation: For LiTS data preparation and data verification
- losses: Implementations of UNet3+ hybrid loss function and dice coefficient
- models: Unet3+ model files
- utils: Generic utility functions
- data_generator.py: Data generator for training, validation and testing
- predict.py: Prediction script used to visualize model output
- train.py: Training script
Configurations are passed through yaml
file. For more details on config file read here.
- This code can be used to reproduce UNet3+ paper results on LiTS - Liver Tumor Segmentation Challenge.
- You can also use it to train UNet3+ on custom dataset.
For dataset preparation read here.
This repo contains all three versions of UNet3+.
# | Description | Model Name | Training Supported |
---|---|---|---|
1 | UNet3+ Base model | unet3plus | ✓ |
2 | UNet3+ with Deep Supervision | unet3plus_deepsup | ✓ |
3 | UNet3+ with Deep Supervision and Classification Guided Module | unet3plus_deepsup_cgm | ✓ |
- But we can train
unet3plus_deepsup_cgm
only with OUTPUT.CLASSES = 1 option
Here is a sample code for UNet 3+
INPUT_SHAPE = [1, 320, 320]
OUTPUT_CHANNELS = 1
unet_3P = UNet3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, deep_supervision=False, CGM=False)
unet_3P_deep_sup = UNet3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, deep_supervision=True, CGM=False)
unet_3P_deep_sup_cgm = UNet3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, deep_supervision=True, CGM=True)
Here you can find UNet3+ hybrid loss.
To train a model on train dataset call train.py
with required model type and configurations .
e.g. To train on base model run
python train.py MODEL.TYPE=unet3plus
To validate a model on valid dataset call validate.py
with required model type and configurations .
e.g. To validate on base model and visualize them, run
python validate.py MODEL.TYPE=unet3plus
To inference a model on valid dataset call predict.py
with required model type and configurations .
e.g. To inference on base model run
python predict.py MODEL.TYPE=unet3plus
We appreciate any feedback so reporting problems, and asking questions are welcomed here.
Licensed under MIT License