LLaVA
LLaVA copied to clipboard
feat: use torch2 implemented flash attn
Current flash_attn only allows some GPUs (e.g. A100) to use it (https://github.com/haotian-liu/LLaVA/issues/153), and using torch 2.0 implemented flash attention allows other GPUs like A6000.
Hi @wade3han
Thank you for your contribution! I tried this on RTX 3090s, and it does not seem to be giving me memory savings/speed ups when using this, compared with using the train.py
.
Can you please share the statistics on your side: (1) memory savings; (2) speedups?
Thank you!
@wade3han any updates?