This repository contains the reference code for the paper Dual-Spatial Normalized Transformer for image captioning.
Clone the repository and create the m2release
conda environment using the environment.yml
file:
conda env create -f environment.yml
conda activate m2release
Then download spacy data by executing the following command:
python -m spacy download en
Note: Python 3.6 and PyTorch (>1.8.0) is required to run our code.
To run the code, annotations, evaluation tools and visual features for the COCO dataset are needed.
Firstly, most annotations have been prepared by [1], please download annotations.zip and rename the extracted folder as annotations, please download image_info_test2014.json and put it into annotations.
Secondly, please download the evaluation tools (Access code: xh6e) and extarct it in the project root directory.
Then, visual features are computed with the code provided by [2]. To reproduce our result, please download the COCO features file in ResNeXt_101/trainval (Access code:bnvu) and extract it as X101_grid_feats_coco_trainval.hdf5.
To reproduce the results reported in our paper, download the pretrained model file DSNT.pth (Access code:gvnn) and place it in the code folder.
Run python test.py
using the following arguments:
Argument | Possible values |
---|---|
--batch_size |
Batch size (default: 40) |
--workers |
Number of workers (default: 8) |
--features_path |
Path to detection features file |
--annotation_folder |
Path to folder with COCO annotations |
Run python train.py
using the following arguments:
Argument | Possible values |
---|---|
--exp_name |
Experiment name |
--batch_size |
Batch size (default: 40) |
--workers |
Number of workers (default: 8) |
--head |
Number of heads (default: 4) |
--resume_last |
If used, the training will be resumed from the last checkpoint. |
--resume_best |
If used, the training will be resumed from the best checkpoint. |
--features_path |
Path to detection features file |
--annotation_folder |
Path to folder with COCO annotations |
--logs_folder |
Path folder for tensorboard logs (default: "tensorboard_logs") |
For example, to train our model with the parameters used in our experiments, use
python train.py --exp_name PGT --batch_size 40 --head 4 --features_path /path/to/features --annotation_folder /path/to/annotations
[1] Cornia, M., Stefanini, M., Baraldi, L., & Cucchiara, R. (2020). Meshed-memory transformer for image captioning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition.
[2] Jiang, H., Misra, I., Rohrbach, M., Learned-Miller, E., & Chen, X. (2020). In defense of grid features for visual question answering. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition.
Thanks Cornia et.al M2 transformer, Zhang et.al RSTNet, and Luo et.al DLCT for their open source code.
Thanks Jiang et.al for the significant discovery in visual representation [2].