mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

How can I merge the Lora weight back to the original model weight?

Open mzbac opened this issue 1 year ago • 6 comments

Maybe I'm missing something, but I can't find any information on how to merge Lora weight back into the original model. Running the model with a Lora adapter will add additional memory overhead and make it slightly more difficult to distribute. I wonder if we could provide the script for merging Lora weight back into the model?

mzbac avatar Jan 01 '24 09:01 mzbac

Yea that's missing in our example. I think it would be nice to have an option on the LoRA layer which merges the adapters and the linear weights after the adapters are trained or when using them in generation.

awni avatar Jan 01 '24 14:01 awni

Thanks for sharing your thoughts. I plan to use MLX to fine-tune some of my models. I will try to make it work and contribute back if possible.

mzbac avatar Jan 02 '24 00:01 mzbac

@awni I only managed to create a merge function(https://github.com/mzbac/mlx-lora/blob/main/models.py#L92-L107) in loraLinear and loop through all the named modules to get the merged linear layer, then update the modules(https://github.com/mzbac/mlx-lora/blob/main/utils.py#L22-L28). However, this is not efficient because I have to make copies of the linear layers and update them. I am wondering if mlx have a method that allows us to map the module with lambda so we can replace the layer without making additional copies?

mzbac avatar Jan 07 '24 01:01 mzbac

You needn't worry about this being inefficient:

        self.linear.weight += (self.lora_a @ self.lora_b).T * 2.0
        new_linear = nn.Linear(input_dims, output_dims, bias=False)
        new_linear.weight = self.linear.weight

The new_linear.weight is not doing a deep copy (under the hood) it will just point to the same data as self.linear.weight.

Can I ask: what is your intention with merging? My understanding is it is pretty uncommon to save the fully merged model (because you can easily restore it from the original model and the adapters).

What is more common is to merge dynamically to avoid the additional expense of forming the low rank update when you are using the model. From that perspective it might make more sense to have an "eval mode" on the LoRALinear layer that merges it.

awni avatar Jan 07 '24 14:01 awni

You needn't worry about this being inefficient:

        self.linear.weight += (self.lora_a @ self.lora_b).T * 2.0
        new_linear = nn.Linear(input_dims, output_dims, bias=False)
        new_linear.weight = self.linear.weight

The new_linear.weight is not doing a deep copy (under the hood) it will just point to the same data as self.linear.weight.

Can I ask: what is your intention with merging? My understanding is it is pretty uncommon to save the fully merged model (because you can easily restore it from the original model and the adapters).

What is more common is to merge dynamically to avoid the additional expense of forming the low rank update when you are using the model. From that perspective it might make more sense to have an "eval mode" on the LoRALinear layer that merges it.

Thanks for pointing out the shallow copying. I noticed that memory usage increased to around 100GB during the merging process, so I thought it might be a deep copy issue.

For merging Base and Lora, it is a method that the current open-source LLM community uses to distribute fine-tuned models. The reason from my understanding may be because the major inference frameworks does not support base + adapter. FYI: https://www.reddit.com/r/LocalLLaMA/comments/17m8ock/why_do_we_always_download_fully_merged_baselora/

mzbac avatar Jan 07 '24 14:01 mzbac

I noticed that memory usage increased to around 100GB during the merging process, so I thought it might be a deep copy issue.

Wow! That's a lot. It could be due to dequantization.. It might help to stream the merging so that it doesn't blow up the memory by putting an eval right after the weight update.

        self.linear.weight += (self.lora_a @ self.lora_b).T * 2.0
        new_linear = nn.Linear(input_dims, output_dims, bias=False)
        new_linear.weight = self.linear.weight
        mx.eval(new_linear.weight)

The reason from my understanding may be because the major inference frameworks does not support base + adapter.

Makes sense, thanks for the explanation.

awni avatar Jan 07 '24 14:01 awni