pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

Flash attention

Open SimJeg opened this issue 2 years ago • 0 comments

Hello,

Vision transformers in timm currently use a custom implementation of attention instead of nn.MultiheadAttention. Pytorch 2.0 will come with flash attention which is an exact implementation of attention, but much faster both for training and inference (see this issue and these results from xformers, 2x faster training for ViT-B-16).

Do you plan to replace Attention by nn.MultiheadAttention at some point, and update the weights accordingly ?

Thank you, Simon

SimJeg avatar Jan 08 '23 09:01 SimJeg