PyTorch and Torchvision needs to be installed before running the scripts, together with PIL
for data-preprocessing and tqdm
for showing the training progress.
To run this repository, kindly install python 3.5 and PyTorch 0.4.0 with Anaconda.
You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/
Create a new environment and install PyTorch and torchvision on it:
conda create --name segfew python=3.5
conda activate segfew
conda install pytorch=0.4.0
conda install torchvision -c pytorch
Clone this repository:
git clone https://github.com/ahirsharan/MTL_Segmentation.git
- (U-Net) Convolutional Networks for Biomedical Image Segmentation (2015): [Paper]
- (Meta Tranfer Learning) Meta-Transfer Learning for Few-Shot Learning: [Paper]
-
COCO Stuff: For COCO, there is two partitions, CocoStuff10k with only 10k that are used for training the evaluation, note that this dataset is outdated, can be used for small scale testing and training, and can be downloaded here. For the official dataset with all of the training 164k examples, it can be downloaded from the official website.
-
Few-Shot: For Few Shot(FSS1000), there are 1000 object classes folder each with 10 images with ground truth mask for segmentation. This dataset can be used for few shot learning and can be downloaded here.
In addition to the Cross-Entropy loss:
- Dice-Loss, which measures of overlap between two samples and can be more reflective of the training objective (maximizing the mIoU), but is highly non-convexe and can be hard to optimize.
- CE Dice loss, the sum of the Dice loss and CE, CE gives smooth optimization while Dice loss is a good indicator of the quality of the segmentation results.
- Focal Loss, an alternative version of the CE, used to avoid class imbalance where the confident predictions are scaled down.
- Lovasz Softmax lends it self as a good alternative to the Dice loss, where we can directly optimization for the mean intersection-over-union based on the convex Lovász extension of submodular losses (for more details, check the paper: The Lovász-Softmax loss).
The code structure is based on MTL-template and Pytorch-Segmentation.
.
|
├── FewShotPreprocessing.py # utility to organise the Few-shot data into train,test and val set
|
|
├── dataloader
| ├── dataset_loader.py # data loader for pre datasets
| ├── mdataset_loader.py # data loader for meta task dataset
| └── samplers.py # samplers for meta task dataset(Few-Shot)
|
|
├── models
| ├── mtl.py # meta-transfer class
| ├── unet_mtl.py # unet class
| └── conv2d_mtl.py # meta-transfer convolution class
|
├── trainer
| ├── pre.py # pre-train trainer class
| └── meta.py # meta-train trainer class
|
|
├── utils
| ├── gpu_tools.py # GPU tool functions
| ├── metrics.py # Metrics functions
| ├── losses.py # Loss functions
| ├── lovasz_losses.py # Lovasz Loss function
| └── misc.py # miscellaneous tool functions
|
├── main.py # the python file with main function and parameter settings
├── run_pre.py # the script to run pre-train phase
└── run_meta.py # the script to run meta-train and meta-test phases
Run pretrain phase:
python run_pre.py
Run meta-train and meta-test phase:
python run_meta.py
Hyperparameters and options in main.py
.
model_type
The network architecturedataset
Meta datasetphase
pre-train, meta-train or meta-evalseed
Manual seed for PyTorch, "0" means using random seedgpu
GPU iddataset_dir
Directory for the imagesmax_epoch
Epoch number for meta-train phasenum_batch
The number for different tasks used for meta-trainshot
Shot number, how many samples for one class in a taskteshot
Test-Shot number, how many samples for one class in a meta test taskway
Way number, how many classes in a tasktrain_query
The number of training samples for each class in a taskval_query
The number of test samples for each class in a taskmeta_lr1
Learning rate for SS weightsmeta_lr2
Learning rate for Base learner weights (meta task)base_lr
Learning rate for the inner loopupdate_step
The number of updates for the inner loopstep_size
The number of epochs to reduce the meta learning ratesgamma
Gamma for the meta-train learning rate decayinit_weights
The pretained weights for meta-train phasepre_init_weights
The pretained weights for pre-train phaseeval_weights
The meta-trained weights for meta-eval phasemeta_label
Additional label for meta-trainpre_max_epoch
Epoch number for pre-train psasepre_batch_size
Batch size for pre-train phasepre_lr
Learning rate for pre-train pahsepre_gamma
Gamma for the preteain learning rate decaypre_step_size
The number of epochs to reduce the pre-train learning ratepre_custom_weight_decay
Weight decay for the optimizer during pre-train
Mean IoU | CE Loss |
---|---|
Mean IoU | CE Loss |
---|---|
Mean IoU | CE Loss |
---|---|
-
The Pre-trained weights for both Pre-Train and Meta Tasks can be found here pertaining to Max-IoU.
-
Some of the best results for 3-shot learning 😄 :