TimeChat icon indicating copy to clipboard operation
TimeChat copied to clipboard

Based transformers version needed for modifying models/modeling_llama.py

Open yeahjack opened this issue 1 year ago • 3 comments

I noticed that models/modeling_llama.py is based on the code from here However, I found that your implementation does not support Flash Attention 2. Therefore, I would like to request further modifications. To facilitate this process, could you please specify the exact version of transformers that your implementation is based on? This will make it easier for me to perform a comparison.

yeahjack avatar May 14 '24 12:05 yeahjack

Hi, our current code is based on transformers==4.34.0 (https://github.com/RenShuhuai-Andy/TimeChat/blob/master/environment.yml#L273).

To use flash attn 2, you can upgrade transformers and use the following code:

from transformers import AutoTokenizer, AutoModelForCausalLM

self.llama_model = AutoModelForCausalLM.from_pretrained(
                    llama_model,
                    torch_dtype="auto",
                    attn_implementation="flash_attention_2",
                    low_cpu_mem_usage=True
                )

RenShuhuai-Andy avatar May 14 '24 13:05 RenShuhuai-Andy

Hi, I tried to use your suggestion, and when using the code from demo.ipynb to generate texts with flash-attn-2 and 8bit(with your low resource mode on), it alerts me that RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16. I tried to load the LLaMA model in 8bit only and it functions well. Do you have any suggestions?

My test

I search on the Internet and found here and used the method with torch.cuda.amp.autocast():, and it alerts RuntimeError: query and key must have the same dtype, hope it helps.

Thank you very much!

yeahjack avatar May 14 '24 15:05 yeahjack

It seems that removing torch_dtype="torch.bloat16" could work, but I am not sure it is the right solution.

yeahjack avatar May 14 '24 15:05 yeahjack