pytorch-image-models
pytorch-image-models copied to clipboard
Add T2T_ViT
Hi @rwightman this PR resolved #2364 , please check.
Result
test T2T-ViT model and weight on ImageNet val dataset
| Model | Acc@1 | Acc@5 | FLOPs#G | MACs#G | Params#M |
|---|---|---|---|---|---|
| t2t_vit_7 | 71.6760 | 90.8860 | 2.0261 | 0.9755 | 4.2557 |
| t2t_vit_10 | 75.1500 | 92.8060 | 2.6476 | 1.2854 | 5.8347 |
| t2t_vit_12 | 76.4800 | 93.4840 | 3.0620 | 1.492 | 6.8874 |
| t2t_vit_14 | 81.5000 | 95.6660 | 8.7526 | 4.334 | 21.4658 |
| t2t_vit_19 | 81.9320 | 95.7440 | 15.6663 | 7.7868 | 39.0851 |
| t2t_vit_24 | 82.2760 | 95.8860 | 25.4543 | 12.6759 | 64.0010 |
| t2t_vit_t_14 | 81.6880 | 95.8520 | 8.6881 | 4.334 | 21.4654 |
| t2t_vit_t_19 | 82.4420 | 96.0820 | 15.6018 | 7.7868 | 39.0847 |
| t2t_vit_t_24 | 82.5540 | 96.0640 | 25.3898 | 12.6759 | 64.0006 |
test code
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.utils.metrics import AverageMeter, accuracy
device = torch.device('cuda:0')
if __name__ == "__main__":
val_dataset = datasets.ImageFolder(
'./data/val',
transforms.Compose([
transforms.Resize(int(224 / 0.9), interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
)
val_loader = DataLoader(
val_dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True)
for name in timm.list_models('t2t_vit*'):
model = timm.create_model(name, pretrained=True).eval()
model.to(device)
top1 = AverageMeter()
top5 = AverageMeter()
with torch.no_grad():
for images, target in tqdm(val_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1, images.size(0))
top5.update(acc5, images.size(0))
print(f"Model {name} ACC@1 {top1.avg:.4f} ACC@5 {top5.avg:.4f}")
output log
100%|██████████████████████████████████████████████| 196/196 [00:39<00:00, 4.92it/s]
Model t2t_vit_7 ACC@1 71.6760 ACC@5 90.8860
FLOPs 2.0261 GFLOPS / MACs 975.534 MMACs / Params 4.2557 M
100%|██████████████████████████████████████████████| 196/196 [00:39<00:00, 4.96it/s]
Model t2t_vit_10 ACC@1 75.1500 ACC@5 92.8060
FLOPs 2.6476 GFLOPS / MACs 1.2854 GMACs / Params 5.8347 M
100%|██████████████████████████████████████████████| 196/196 [00:40<00:00, 4.88it/s]
Model t2t_vit_12 ACC@1 76.4800 ACC@5 93.4840
FLOPs 3.062 GFLOPS / MACs 1.492 GMACs / Params 6.8874 M
100%|██████████████████████████████████████████████| 196/196 [01:08<00:00, 2.87it/s]
Model t2t_vit_14 ACC@1 81.5000 ACC@5 95.6660
FLOPs 8.7526 GFLOPS / MACs 4.334 GMACs / Params 21.4658 M
100%|██████████████████████████████████████████████| 196/196 [01:45<00:00, 1.86it/s]
Model t2t_vit_19 ACC@1 81.9320 ACC@5 95.7440
FLOPs 15.6663 GFLOPS / MACs 7.7868 GMACs / Params 39.0851 M
100%|██████████████████████████████████████████████| 196/196 [02:31<00:00, 1.30it/s]
Model t2t_vit_24 ACC@1 82.2760 ACC@5 95.8860
FLOPs 25.4543 GFLOPS / MACs 12.6759 GMACs / Params 64.001 M
100%|██████████████████████████████████████████████| 196/196 [01:28<00:00, 2.20it/s]
Model t2t_vit_t_14 ACC@1 81.6880 ACC@5 95.8520
FLOPs 8.6881 GFLOPS / MACs 4.334 GMACs / Params 21.4654 M
100%|██████████████████████████████████████████████| 196/196 [02:04<00:00, 1.57it/s]
Model t2t_vit_t_19 ACC@1 82.4420 ACC@5 96.0820
FLOPs 15.6018 GFLOPS / MACs 7.7868 GMACs / Params 39.0847 M
100%|██████████████████████████████████████████████| 196/196 [02:51<00:00, 1.15it/s]
Model t2t_vit_t_24 ACC@1 82.5540 ACC@5 96.0640
FLOPs 25.3898 GFLOPS / MACs 12.6759 GMACs / Params 64.0006 M
calculate FLOPs/MACs/Params tool
report from calflops
from calflops import calculate_flops
def flops_param(model):
flops, macs, params = calculate_flops(
model=model,
input_shape=(1, 3, 224, 224),
output_as_string=True,
output_precision=4,
print_detailed=False,
print_results=False
)
print(f"FLOPs {flops} / MACs {macs} / Params {params}")
Reference
paper: https://arxiv.org/pdf/2101.11986 code: https://github.com/yitu-opensource/T2T-ViT
@brianhou0208 thanks for the work, and looks like a good job getting it in shape. I took a closer look using your code but I have some doubts about this model
- it requires a workaround w/ AMP + float16 to avoid NaN (see next post)
- compared to simpler models it's really not performing better givent the speed, especially comparing these https://huggingface.co/collections/timm/searching-for-better-vit-baselines-663eb74f64f847d2f35a9c19 they are faster and better accuracy at a fraction of the param count and they have fewer macs/activations. Even comparing some models that have been there longer like deit3 (e.g. deit3_medium_patch16_224) they are faster/simpler/smaller than these.
For speed comparisons I disabled F.sdpa in existing vit to be fair. Simpler vits with higher acccuracy (imagenet-1k pretrain also to be fair) are often 30-40% faster.
So not convinced this is worth the add. Was there a particular reason you had interest in the model?
def single_attn(self, x: torch.Tensor) -> torch.Tensor:
k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)
if not torch.jit.is_scripting():
with torch.autocast(device_type=v.device.type, enabled=False):
y = self._attn_impl(k, q, v)
else:
y = self._attn_impl(k, q, v)
# skip connection
y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection
return y
def _attn_impl(self, k, q, v):
kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m)
D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1)
kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m)
y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag
return y
Hi @rwightman, I agree with your observation. The T2T-ViT model does not have advantages over other models. The only advantage might be that it does not use any nn.Conv2d at all, relying instead on the nn.Unfold method to extract patches.
Most ViT-based models require some form of convolution for input processing, but the T2T-ViT architecture can completely bypass convolution, maybe this architecture can be further explored...
Another issue occurs when using pre-trained weights and testing whether the structure of first_conv is adaptive to the number of input (C, H, W). If first_conv is set to None, the test_model_default_cfgs_non_std test will fail.
https://github.com/huggingface/pytorch-image-models/blob/d81da93c1640a504977b0ee494791e5c634ec63c/tests/test_models.py#L371-L376
In test_model_load_pretrained , if first_convd is like T2T-ViT without Conv, passing this parameter to nn.Linear instead of nn.Conv2d will also report an error.
https://github.com/huggingface/pytorch-image-models/blob/d81da93c1640a504977b0ee494791e5c634ec63c/timm/models/_builder.py#L225-L239
Since this involves modifying test_models, and adding T2T-ViT is not worth the effort, I should probably close this PR.
@brianhou0208 I don't know if not having the input conv is a 'feature', my very first vit impl here, before the official JAX code was released that used the Conv2D trick was this: https://github.com/huggingface/pytorch-image-models/blob/7613094fb5cb960813f606a5c42e3c00c961bc8f/timm/models/vision_transformer.py#L139-L169
The conv approach was faster since it was an optimized kernel and not a chain of API calls, I suppose torch.compile would rectify most of that but still don't see the downside to the conv.
Also the packed vit I started working on (have yet to pick it back up) has to push patchification further into the data pipeline, https://github.com/huggingface/pytorch-image-models/blob/379780bb6ca3304d63bf8ca789d5bbce5949d0b5/timm/models/vision_transformer_packed.py