pytorch-image-models
pytorch-image-models copied to clipboard
Flash attention
Hello,
Vision transformers in timm currently use a custom implementation of attention instead of nn.MultiheadAttention. Pytorch 2.0 will come with flash attention which is an exact implementation of attention, but much faster both for training and inference (see this issue and these results from xformers, 2x faster training for ViT-B-16).
Do you plan to replace Attention by nn.MultiheadAttention at some point, and update the weights accordingly ?
Thank you, Simon