Xueming Fu1,2, Quan Quan3, Heqin Zhu1,2, Zaiyi Liu4,5, S.Kevin Zhou1,2,3
2 Suzhou Institute for Advanced Research, University of Science and Technology of China
3 Institute of Computing Technology, Chinese Academy of Sciences
4 Department of Radiology, Guangdong Provincial People’s Hospital
5 Guangdong Provincial Key Laboratory of Artificial Intelligence in Medical Image Analysis and Application
- Code and weight now are released 😎 !
- HySparK is accepted by MICCAI 2024 (Early accept) !
- Code will be released soon ! 😘
- Paper released
- Code released
- Weight released
Name | Pre-trained data scale | Weights |
---|---|---|
HySparK-B | 6.8k CT Scan | hybird_ct_pretrained_timm_style_mask75.pth |
conda create -n hyspark python=3.9
conda activate hyspark
pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
pip install packaging timm==0.5.4
pip install transformers==4.34.1 typed-argument-parser
pip install numpy==1.21.2 opencv-python==4.5.5.64 opencv-python-headless==4.5.5.64
pip install 'monai[all]'
pip install monai==1.2.0
We recommend you to convert the dataset into the nnUNet format.
└── HySparK
├── data
├── Dataset060_TotalSegmentator
└── imagesTr
├── xxx_0000.nii.gz
├── ...
├── Dataset006_FLARE2022
└── imagesTr
├── xxx_0000.nii.gz
├── ...
└── Other_dataset
└── imagesTr
├── xxx_0000.nii.gz
├── ...
Try to use the function organize in nnunet-style or organize_by_names
to prepare your custom datasets.
Then run :
python generate_js.py
A example dataset.json
will be generated in ./data
The content should be like below
{
"training": [
{
"image": "./Dataset060_TotalSegmentator/imagesTr/xxx_0000.nii.gz"
},
{
"image": "./Dataset006_FLARE2022/imagesTr/xxx_0000.nii.gz"
},
]
}
Run training on multi-GPU :
# An example of training on 4 GPUs with DDP
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12351 main.py --exp_name=debug --data_path=./data --model=hyspark --bs=12 --exp_dir=debug_hyspark_ddp_4
Run training on single-GPU :
# An example of training on single GPU
python main.py --exp_name=debug --data_path=./data --model=hyspark --bs=4 --exp_dir=debug_hyspark
Load pre-training weights :
# An example of Fine-tuning on BTCV (num_classes=14)
from models.network.hyspark_model import build_hybird
model = build_hybird(in_channel=1, n_classes=14, img_size=96).cuda()
model_dict = torch.load("./[your_ckpt_path]/hybird_ct_pretrained_timm_style_mask75.pth")
if model.load_state_dict(model_dict, strict=False):
print("HySpark use pretrained weights successfully !")
The downstream pipeline can be referred to UNETR
This code base uses helper functions from SparK.
If the code, paper and weights help your research, please cite:
@inproceedings{tang2024hyspark,
title={Hyspark: Hybrid sparse masking for large scale medical image pre-training},
author={Tang, Fenghe and Xu, Ronghao and Yao, Qingsong and Fu, Xueming and Quan, Quan and Zhu, Heqin and Liu, Zaiyi and Zhou, S Kevin},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={330--340},
year={2024},
organization={Springer}
}
This project is released under the Apache 2.0 license. Please see the LICENSE file for more information.