FastChat
FastChat copied to clipboard
Add xformer and support training on V100s
Why are these changes needed?
We are going to use xformer instead of flash attention. Xformer is better because:
- It supports more GPU architectures than flash attention, including V100
- It has similar memory footprint and flops compared to flashattention
- It is developed and maintained by Meta and has more useful functionality.
We can gradually deprecate flash attention.
cc @DachengLi1
Co-authored-by: Dacheng Li[email protected]
Related issue number (if applicable)
Checks
- [x] I've run
format.sh
to lint the changes in this PR. - [x] I've included any doc changes needed.
- [ ] I've made sure the relevant tests are passing (if applicable).
The authors of Flash Attention have also developed the triton-based implementation (https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py). How about replacing their original implementation with the triton-based?
This works fantastically on V100 gpu(s), please merge it ASAP! Appreciate it!
@ss-zheng thanks, will merge soon.
Thanks for this great PR! Can you also apply it to other finetuning scripts like train_lora.py
, etc.
FYI, I just learned that xformers
' memory-efficient attention has been upstreamed to torch.nn.functional.scaled_dot_product_attention
: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html .
However, it needs code change in transformers
to enable/pin down. I think this PR is still worth merging.
@zhisbug just a reminder that the developer of Flash Attention gave up on the V100s.
https://github.com/Dao-AILab/flash-attention/issues/148#issuecomment-1573216640
The authors of Flash Attention have also developed the triton-based implementation (https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py). How about replacing their original implementation with the triton-based?
That would force everyone to use PyTorch 2.0, which is not deployed in many supercomputing centres.
So we can work on V100 with flash attention models like Llama 2?
@jshin49 doesn't seem so, look at the comment before.