Skip to content

Latest commit

 

History

History
148 lines (132 loc) · 6.33 KB

README.md

File metadata and controls

148 lines (132 loc) · 6.33 KB

Rethinking Interactive Image Segmentation with Low Latency, High Quality, and Diverse Prompts

Pytorch implementation for paper Rethinking Interactive Image Segmentation with Low Latency, High Quality, and Diverse Prompts, CVPR 2024.
Qin Liu, Jaemin Cho, Mohit Bansal, Marc Niethammer
UNC-Chapel Hill

drawing drawing drawing

Installation

The code is tested with python=3.10, torch=2.2.0, torchvision=0.17.0.

git clone https://github.com/uncbiag/SegNext
cd SegNext

Now, create a new conda environment and install required packages accordingly.

conda create -n segnext python=3.10
conda activate segnext
conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install -r requirements.txt

Getting Started

First, download three model weights: vitb_sax1 (408M), vitb_sax2 (435M), and vitb_sax2_ft (435M). These weights will be automatically saved to the weights folder.

python download.py

Run interactive GUI with the downloaded weights. The assets contains images for demo.

./run_demo.sh

Datasets

We train and test our method on three datasets: DAVIS, COCO+LVIS, and HQSeg-44K.

Dataset Description Download Link
DAVIS 345 images with one object each (test) DAVIS.zip (43 MB)
HQSeg-44K 44320 images (train); 1537 images (val) official site
COCO+LVIS* 99k images with 1.5M instances (train) original LVIS images +
combined annotations

Don't forget to change the paths to the datasets in config.yml after downloading and unpacking.

(*) To prepare COCO+LVIS, you need to download original LVIS v1.0, then download and unpack pre-processed annotations that are obtained by combining COCO and LVIS dataset into the folder with LVIS v1.0. (The combined annotations are prepared by RITM.)

Evaluation

We provide a script (run_eval.sh) to evaluate our presented models. The following command runs the NoC evaluation on all test datasets.

python ./segnext/scripts/evaluate_model.py --gpus=0 --checkpoint=./weights/vitb_sa2_cocolvis_hq44k_epoch_0.pth --datasets=DAVIS,HQSeg44K
Train
Dataset
Model HQSeg-44K DAVIS
5-mIoU NoC90 NoC95 NoF95 5-mIoU NoC90 NoC95 NoF95
C+L vitb-sax1 (408 MB) 85.41 7.47 11.94 731 90.13 5.46 13.31 177
C+L vitb-sax2 (435 MB) 85.71 7.18 11.52 700 89.85 5.34 12.80 163
C+L+HQ vitb-sax2 (435 MB) 91.75 5.32 9.42 583 91.87 4.43 10.73 123

For SAT latency evaluation, please refer to eval_sat_latency.ipynb.

Training

We provide a script (run_train.sh) for training our models on the HQSeg-44K dataset. You can start training with the following commands. By default we use 4 A6000 GPUs for training.

# train vitb-sax1 model on coco+lvis 
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_cocolvis_sax1.py
torchrun --nproc-per-node=4 --master-port 29504 ./segnext/train.py ${MODEL_CONFIG} --batch-size=16 --gpus=0,1,2,3

# train vitb-sax2 model on coco+lvis 
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_cocolvis_sax2.py
torchrun --nproc-per-node=4 --master-port 29505 ./segnext/train.py ${MODEL_CONFIG} --batch-size=16 --gpus=0,1,2,3

# finetune vitb-sax2 model on hqseg-44k 
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_hqseg44k_sax2.py
torchrun --nproc-per-node=4 --master-port 29506 ./segnext/train.py ${MODEL_CONFIG} --batch-size=12 --gpus=0,1,2,3 --weights ./weights/vitb_sa2_cocolvis_epoch_90.pth

Citation

@article{liu2024rethinking,
  title={Rethinking Interactive Image Segmentation with Low Latency, High Quality, and Diverse Prompts},
  author={Liu, Qin and Cho, Jaemin and Bansal, Mohit and Niethammer, Marc},
  journal={arXiv preprint arXiv:2404.00741},
  year={2024}
}