Fix llama
Fix dtype mismatch error when load_in_low_bit='bf16'
CPU and GPU both: RuntimeError: expected m1 and m2 to have the same dtype, but got: float != c10::BFloat16 and details could be found in the issue https://github.com/analytics-zoo/nano/issues/1111
@rnwang04 Please take a look at it?
This PR can fix llama, but will other model meet similar issues ?
Shall we add torch_dtype=torch.bfloat16 for load_in_low_bit='bf16' as we did for fp16 too ?
This PR can fix llama, but will other model meet similar issues ? Shall we add
torch_dtype=torch.bfloat16forload_in_low_bit='bf16'as we did for fp16 too ?
Yes