TokenFusion
TokenFusion copied to clipboard
[CVPR 2022] Code release for "Multimodal Token Fusion for Vision Transformers"
Multimodal Token Fusion for Vision Transformers
By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang.
This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022.

Homogeneous predictions,

Heterogeneous predictions,

Datasets
For semantic segmentation task on NYUDv2 (official dataset), we provide a link to download the dataset here. The provided dataset is originally preprocessed in this repository, and we add depth data in it.
For image-to-image translation task, we use the sample dataset of Taskonomy, where a link to download the sample dataset is here.
Please modify the data paths in the codes, where we add comments 'Modify data path'.
Dependencies
python==3.6
pytorch==1.7.1
torchvision==0.8.2
numpy==1.19.2
Semantic Segmentation
First,
cd semantic_segmentation
Download the segformer pretrained model (pretrained on ImageNet) from weights, e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'.
Training script for segmentation with RGB and Depth input,
python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2
Evaluation script,
python main.py --gpu 0 --resume path_to_pth --evaluate # optionally use --save-img to visualize results
Checkpoint models, training logs, mask ratios and the single-scale performance on NYUDv2 are provided as follows:
Method | Backbone | Pixel Acc. (%) | Mean Acc. (%) | Mean IoU (%) | Download |
---|---|---|---|---|---|
CEN | ResNet101 | 76.2 | 62.8 | 51.1 | Google Drive |
CEN | ResNet152 | 77.0 | 64.4 | 51.6 | Google Drive |
Ours | SegFormer-B3 | 78.7 | 67.5 | 54.8 | Google Drive |
Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion
Image-to-Image Translation
First,
cd image2image_translation
Training script, from Shade and Texture to RGB,
python main.py --gpu 0 -c exp_name
This script will auto-evaluate on the validation dataset every 5 training epochs.
Predicted images will be automatically saved during training, in the following folder structure:
code_root/ckpt/exp_name/results
├── input0 # 1st modality input
├── input1 # 2nd modality input
├── fake0 # 1st branch output
├── fake1 # 2nd branch output
├── fake2 # ensemble output
├── best # current best output
│ ├── fake0
│ ├── fake1
│ └── fake2
└── real # ground truth output
Checkpoint models:
Method | Task | FID | KID | Download |
---|---|---|---|---|
CEN | Texture+Shade->RGB | 62.6 | 1.65 | - |
Ours | Texture+Shade->RGB | 45.5 | 1.00 | Google Drive |
3D Object Detection (under construction)
Data preparation, environments, and training scripts follow Group-Free and ImVoteNet.
E.g.,
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name
Citation
If you find our work useful for your research, please consider citing the following paper.
@inproceedings{wang2022tokenfusion,
title={Multimodal Token Fusion for Vision Transformers},
author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}