torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

why is xformers not used for attention computation?

Open jason718 opened this issue 1 year ago • 4 comments

Curious why xformers is not used? Is it for simplicity or is there performance reason.

jason718 avatar Oct 09 '24 23:10 jason718

F.scaled_dot_product_attention calls into flash or memory efficient attention depending on some factors (should be mainly flash for the torchtitan case iiuc). Are there other ops that you have in mind?

awgu avatar Oct 09 '24 23:10 awgu

@awgu It looks like xformers has support for Flash Attention v3 starting from 0.0.28 (flash3.FwOp and flash3.BwOp). Could bring extra training efficiency for Hopper arch as it's not implemented in pytorch yet.

As I read it from the blog, this brings a 1.6x-1.8x speedup over FAv2.

image

casper-hansen avatar Oct 11 '24 10:10 casper-hansen

@casper-hansen Makes sense!

I guess it should not be too hard for users to install xformers and replace the F.scaled_dot_product_attention_call with the xformers attention call. This should work as long as the xformers attention is torch.compile-compatible, which I recall it is.

Since torchtitan is mainly for showing an example of how to set this kind of distributed training up, I think including xformers attention is not as important as showing what is achievable with torch native.

awgu avatar Oct 11 '24 15:10 awgu

@casper-hansen On H100, F.scaled_dot_product_attention calls into CuDNN attention, which has a much smaller gap in performance with FA3.

Chillee avatar Oct 13 '24 08:10 Chillee