lit-llama icon indicating copy to clipboard operation
lit-llama copied to clipboard

Export finetuned lora weights to base model

Open eyrs42 opened this issue 2 years ago • 4 comments

Hi

Is there an easy way to map the finetuned lora weights to the base litllama model? I see that the lora finetuning saves checkpoints that include only the finetuned weights ("lora.A" and "lora.B"). I couldn't find a method that maps these back to the original base litllama model.

Thank you!

eyrs42 avatar May 11 '23 14:05 eyrs42

The reason we only save the lora weights is simply to save space, since the full checkpoints are quite large. But yes, we could provide such a conversion for a lora checkpoint, that's a good suggestion. The implementation would roughly be like this:

  1. load pretrained weights and lora weights from checkpoints into lora model (see generate_lora.py)
  2. Call model.eval() to merge lora weights back into the regular weights
  3. state = model.state_dict()
  4. drop all entries in the dict that correspond to lora
  5. torch.save(state, ...)

awaelchli avatar May 12 '23 08:05 awaelchli

interested in this!

Currently, we cannot run int8 on lora weights, I guess it is a requirement for us to merge the weights to the base model to run the quantization script.

timothylimyl avatar May 22 '23 02:05 timothylimyl

@awaelchli Can you provide some direction on this? It will be great to start a PR on this feature, I think it will be awesome to have!

I do not get how does running model.eval() merges the weight.

The only difference between LoRA and base is on:

          (c_attn): Linear(in_features=4096, out_features=12288, bias=False)

versus

          (c_attn): MergedLinear(
            in_features=4096, out_features=12288, bias=False
            (lora_dropout): Dropout(p=0.05, inplace=False)
          )

When I print out the model.parameters(), snippet of a layer:

transformer.h.4.attn.c_attn.weight
transformer.h.4.attn.c_attn.lora_A
transformer.h.4.attn.c_attn.lora_B

So does model.eval() somehow merges the weight with the delta weight lora_A @ lora_B? I cannot figure out where does this happen in the code.

timothylimyl avatar May 22 '23 04:05 timothylimyl

Realised explanation is here: pull #185

timothylimyl avatar May 22 '23 09:05 timothylimyl