torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Understanding QLora memory consumption for inference

Open Optimox opened this issue 2 months ago • 5 comments

Hello,

I have a question regarding GPU memory consumption during inference.

Before finetuning a model with QLora, the torchtune.LoRALinear modules will convert the original LLM weights to nf4, and create two lora matrices in bfloat16 (if training with bfloat16).

However, after training when saving the checkpoints the lora matrices are merged with the original LLM weights by addition in the function get_merged_lora_ckpt: so the final checkpoints will have less weights but the largest original LLM weights will now be in bfloat16 precision right?

So this means that when performing inference with the saved checkpoints, I will actually need more GPU memory than when doing inference during the finetuning stage (self.weights nf4 + lora weights bf16). Is that correct?

Does it mean that if I care more about GPU memory consumption than speed of inference I should make sure to only save the lora params and reload them on the side of the original LLM weights in nf4 format ?

Thanks for you help!

Optimox avatar Apr 25 '24 12:04 Optimox

Hi @Optimox, thanks for creating the issue! Your understanding is correct: we merge the weights at the end of training and for QLoRA this involves an upcast to bf16. So running inference with this model will require more memory than if you were to do it with the QLoRA version. The rationale here is interoperability -- QLoRA's nf4 format is not a canonical torch dtype like bf16, so you may have challenges trying to use such checkpoints in places where this format is not recognized.

I would actually propose an alternative option here: rather than try to do the upcast on the fly during inference a la QLoRA, why not quantize the bf16 checkpoint and run inference on that? Then you kinda get the best of both worlds: lower memory with fast inference. We integrate with torchao quantization APIs, you can see in this example that quantizing to e.g. 4-bit weights-only quantization will reduce the size of a Llama3-8B checkpoint to < 5GB.

ebsmothers avatar Apr 25 '24 15:04 ebsmothers

Thank you I will have a look at the quantization method, but I guess there is a quantization loss in terms of model quality to expect right ?

Optimox avatar Apr 25 '24 16:04 Optimox

cc @jerryzh168 for potential model accuracy / quality degradation during quantization.

rohan-varma avatar Apr 25 '24 17:04 rohan-varma

Hi @Optimox, by the way we are currently looking into adding support for quantization-aware training (QAT) during finetuning. The goal is to still produce a quantized model for inference, but also mitigate the quantization loss, since QAT has been shown to provide better model quality compared to pure post training techniques. This is still in its early stages but it's something to keep in mind once it's ready in the future.

andrewor14 avatar Apr 25 '24 20:04 andrewor14

Amazing! Keep up the good work!

Optimox avatar Apr 26 '24 07:04 Optimox