qlora
qlora copied to clipboard
torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) is this intentional in qlora.py?
In qlora.py line https://github.com/artidoro/qlora/blob/main/qlora.py#L279 , if fp16 is specified we assign torch_dtype to torch.float32? Shall we do torch.float16 instead, or this is intentional, if so what's the reason and why we can't just do torch_dtype=torch.bfloat16 if args.bf16 else torch.float32
For this line https://github.com/artidoro/qlora/blob/main/qlora.py#L263 it's correct though
I'm also wondering about this.
See #172