TokenFusion icon indicating copy to clipboard operation
TokenFusion copied to clipboard

[CVPR 2022] Code release for "Multimodal Token Fusion for Vision Transformers"

Multimodal Token Fusion for Vision Transformers

By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang.


This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022.

Homogeneous predictions,

Heterogeneous predictions,


For semantic segmentation task on NYUDv2 (official dataset), we provide a link to download the dataset here. The provided dataset is originally preprocessed in this repository, and we add depth data in it.

For image-to-image translation task, we use the sample dataset of Taskonomy, where a link to download the sample dataset is here.

Please modify the data paths in the codes, where we add comments 'Modify data path'.



Semantic Segmentation


cd semantic_segmentation

Download the segformer pretrained model (pretrained on ImageNet) from weights, e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'.

Training script for segmentation with RGB and Depth input,

python --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2

Evaluation script,

python --gpu 0 --resume path_to_pth --evaluate  # optionally use --save-img to visualize results

Checkpoint models, training logs, mask ratios and the single-scale performance on NYUDv2 are provided as follows:

Method Backbone Pixel Acc. (%) Mean Acc. (%) Mean IoU (%) Download
CEN ResNet101 76.2 62.8 51.1 Google Drive
CEN ResNet152 77.0 64.4 51.6 Google Drive
Ours SegFormer-B3 78.7 67.5 54.8 Google Drive

Mindspore implementation is available at:

Image-to-Image Translation


cd image2image_translation

Training script, from Shade and Texture to RGB,

python --gpu 0 -c exp_name

This script will auto-evaluate on the validation dataset every 5 training epochs.

Predicted images will be automatically saved during training, in the following folder structure:

  ├── input0  # 1st modality input
  ├── input1  # 2nd modality input
  ├── fake0   # 1st branch output 
  ├── fake1   # 2nd branch output
  ├── fake2   # ensemble output
  ├── best    # current best output
  │    ├── fake0
  │    ├── fake1
  │    └── fake2
  └── real    # ground truth output

Checkpoint models:

Method Task FID KID Download
CEN Texture+Shade->RGB 62.6 1.65 -
Ours Texture+Shade->RGB 45.5 1.00 Google Drive

3D Object Detection (under construction)

Data preparation, environments, and training scripts follow Group-Free and ImVoteNet.


CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name


If you find our work useful for your research, please consider citing the following paper.

  title={Multimodal Token Fusion for Vision Transformers},
  author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},