TimeChat
TimeChat copied to clipboard
Based transformers version needed for modifying models/modeling_llama.py
I noticed that models/modeling_llama.py is based on the code from here However, I found that your implementation does not support Flash Attention 2. Therefore, I would like to request further modifications. To facilitate this process, could you please specify the exact version of transformers that your implementation is based on? This will make it easier for me to perform a comparison.
Hi, our current code is based on transformers==4.34.0 (https://github.com/RenShuhuai-Andy/TimeChat/blob/master/environment.yml#L273).
To use flash attn 2, you can upgrade transformers and use the following code:
from transformers import AutoTokenizer, AutoModelForCausalLM
self.llama_model = AutoModelForCausalLM.from_pretrained(
llama_model,
torch_dtype="auto",
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True
)
Hi, I tried to use your suggestion, and when using the code from demo.ipynb to generate texts with flash-attn-2 and 8bit(with your low resource mode on), it alerts me that RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16. I tried to load the LLaMA model in 8bit only and it functions well. Do you have any suggestions?
My test
I search on the Internet and found here and used the method with torch.cuda.amp.autocast():, and it alerts RuntimeError: query and key must have the same dtype, hope it helps.
Thank you very much!
It seems that removing torch_dtype="torch.bloat16" could work, but I am not sure it is the right solution.