Skip to content

Latest commit

 

History

History
110 lines (80 loc) · 4.29 KB

README.md

File metadata and controls

110 lines (80 loc) · 4.29 KB

Geometric Transform Attention

Takeru Miyato · Bernhard Jaeger · Max Welling · Andreas Geiger

gta_mech

Official reproducing code of our ICLR2024 work: "GTA: A Geometry-Aware Attention Mechanism for Multi-view Transformers", a simple way to make your multi-view transformer more expressive!

(3/15/2024): The GTA mechanism is also effective for image generation, which is a purely 2D task. You can find the experimental details in our camera-ready paper and the implementation at this branch.

Contents

This repository contains the following different codebases, each of which can be accessed by switching to the corresponding branch:

  • NVS experiments on CLEVR-TR and MSN-Hard (this branch)
  • NVS experiments on ACID and RealEstate (link)
  • ImageNet generation with Diffusion transformers (DiT) (link)

You can find the code of GTA for multi-view ViTs here and for image ViTs here.

Please feel free to reach out to us if you have any questions!

Setup

1. Create env and install python libraries

conda create -n gta python=3.9
conda activate gta
pip3 install -r requirements.txt

2. Download dataset

export DATADIR=<path_to_datadir>
mkdir -p $DATADIR

CLEVR-TR

Download the dataset from this link and place it under $DATADIR

clevr1 clevr2

MultiShapeNet Hard (MSN-Hard)

gsutil -m cp -r gs://kubric-public/tfds/kubric_frames/multi_shapenet_conditional/2.8.0/ ${DATADIR}/multi_shapenet_frames/

gta_mech

*Pretrained models (MSN-Hard pre-trained models will be uploaded soon)

Training

CLEVR-TR

torchrun --standalone --nnodes 1 --nproc_per_node 4 train.py runs/clevrtr/GTA/gta/config.yaml  ${DATADIR}/clevrtr --seed=0 

MSN-Hard

torchrun --standalone --nnodes 1 --nproc_per_node 4 train.py runs/msn/GTA/gta_so3/config.yaml  ${DATADIR} --seed=0 

Evaluation of PSNR, SSIM and LPIPS

python evaluate.py runs/clevrtr/GTA/gta/config.yaml ${DATADIR}/clevrtr $PATH_TO_CHECKPOINT # CLEVR-TR
python evaluate.py runs/msn/GTA/gta_so3/config.yaml ${DATADIR} $PATH_TO_CHECKPOINT # MSN-Hard

Acknowledgements

This repository is built on top of SRT and OSRT created by @stelzner. We would like to thank him for his open-source contribution of the SRT models. We also thank @lucidrains for providing the values of J matrices, which are needed to compute the irreps of SO(3) efficiently.

Citation

@inproceedings{Miyato2024GTA,
    title={GTA: A Geometry-Aware Attention Mechanism for Multi-View Transformers},
    author={Miyato,Takeru and Jaeger, Bernhard and Welling, Max and Geiger, Andreas},
    booktitle={International Conference on Learning Representations (ICLR)},
    year={2024}
}