FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

Add xformer and support training on V100s

Open zhisbug opened this issue 1 year ago • 1 comments

Why are these changes needed?

We are going to use xformer instead of flash attention. Xformer is better because:

  • It supports more GPU architectures than flash attention, including V100
  • It has similar memory footprint and flops compared to flashattention
  • It is developed and maintained by Meta and has more useful functionality.

We can gradually deprecate flash attention.

cc @DachengLi1

Co-authored-by: Dacheng Li[email protected]

Related issue number (if applicable)

Checks

  • [x] I've run format.sh to lint the changes in this PR.
  • [x] I've included any doc changes needed.
  • [ ] I've made sure the relevant tests are passing (if applicable).

zhisbug avatar May 15 '23 09:05 zhisbug

The authors of Flash Attention have also developed the triton-based implementation (https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py). How about replacing their original implementation with the triton-based?

SkyAndCloud avatar May 25 '23 05:05 SkyAndCloud

This works fantastically on V100 gpu(s), please merge it ASAP! Appreciate it!

ss-zheng avatar Jun 21 '23 23:06 ss-zheng

@ss-zheng thanks, will merge soon.

zhisbug avatar Jun 21 '23 23:06 zhisbug

Thanks for this great PR! Can you also apply it to other finetuning scripts like train_lora.py, etc.

nuance1979 avatar Jun 22 '23 17:06 nuance1979

FYI, I just learned that xformers' memory-efficient attention has been upstreamed to torch.nn.functional.scaled_dot_product_attention: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html .

However, it needs code change in transformers to enable/pin down. I think this PR is still worth merging.

nuance1979 avatar Jun 24 '23 17:06 nuance1979

@zhisbug just a reminder that the developer of Flash Attention gave up on the V100s.

https://github.com/Dao-AILab/flash-attention/issues/148#issuecomment-1573216640

surak avatar Jul 20 '23 11:07 surak

The authors of Flash Attention have also developed the triton-based implementation (https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py). How about replacing their original implementation with the triton-based?

That would force everyone to use PyTorch 2.0, which is not deployed in many supercomputing centres.

surak avatar Jul 20 '23 11:07 surak

So we can work on V100 with flash attention models like Llama 2?

jshin49 avatar Aug 02 '23 08:08 jshin49

@jshin49 doesn't seem so, look at the comment before.

surak avatar Aug 02 '23 08:08 surak