xformers icon indicating copy to clipboard operation
xformers copied to clipboard

Support pure pytorch implementation for memory_efficient_attention

Open hoangmit opened this issue 2 years ago • 4 comments

🚀 Feature

I found that memory_efficient_attention op does not support pure pytorch implementation (e.g. without using device specific op, or library or cython). The current implementations fail to dispatch on CPU device.

In the spirit of torch 2.0, we should have a python implementation for most op.

Something similar to: https://github.com/lucidrains/memory-efficient-attention-pytorch

Pure pytorch implementation is useful for testing, benchmarking, and generic support purposes.

hoangmit avatar Nov 22 '22 19:11 hoangmit

The algorithm from Rabe that you linked to is no the full Flash Attention, it only covers the initial softmax(QKt) but flash attention does the V dot product at the same time and it makes it a lot more effective. This algorithm cannot (to the best of my knowledge) be written using pytorch's primitives, you can bet that it would be used everywhere already if that was the case. Feel free to write a PR to prove me wrong though :)

blefaudeux avatar Nov 22 '22 21:11 blefaudeux

I think the point is to have a fall-back for CPU (and unsupported devices) that is memory-efficient - even if it's much slower than pytorch

danthe3rd avatar Nov 22 '22 21:11 danthe3rd

ah ok, that's not how I understood the title and issue. In that case I'm not sure that there's much to be done beyond slicing per row on the attention matrix, is that what you meant @hoangmit ?

blefaudeux avatar Nov 22 '22 21:11 blefaudeux

I am not familiar with the details. However, lucidrains' repo linked above has both Flash (Tri Dao's) attention[1] and Rabe's attention. (They are probably not super fast.)

Yes, the main point is having fall-backs for CPU and unsupported devices (e.g. Apple) that is memory-efficient. I can wait 2X the time, but Out Of Memory is a deal-breaker.

[1] https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py

hoangmit avatar Nov 23 '22 02:11 hoangmit