[FEATURE] [RFC] Support for interchangable attention backends
Is your feature request related to a problem? Please describe.
Currently, many models rely on a standard multi-head self-attention operator. Timm currently allows the user choose between 2 versions, an eager pytorch implementation and a fused implementation provided by PyTorch (torch.nn.functional.scaled_dot_product_attention), along with the 3 implementations available through PyTorch (FA2, memory-efficient attention, eager). This can be restrictive (better implementation available elsewhere, upstream issues that prevent PT SDPA from working correctly) or leave performance on the table (FA3 and other newer implementations). Adding more supported backends to timm for the user to choose from (and eventually allowing the user to register their own) will alleviate this restriction. Overall, the current way eager vs sdpa is handled is also somewhat hacky imo.
Describe the solution you'd like My thoughts are to create a registry for backends, similar to how models are managed. Supported backends should be attempted to be imported (flash_attn, xformers, others) and registered on success. The user should also have access to this, if they want to provide some other implementation with the same call signature. I'm not sure if this is the best approach.
Describe alternatives you've considered Alternatives would be to modify/monkeypatch the model implementation to call another attention implementation. Not sure of how necessary this is, since I'm not sure of the performance advantages of FA3/others over PT sdpa for vision models. Part of the reason other libraries keep an attention impl registry seems to be that language has much more variation in attention compared to vision.
@fffffgggg54 yeah, it could be done, have thought a bit about it but hasn't really struck me as a priority vs some other things I've been working towards.
It would really only cover different impl of 'sdpa' ... things like linear attn, etc typically need their own full modules
One way to test value is pass custom Attention modules / blocks to existing vits, etc to see if say FA3 offers a noteworthy benefit on specific hardware archs (e.g. FA3 is Hopper optimized, though not even Blackwell yet I think). The nice thing about the torch builtins is that you typically get pretty decent performance across a wider range of devices vs some custom impl are very hardware specific.
One big complexity with flipping these is that they all seem to handle masks a bit differently. Some are very flexibile and some only work with causal masks. For vision it's often no-mask, but sometimes additive mask is overloaded with pos embeds, etc. FA often can't handle the pos embed cases well...