Skip to content

[MICCAI 2024] HySparK: Hybrid Sparse Masking for Large Scale Medical Image Pre-Training

Notifications You must be signed in to change notification settings

FengheTan9/HySparK

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

HySparK: Hybrid Sparse Masking for Large Scale Medical Image Pre-Training

HySparK




arXiv github License: Apache2.0

News

  • Code and weight now are released 😎 !
  • HySparK is accepted by MICCAI 2024 (Early accept) !
  • Code will be released soon ! 😘

TODOs

  • Paper released
  • Code released
  • Weight released

Models

Pre-trained weights

Name Pre-trained data scale Weights
HySparK-B 6.8k CT Scan hybird_ct_pretrained_timm_style_mask75.pth

Getting Started

Prepare Environment

conda create -n hyspark python=3.9
conda activate hyspark
pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
pip install packaging timm==0.5.4
pip install transformers==4.34.1 typed-argument-parser
pip install numpy==1.21.2 opencv-python==4.5.5.64 opencv-python-headless==4.5.5.64
pip install 'monai[all]'
pip install monai==1.2.0

Prepare Datasets

We recommend you to convert the dataset into the nnUNet format.

└── HySparK
    ├── data
        ├── Dataset060_TotalSegmentator
            └── imagesTr
                ├── xxx_0000.nii.gz
                ├── ...
        ├── Dataset006_FLARE2022
            └── imagesTr
                ├── xxx_0000.nii.gz
                ├── ...
        └── Other_dataset
            └── imagesTr
                ├── xxx_0000.nii.gz
                ├── ...

Try to use the function organize in nnunet-style or organize_by_names to prepare your custom datasets.

Then run :

python generate_js.py

A example dataset.json will be generated in ./data

The content should be like below

{
    "training": [
        {
            "image": "./Dataset060_TotalSegmentator/imagesTr/xxx_0000.nii.gz"
        },
        {
            "image": "./Dataset006_FLARE2022/imagesTr/xxx_0000.nii.gz"
        },
    ]
}

Start Training

Run training on multi-GPU :

# An example of training on 4 GPUs with DDP
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12351 main.py --exp_name=debug --data_path=./data  --model=hyspark --bs=12  --exp_dir=debug_hyspark_ddp_4

Run training on single-GPU :

# An example of training on single GPU
python main.py --exp_name=debug --data_path=./data --model=hyspark --bs=4 --exp_dir=debug_hyspark

Fine-tuning

Load pre-training weights :

# An example of Fine-tuning on BTCV (num_classes=14)
from models.network.hyspark_model import build_hybird

model = build_hybird(in_channel=1, n_classes=14, img_size=96).cuda()

model_dict = torch.load("./[your_ckpt_path]/hybird_ct_pretrained_timm_style_mask75.pth")   

if model.load_state_dict(model_dict, strict=False):
    print("HySpark use pretrained weights successfully !")

The downstream pipeline can be referred to UNETR

Acknowledgements:

This code base uses helper functions from SparK.

Citation

If the code, paper and weights help your research, please cite:

@inproceedings{tang2024hyspark,
  title={Hyspark: Hybrid sparse masking for large scale medical image pre-training},
  author={Tang, Fenghe and Xu, Ronghao and Yao, Qingsong and Fu, Xueming and Quan, Quan and Zhu, Heqin and Liu, Zaiyi and Zhou, S Kevin},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={330--340},
  year={2024},
  organization={Springer}
}

License

This project is released under the Apache 2.0 license. Please see the LICENSE file for more information.

Releases

No releases published

Packages

No packages published

Languages