pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

Add T2T_ViT

Open brianhou0208 opened this issue 10 months ago • 4 comments

Hi @rwightman this PR resolved #2364 , please check.

Result

test T2T-ViT model and weight on ImageNet val dataset

Model Acc@1 Acc@5 FLOPs#G MACs#G Params#M
t2t_vit_7 71.6760 90.8860 2.0261 0.9755 4.2557
t2t_vit_10 75.1500 92.8060 2.6476 1.2854 5.8347
t2t_vit_12 76.4800 93.4840 3.0620 1.492 6.8874
t2t_vit_14 81.5000 95.6660 8.7526 4.334 21.4658
t2t_vit_19 81.9320 95.7440 15.6663 7.7868 39.0851
t2t_vit_24 82.2760 95.8860 25.4543 12.6759 64.0010
t2t_vit_t_14 81.6880 95.8520 8.6881 4.334 21.4654
t2t_vit_t_19 82.4420 96.0820 15.6018 7.7868 39.0847
t2t_vit_t_24 82.5540 96.0640 25.3898 12.6759 64.0006
test code
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.utils.metrics import AverageMeter, accuracy

device = torch.device('cuda:0')

if __name__ == "__main__":
    val_dataset = datasets.ImageFolder(
        './data/val',
        transforms.Compose([
            transforms.Resize(int(224 / 0.9), interpolation=3),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
    )

    val_loader = DataLoader(
        val_dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True)
    
    for name in timm.list_models('t2t_vit*'):
        model = timm.create_model(name, pretrained=True).eval()
        model.to(device)
        top1 = AverageMeter()
        top5 = AverageMeter()

        with torch.no_grad():
            for images, target in tqdm(val_loader):
                images = images.to(device)
                target = target.to(device)
                output = model(images)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1, images.size(0))
                top5.update(acc5, images.size(0))
        print(f"Model {name} ACC@1 {top1.avg:.4f} ACC@5 {top5.avg:.4f}")

output log
100%|██████████████████████████████████████████████| 196/196 [00:39<00:00,  4.92it/s]
Model t2t_vit_7 ACC@1 71.6760 ACC@5 90.8860
FLOPs 2.0261 GFLOPS / MACs 975.534 MMACs / Params 4.2557 M

100%|██████████████████████████████████████████████| 196/196 [00:39<00:00,  4.96it/s]
Model t2t_vit_10 ACC@1 75.1500 ACC@5 92.8060
FLOPs 2.6476 GFLOPS / MACs 1.2854 GMACs / Params 5.8347 M

100%|██████████████████████████████████████████████| 196/196 [00:40<00:00,  4.88it/s]
Model t2t_vit_12 ACC@1 76.4800 ACC@5 93.4840
FLOPs 3.062 GFLOPS / MACs 1.492 GMACs / Params 6.8874 M

100%|██████████████████████████████████████████████| 196/196 [01:08<00:00,  2.87it/s]
Model t2t_vit_14 ACC@1 81.5000 ACC@5 95.6660
FLOPs 8.7526 GFLOPS / MACs 4.334 GMACs / Params 21.4658 M

100%|██████████████████████████████████████████████| 196/196 [01:45<00:00,  1.86it/s]
Model t2t_vit_19 ACC@1 81.9320 ACC@5 95.7440
FLOPs 15.6663 GFLOPS / MACs 7.7868 GMACs / Params 39.0851 M

100%|██████████████████████████████████████████████| 196/196 [02:31<00:00,  1.30it/s]
Model t2t_vit_24 ACC@1 82.2760 ACC@5 95.8860
FLOPs 25.4543 GFLOPS / MACs 12.6759 GMACs / Params 64.001 M

100%|██████████████████████████████████████████████| 196/196 [01:28<00:00,  2.20it/s]
Model t2t_vit_t_14 ACC@1 81.6880 ACC@5 95.8520
FLOPs 8.6881 GFLOPS / MACs 4.334 GMACs / Params 21.4654 M

100%|██████████████████████████████████████████████| 196/196 [02:04<00:00,  1.57it/s]
Model t2t_vit_t_19 ACC@1 82.4420 ACC@5 96.0820
FLOPs 15.6018 GFLOPS / MACs 7.7868 GMACs / Params 39.0847 M

100%|██████████████████████████████████████████████| 196/196 [02:51<00:00,  1.15it/s]
Model t2t_vit_t_24 ACC@1 82.5540 ACC@5 96.0640
FLOPs 25.3898 GFLOPS / MACs 12.6759 GMACs / Params 64.0006 M
calculate FLOPs/MACs/Params tool

report from calflops

from calflops import calculate_flops
def flops_param(model):
    flops, macs, params = calculate_flops(
        model=model,
        input_shape=(1, 3, 224, 224),
        output_as_string=True,
        output_precision=4,
        print_detailed=False,
        print_results=False
    )
    print(f"FLOPs {flops} / MACs {macs} / Params {params}")

Reference

paper: https://arxiv.org/pdf/2101.11986 code: https://github.com/yitu-opensource/T2T-ViT

brianhou0208 avatar Jan 22 '25 20:01 brianhou0208

@brianhou0208 thanks for the work, and looks like a good job getting it in shape. I took a closer look using your code but I have some doubts about this model

  1. it requires a workaround w/ AMP + float16 to avoid NaN (see next post)
  2. compared to simpler models it's really not performing better givent the speed, especially comparing these https://huggingface.co/collections/timm/searching-for-better-vit-baselines-663eb74f64f847d2f35a9c19 they are faster and better accuracy at a fraction of the param count and they have fewer macs/activations. Even comparing some models that have been there longer like deit3 (e.g. deit3_medium_patch16_224) they are faster/simpler/smaller than these.

For speed comparisons I disabled F.sdpa in existing vit to be fair. Simpler vits with higher acccuracy (imagenet-1k pretrain also to be fair) are often 30-40% faster.

So not convinced this is worth the add. Was there a particular reason you had interest in the model?

rwightman avatar Jan 24 '25 17:01 rwightman

    def single_attn(self, x: torch.Tensor) -> torch.Tensor:
        k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)

        if not torch.jit.is_scripting():
            with torch.autocast(device_type=v.device.type, enabled=False):
                y = self._attn_impl(k, q, v)
        else:
            y = self._attn_impl(k, q, v)

        # skip connection
        y = v + self.dp(self.proj(y))  # same as token_transformer in T2T layer, use v as skip connection
        return y

    def _attn_impl(self, k, q, v):
        kp, qp = self.prm_exp(k), self.prm_exp(q)  # (B, T, m), (B, T, m)
        D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2)  # (B, T, m) * (B, m) -> (B, T, 1)
        kptv = torch.einsum('bin,bim->bnm', v.float(), kp)  # (B, emb, m)
        y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon)  # (B, T, emb)/Diag
        return y

rwightman avatar Jan 24 '25 17:01 rwightman

Hi @rwightman, I agree with your observation. The T2T-ViT model does not have advantages over other models. The only advantage might be that it does not use any nn.Conv2d at all, relying instead on the nn.Unfold method to extract patches. Most ViT-based models require some form of convolution for input processing, but the T2T-ViT architecture can completely bypass convolution, maybe this architecture can be further explored...

Another issue occurs when using pre-trained weights and testing whether the structure of first_conv is adaptive to the number of input (C, H, W). If first_conv is set to None, the test_model_default_cfgs_non_std test will fail. https://github.com/huggingface/pytorch-image-models/blob/d81da93c1640a504977b0ee494791e5c634ec63c/tests/test_models.py#L371-L376 In test_model_load_pretrained , if first_convd is like T2T-ViT without Conv, passing this parameter to nn.Linear instead of nn.Conv2d will also report an error. https://github.com/huggingface/pytorch-image-models/blob/d81da93c1640a504977b0ee494791e5c634ec63c/timm/models/_builder.py#L225-L239

Since this involves modifying test_models, and adding T2T-ViT is not worth the effort, I should probably close this PR.

brianhou0208 avatar Jan 24 '25 19:01 brianhou0208

@brianhou0208 I don't know if not having the input conv is a 'feature', my very first vit impl here, before the official JAX code was released that used the Conv2D trick was this: https://github.com/huggingface/pytorch-image-models/blob/7613094fb5cb960813f606a5c42e3c00c961bc8f/timm/models/vision_transformer.py#L139-L169

The conv approach was faster since it was an optimized kernel and not a chain of API calls, I suppose torch.compile would rectify most of that but still don't see the downside to the conv.

Also the packed vit I started working on (have yet to pick it back up) has to push patchification further into the data pipeline, https://github.com/huggingface/pytorch-image-models/blob/379780bb6ca3304d63bf8ca789d5bbce5949d0b5/timm/models/vision_transformer_packed.py

rwightman avatar Jan 24 '25 21:01 rwightman