LLaVA icon indicating copy to clipboard operation
LLaVA copied to clipboard

feat: use torch2 implemented flash attn

Open wade3han opened this issue 1 year ago • 2 comments

Current flash_attn only allows some GPUs (e.g. A100) to use it (https://github.com/haotian-liu/LLaVA/issues/153), and using torch 2.0 implemented flash attention allows other GPUs like A6000.

wade3han avatar Jun 04 '23 09:06 wade3han

Hi @wade3han

Thank you for your contribution! I tried this on RTX 3090s, and it does not seem to be giving me memory savings/speed ups when using this, compared with using the train.py.

Can you please share the statistics on your side: (1) memory savings; (2) speedups?

Thank you!

haotian-liu avatar Jun 11 '23 05:06 haotian-liu

@wade3han any updates?

pineking avatar Jul 12 '23 10:07 pineking