This is the official implementation of Cross-modal Prototype Driven Network for Radiology Report Generation accepted to ECCV2022.
Radiology report generation (RRG) aims to describe automatically a radiology image with human-like language. As an alternative to expert diagnosis, RRG could potentially support the work of radiologists, reducing the burden of manual reporting. Previous approaches often adopt an encoder-decoder architecture and focus on single-modal feature learning, while few studies explore cross-modal feature interaction. Here we propose a Cross-modal PROtotype driven NETwork (XPRONET) to promote cross-modal pattern learning and exploit it to improve the task of radiology report generation. This is achieved by three well-designed, fully differentiable and complementary modules: a shared cross-modal prototype matrix to record the cross-modal proto- types; a cross-modal prototype network to learn the cross-modal prototypes and embed the cross-modal information into the visual and textual features; and an improved multi-label contrastive loss to enable and enhance multi-label prototype learning. Experimental results demonstrate that XPRONET obtains substantial improvements on two commonly used medical report generation benchmark datasets, i.e., IU-Xray and MIMIC-CXR, where its performance exceeds recent state-of-the-art approaches by a large margin on IU-Xray dataset and achieves the SOTA performance on MIMIC-CXR.
If you use or extend our work, please cite our paper.
@inproceedings{wang2022cross,
title={Cross-modal prototype driven network for radiology report generation},
author={Wang, Jun and Bhalerao, Abhir and He, Yulan},
booktitle={Computer Vision--ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23--27, 2022, Proceedings, Part XXXV},
pages={563--579},
year={2022},
organization={Springer}
}
12/22/2023
- XPRONet now supports Multi-GPU (Distributed) and Mixed Precision Training, to support the new features, please ensure Pytorch Version >= 1.8 (Note there may be slight difference for the test results between Multi-GPU test and Single GPU test due to the DDP sampler.
- We provide a separate test scripts to enable quick test in the trained dataset.
- We recommend to re-generate the initial prototype matrix if you have your own data-precessing on the dataset, e.g, different image resolution or downsampled images.
- We optimize and clean some parts of the code.
The following packages are required to run the scripts:
- [Python >= 3.6]
- [PyTorch >= 1.6]
- [Torchvision]
- [Pycocoevalcap]
- You can create the environment via conda:
conda env create --name [env_name] --file env.yml
You can download the trained models here.
We use two datasets (IU X-Ray and MIMIC-CXR) in our paper.
For IU X-Ray
, you can download the dataset from here.
For MIMIC-CXR
, you can download the dataset from here.
After downloading the datasets, put them in the directory data
.
You can generate the pesudo label for each dataset by leveraging the automatic labeler ChexBert.
We also provide the generated labels in the files directory.
The processed cross-modal prototypes are provided in the files directory. For those who prefer to generate the prototype for initilization by their own, you should:
- Leverage the pretrained visual extractor (imagenet-pretrained) and Bert (ChexBert) to extract the visual and texual features.
- Concat the visual and texual features.
- Utilize K-mean algorithm to cluster to cross-modal features to 14 clusters.
The above procedure is elobarately described in our paper.
Our experiments on IU X-Ray were done on a machine with 1x2080Ti.
Run bash run_iu_xray.sh
to train a model on the IU X-Ray data.
Our experiments on MIMIC-CXR were done on a machine with 4x2080Ti.
Run bash run_mimic_cxr.sh
to train a model on the MIMIC-CXR data.
- A slightly better result can be seen on mimic-cxr when replacing the test labels in the dataset with the labels generated by a Densenet-121 visual extrator trained on the training set, attached in files directory.
Original test labels: {'BLEU_1': 0.34400243008496584, 'BLEU_2': 0.21477888645882087, 'BLEU_3': 0.14566940489219155, 'BLEU_4': 0.10548765498123501, 'METEOR': 0.13756509292234576, 'ROUGE_L': 0.2788686298013669, 'Cider': 0.1542425919149904}
Visual extractor labels: {'BLEU_1': 0.3439767902051841, 'BLEU_2': 0.21472182985678803, 'BLEU_3': 0.1456235087036771, 'BLEU_4': 0.10550268589574416, 'METEOR': 0.13761582328649768, 'ROUGE_L': 0.2789035755567323, 'Cider': 0.15611385337225203}
Our project references the codes in the following repos. Thanks for their works and sharing.