why is xformers not used for attention computation?
Curious why xformers is not used? Is it for simplicity or is there performance reason.
F.scaled_dot_product_attention calls into flash or memory efficient attention depending on some factors (should be mainly flash for the torchtitan case iiuc). Are there other ops that you have in mind?
@awgu It looks like xformers has support for Flash Attention v3 starting from 0.0.28 (flash3.FwOp and flash3.BwOp). Could bring extra training efficiency for Hopper arch as it's not implemented in pytorch yet.
As I read it from the blog, this brings a 1.6x-1.8x speedup over FAv2.
@casper-hansen Makes sense!
I guess it should not be too hard for users to install xformers and replace the F.scaled_dot_product_attention_call with the xformers attention call. This should work as long as the xformers attention is torch.compile-compatible, which I recall it is.
Since torchtitan is mainly for showing an example of how to set this kind of distributed training up, I think including xformers attention is not as important as showing what is achievable with torch native.
@casper-hansen On H100, F.scaled_dot_product_attention calls into CuDNN attention, which has a much smaller gap in performance with FA3.