gta
gta copied to clipboard
[ICLR'24] GTA: A Geometry-Aware Attention Mechanism for Multiview Transformers
Geometric Transform Attention
Takeru Miyato · Bernhard Jaeger · Max Welling · Andreas Geiger
OpenReview | arXiv | Project Page
Official reproducing code of our ICLR2024 work: "GTA: A Geometry-Aware Attention Mechanism for Multi-view Transformers", a simple way to make your multi-view transformer more expressive!
(3/15/2024): The GTA mechanism is also effective for image generation, which is a purely 2D task. You can find the experimental details in our camera-ready paper and the implementation at this branch.
Contents
This repository contains the following different codebases, each of which can be accessed by switching to the corresponding branch:
- NVS experiments on CLEVR-TR and MSN-Hard (this branch)
- NVS experiments on ACID and RealEstate (link)
- ImageNet generation with Diffusion transformers (DiT) (link)
You can find the code of GTA for multi-view ViTs here and for image ViTs here.
Please feel free to reach out to us if you have any questions!
Setup
1. Create env and install python libraries
conda create -n gta python=3.9
conda activate gta
pip3 install -r requirements.txt
2. Download dataset
export DATADIR=<path_to_datadir>
mkdir -p $DATADIR
CLEVR-TR
Download the dataset from this link and place it under $DATADIR
MultiShapeNet Hard (MSN-Hard)
gsutil -m cp -r gs://kubric-public/tfds/kubric_frames/multi_shapenet_conditional/2.8.0/ ${DATADIR}/multi_shapenet_frames/
*Pretrained models (MSN-Hard pre-trained models will be uploaded soon)
Training
CLEVR-TR
torchrun --standalone --nnodes 1 --nproc_per_node 4 train.py runs/clevrtr/GTA/gta/config.yaml ${DATADIR}/clevrtr --seed=0
MSN-Hard
torchrun --standalone --nnodes 1 --nproc_per_node 4 train.py runs/msn/GTA/gta_so3/config.yaml ${DATADIR} --seed=0
Evaluation of PSNR, SSIM and LPIPS
python evaluate.py runs/clevrtr/GTA/gta/config.yaml ${DATADIR}/clevrtr $PATH_TO_CHECKPOINT # CLEVR-TR
python evaluate.py runs/msn/GTA/gta_so3/config.yaml ${DATADIR} $PATH_TO_CHECKPOINT # MSN-Hard
Acknowledgements
This repository is built on top of SRT and OSRT created by @stelzner. We would like to thank him for his open-source contribution of the SRT models. We also thank @lucidrains for providing the values of J matrices, which are needed to compute the irreps of SO(3) efficiently.
Citation
@inproceedings{Miyato2024GTA,
title={GTA: A Geometry-Aware Attention Mechanism for Multi-View Transformers},
author={Miyato,Takeru and Jaeger, Bernhard and Welling, Max and Geiger, Andreas},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}