litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Unexpected behaviour in inference with merged QLoRA weights

Open michele-milesi opened this issue 1 year ago • 3 comments

Hi, a few weeks ago @morettif and I finetuned the Llama70B with QLoRA on a H100:

  • r=32
  • alpha=64
  • quantize=bnb.nf4-dq
  • precision=bf16-true
  • weight_decay=0
  • batch_size=32
  • micro_batch_size=2
  • lora_dropout=0.05
  • All LoRA layers

The validation loss after the finetuning was about 0.5, but during inference on the validation set, we obtained very strange results. We checked the loss on the validation set and it was ~2.5 (as before the finetuning).

We used the generate/lora.py script and we noticed that after commenting this line: https://github.com/Lightning-AI/lit-gpt/blob/8a101b633dfeafd378f8fbaba6a80a4417c33576/generate/lora.py#L120 the model generated meaningful output with a loss coherent with the one observed during training (~0.5).

We analyzed the dtype of the pre-trained model weights at the moment of the call to the merge_lora_weights() function, some of them are torch.uint8, this causes that during merging the following if (in the merge() function) is evaluated as True: https://github.com/Lightning-AI/lit-gpt/blob/8a101b633dfeafd378f8fbaba6a80a4417c33576/lit_gpt/lora.py#L151

Could there be an error in this piece of code related to the management of quantization?

michele-milesi avatar Feb 20 '24 10:02 michele-milesi

Hi @michele-milesi Thanks for reporting.

We analyzed the dtype of the pre-trained model weights at the moment of the call to the merge_lora_weights() function, some of them are torch.uint8

When you run an inference with --quantize bnb.nf4 argument, the pre-trained weights are quantized upon loading to torch.uint8 format. So this is expected. The purpose of the merge method is to, well, merge the pretrained and LoRA weights to reduce the number of computations. So it's solely for speed optimizations, meaning that the fact that it works without merging is no surprise.

What's surprising is that the loss is different. I'll check this out. In the meantime, as I stated above, you can use the code with the commented-out merge_lora_weights function without any problems.


Btw, did you make any changed to the code? If you did, could you send me the result git diff main > changes.diff (if changes aren't proprietary of course).

Andrei-Aksionov avatar Feb 20 '24 13:02 Andrei-Aksionov

Hi @Andrei-Aksionov, thanks for your support.

We used the commit with id: 1e5afd6fb5653eddc15aafcae8c20f5222e4e1e3. The only two things we have done are:

  1. Comment the line that calls the merge_lora_weights() function
  2. Run inference on the entire validation dataset, by loading the model once and iterate for n samples in the dataset.

michele-milesi avatar Feb 20 '24 13:02 michele-milesi

Got it. I'll check later this week.

Andrei-Aksionov avatar Feb 20 '24 13:02 Andrei-Aksionov