litgpt
litgpt copied to clipboard
Unexpected behaviour in inference with merged QLoRA weights
Hi, a few weeks ago @morettif and I finetuned the Llama70B
with QLoRA on a H100:
-
r=32
-
alpha=64
-
quantize=bnb.nf4-dq
-
precision=bf16-true
-
weight_decay=0
-
batch_size=32
-
micro_batch_size=2
-
lora_dropout=0.05
- All LoRA layers
The validation loss after the finetuning was about 0.5
, but during inference on the validation set, we obtained very strange results. We checked the loss on the validation set and it was ~2.5 (as before the finetuning).
We used the generate/lora.py
script and we noticed that after commenting this line: https://github.com/Lightning-AI/lit-gpt/blob/8a101b633dfeafd378f8fbaba6a80a4417c33576/generate/lora.py#L120 the model generated meaningful output with a loss coherent with the one observed during training (~0.5).
We analyzed the dtype of the pre-trained model weights at the moment of the call to the merge_lora_weights()
function, some of them are torch.uint8
, this causes that during merging the following if (in the merge()
function) is evaluated as True
: https://github.com/Lightning-AI/lit-gpt/blob/8a101b633dfeafd378f8fbaba6a80a4417c33576/lit_gpt/lora.py#L151
Could there be an error in this piece of code related to the management of quantization?
Hi @michele-milesi Thanks for reporting.
We analyzed the dtype of the pre-trained model weights at the moment of the call to the merge_lora_weights() function, some of them are torch.uint8
When you run an inference with --quantize bnb.nf4
argument, the pre-trained weights are quantized upon loading to torch.uint8
format. So this is expected.
The purpose of the merge
method is to, well, merge the pretrained and LoRA weights to reduce the number of computations. So it's solely for speed optimizations, meaning that the fact that it works without merging is no surprise.
What's surprising is that the loss is different. I'll check this out.
In the meantime, as I stated above, you can use the code with the commented-out merge_lora_weights
function without any problems.
Btw, did you make any changed to the code?
If you did, could you send me the result git diff main > changes.diff
(if changes aren't proprietary of course).
Hi @Andrei-Aksionov, thanks for your support.
We used the commit with id: 1e5afd6fb5653eddc15aafcae8c20f5222e4e1e3. The only two things we have done are:
- Comment the line that calls the
merge_lora_weights()
function - Run inference on the entire validation dataset, by loading the model once and iterate for
n
samples in the dataset.
Got it. I'll check later this week.