vllm icon indicating copy to clipboard operation
vllm copied to clipboard

Would it be possible to support LoRA fine-tuned models?

Open asalaria-cisco opened this issue 1 year ago • 6 comments

How easy or difficult it would be to support LoRA fine-tuned models? Would it need big changes to the vLLM engine or is it something that can be done at the higher level by modifying the model?

asalaria-cisco avatar Jun 21 '23 07:06 asalaria-cisco

I think no change is needed on the vLLM side. You can simply combine the additional weights in LoRA with the pertained model weights. Then the resulting model has the same architecture as the pretrained model, so can be served on vLLM.

WoosukKwon avatar Jun 21 '23 08:06 WoosukKwon

Ah, I see. Point taken. However, in some applications we might have several fine-tuned models on top of the same base model. In those cases, recombining the weights might not be desirable (for example to ease storage or lifecycle management). Thus it would be great if we could use vLLM with LoRA models without recombining the weights.

asalaria-cisco avatar Jun 21 '23 08:06 asalaria-cisco

@asalaria-cisco Thanks for the further explanation! You're right. I also agree that that's a cool feature to have. Currently, we are focusing on fixing bugs and adding new models. After these are addressed, we will consider adding the feature.

To accelerate our development, could you share which framework you used for training LLMs with LoRA? I'm wondering what the weight format looks like.

WoosukKwon avatar Jun 24 '23 00:06 WoosukKwon

huggingface/peft is the most popular choice

creatorrr avatar Jun 24 '23 14:06 creatorrr

Also want this feature. Merging weight is kinda painful (merging requires large disk space).

skyshine102 avatar Jun 27 '23 09:06 skyshine102

I really want this feature, too. It will be awesome!

vkehfdl1 avatar Jun 28 '23 05:06 vkehfdl1

Adding to this, if someone gets around to support independent LoRA adapter weights, I'd like to request a particular architecture difference that makes it easier to switch between adapters.

Right now Peft implements this by modifying the base model's modules and replacing them with Lora linear or embedding layers. This makes it impossible to access the base model directly, nor does it let you load up a separate adapter using the base model. It's possible to deal with this and load/unload the adapters in between each call, but it'd be really really awesome to be able to load up adapters independently and still be able to do inference on the base model all without doing something silly like set_adapter('this_calls_adapter') or disable_adapters().

fozziethebeat avatar Jul 08 '23 11:07 fozziethebeat

I added this quick solution below for llama-hf model. The steps are

  1. Load original llama to vllm with llm = LLM("llama-7b") ...
  2. Load lora states dict lora_state_dict = torch.load("lora_states.pt")['module'].
  3. Merge lora states to llm do lora_merge_unmerge_state_dict(llm, lora_state_dict, merge=True)
  4. Do whatever inference job with llm ...
  5. To unmerge and obtain the original llama, run lora_merge_unmerge_state_dict(llm, lora_state_dict, merge=False)

def lora_reassign_weights(model, state_dict, r, lora_alpha, fan_in_fan_out=False, merge=True):
    is_merged = getattr(model, "is_merged", False)
    assert is_merged != merge, f'{is_merged} != {merge}: if is_merged, then must be unmerge; if not is_merged, then must merge'
    named_params = [(n, p) for n, p in model.named_parameters()]
    scaling = lora_alpha / r
    print(f'Lora configs: alpha={lora_alpha}, r={r}, scaling={scaling}')
    state_dict = {k.replace("base_model.model.", ""): v for k, v in state_dict.items()}
    replaced = set()
    merged_names = {
        # these are projector weights that got combined into single matrix in vllm
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }
    non_merged_names = ['o_proj', 'down_proj']
    for name, param in named_params:
        param.requires_grad = False
        if "_proj.weight" not in name:
            continue
        for wn, wn_series in merged_names.items():
            if name.endswith(f"{wn}.weight"):
                for stride_id, att_weight_name in enumerate(wn_series):
                    lora_a = name.replace(f"{wn}.weight", f"{att_weight_name}.lora_A.weight")
                    lora_b = name.replace(f"{wn}.weight", f"{att_weight_name}.lora_B.weight")
                    shard_size = param.shape[0] // len(wn_series)
                    if lora_a in state_dict:
                        assert lora_b in state_dict, f'{lora_b} not in state_dict'
                        assert state_dict[lora_b].shape[1] == r, f'{r=} != {state_dict[lora_b].shape}'
                        matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling
                        assert param.data[shard_size * stride_id:shard_size * (stride_id + 1)].shape == matrix.shape
                        if merge:
                            param.data[shard_size * stride_id:shard_size * (stride_id + 1)] += matrix
                        else:
                            param.data[shard_size * stride_id:shard_size * (stride_id + 1)] -= matrix
                        replaced.add(lora_a)
                        replaced.add(lora_b)
        for wn in non_merged_names:
            if name.endswith(f"{wn}.weight"):
                lora_a = name.replace(f"{wn}.weight", f"{wn}.lora_A.weight")
                lora_b = name.replace(f"{wn}.weight", f"{wn}.lora_B.weight")
                if lora_a in state_dict:
                    assert lora_b in state_dict
                    matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling
                    assert param.data.shape == matrix.shape, f'invalid shape: {name} {param.data.shape} != {matrix.shape}'
                    if merge:
                        param.data += matrix
                    else:
                        param.data -= matrix
                    replaced.add(lora_a)
                    replaced.add(lora_b)
    no_replaced = [k for k in state_dict.keys() if k not in replaced]
    assert len(no_replaced) == 0, f'some lora states not loaded, check again!: {no_replaced}'
    model.is_merged = merge


def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True):
    # merge lora states to weights
    for worker in llm.llm_engine.workers:
        lora_reassign_weights(worker.model, state_dict, 
            r=peft_config.r, 
            lora_alpha=peft_config.lora_alpha, 
            fan_in_fan_out=peft_config.fan_in_fan_out, 
            merge=merge
        )


Hope this helps.

nxphi47 avatar Jul 19 '23 06:07 nxphi47

Hey, I tried to do this, but when the model is loaded using Ray it doesn't work. I get this error

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File <command-1454142161413636>:2
      1 # Merge Lora
----> 2 lora_merge_unmerge_state_dict(llm, lora_state_dict, peft_config, merge=True)

File <command-1454142161413207>:58, in lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge)
     55 def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True):
     56     # merge lora states to weights
     57     for worker in llm.llm_engine.workers:
---> 58         lora_reassign_weights(worker.model, state_dict, 
     59             r=peft_config.r, 
     60             lora_alpha=peft_config.lora_alpha, 
     61             fan_in_fan_out=peft_config.fan_in_fan_out, 
     62             merge=merge
     63         )

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-673bda09-f309-427b-a99f-d419bb29211d/lib/python3.10/site-packages/ray/actor.py:1201, in ActorHandle.__getattr__(self, item)
   1199 def __getattr__(self, item):
   1200     if not self._ray_is_cross_language:
-> 1201         raise AttributeError(
   1202             f"'{type(self).__name__}' object has " f"no attribute '{item}'"
   1203         )
   1204     if item in ["__ray_terminate__"]:
   1206         class FakeActorMethod(object):

AttributeError: 'ActorHandle' object has no attribute 'model'

nivibilla avatar Jul 28 '23 17:07 nivibilla

It would be great to support multiple LoRA matrices that can be hotswapped using the same inference infrastructure running vLLM. There is a good example of how lmsys solved this here

https://github.com/lm-sys/FastChat/pull/1905

sam-h-bean avatar Jul 31 '23 19:07 sam-h-bean

I wrote that PR for FastChat. That's actually not the most preferred solution since it requires walking through the model's list of modules and updating them to activate/deactivate the right adapter during each request. Ideally we'd be able to have all a way to call the base model + adapter of choice without having to re-write the model on every request.

This adds a noticable delay and requires locking the whole model per request

fozziethebeat avatar Jul 31 '23 23:07 fozziethebeat

@nivibilla I don't use ray so I'm not sure. But you need to locate and apply the reassign_weights function to the VLLM LlamaForCausalLM model here, wherever it is in Ray.

nxphi47 avatar Aug 01 '23 05:08 nxphi47

I added this quick solution below for llama-hf model. The steps are

  1. Load original llama to vllm with llm = LLM("llama-7b") ...
  2. Load lora states dict lora_state_dict = torch.load("lora_states.pt")['module'].
  3. Merge lora states to llm do lora_merge_unmerge_state_dict(llm, lora_state_dict, merge=True)
  4. Do whatever inference job with llm ...
  5. To unmerge and obtain the original llama, run lora_merge_unmerge_state_dict(llm, lora_state_dict, merge=False)
def lora_reassign_weights(model, state_dict, r, lora_alpha, fan_in_fan_out=False, merge=True):
    is_merged = getattr(model, "is_merged", False)
    assert is_merged != merge, f'{is_merged} != {merge}: if is_merged, then must be unmerge; if not is_merged, then must merge'
    named_params = [(n, p) for n, p in model.named_parameters()]
    scaling = lora_alpha / r
    print(f'Lora configs: alpha={lora_alpha}, r={r}, scaling={scaling}')
    state_dict = {k.replace("base_model.model.", ""): v for k, v in state_dict.items()}
    replaced = set()
    merged_names = {
        # these are projector weights that got combined into single matrix in vllm
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }
    non_merged_names = ['o_proj', 'down_proj']
    for name, param in named_params:
        param.requires_grad = False
        if "_proj.weight" not in name:
            continue
        for wn, wn_series in merged_names.items():
            if name.endswith(f"{wn}.weight"):
                for stride_id, att_weight_name in enumerate(wn_series):
                    lora_a = name.replace(f"{wn}.weight", f"{att_weight_name}.lora_A.weight")
                    lora_b = name.replace(f"{wn}.weight", f"{att_weight_name}.lora_B.weight")
                    shard_size = param.shape[0] // len(wn_series)
                    if lora_a in state_dict:
                        assert lora_b in state_dict, f'{lora_b} not in state_dict'
                        assert state_dict[lora_b].shape[1] == r, f'{r=} != {state_dict[lora_b].shape}'
                        matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling
                        assert param.data[shard_size * stride_id:shard_size * (stride_id + 1)].shape == matrix.shape
                        if merge:
                            param.data[shard_size * stride_id:shard_size * (stride_id + 1)] += matrix
                        else:
                            param.data[shard_size * stride_id:shard_size * (stride_id + 1)] -= matrix
                        replaced.add(lora_a)
                        replaced.add(lora_b)
        for wn in non_merged_names:
            if name.endswith(f"{wn}.weight"):
                lora_a = name.replace(f"{wn}.weight", f"{wn}.lora_A.weight")
                lora_b = name.replace(f"{wn}.weight", f"{wn}.lora_B.weight")
                if lora_a in state_dict:
                    assert lora_b in state_dict
                    matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling
                    assert param.data.shape == matrix.shape, f'invalid shape: {name} {param.data.shape} != {matrix.shape}'
                    if merge:
                        param.data += matrix
                    else:
                        param.data -= matrix
                    replaced.add(lora_a)
                    replaced.add(lora_b)
    no_replaced = [k for k in state_dict.keys() if k not in replaced]
    assert len(no_replaced) == 0, f'some lora states not loaded, check again!: {no_replaced}'
    model.is_merged = merge


def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True):
    # merge lora states to weights
    for worker in llm.llm_engine.workers:
        lora_reassign_weights(worker.model, state_dict, 
            r=peft_config.r, 
            lora_alpha=peft_config.lora_alpha, 
            fan_in_fan_out=peft_config.fan_in_fan_out, 
            merge=merge
        )

Hope this helps.

Thanks for sharing, I would like to ask about the matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling. This transpose does not look like a torch method or a numpy method. Is this a pseudocode? If I add lora in the linear layer, is it okay to use matrix = state_dict[lora_b] @ state_dict[lora_a] * scaling ?

zuxinqi avatar Aug 08 '23 06:08 zuxinqi

@zuxinqi Sorry I forgot, transpose here:


def transpose(weight, fan_in_fan_out):
    return weight.T if fan_in_fan_out else weight

nxphi47 avatar Aug 11 '23 03:08 nxphi47

Hey, I tried to do this, but when the model is loaded using Ray it doesn't work. I get this error

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File <command-1454142161413636>:2
      1 # Merge Lora
----> 2 lora_merge_unmerge_state_dict(llm, lora_state_dict, peft_config, merge=True)

File <command-1454142161413207>:58, in lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge)
     55 def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True):
     56     # merge lora states to weights
     57     for worker in llm.llm_engine.workers:
---> 58         lora_reassign_weights(worker.model, state_dict, 
     59             r=peft_config.r, 
     60             lora_alpha=peft_config.lora_alpha, 
     61             fan_in_fan_out=peft_config.fan_in_fan_out, 
     62             merge=merge
     63         )

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-673bda09-f309-427b-a99f-d419bb29211d/lib/python3.10/site-packages/ray/actor.py:1201, in ActorHandle.__getattr__(self, item)
   1199 def __getattr__(self, item):
   1200     if not self._ray_is_cross_language:
-> 1201         raise AttributeError(
   1202             f"'{type(self).__name__}' object has " f"no attribute '{item}'"
   1203         )
   1204     if item in ["__ray_terminate__"]:
   1206         class FakeActorMethod(object):

AttributeError: 'ActorHandle' object has no attribute 'model'

Find the same issue when tensor_parallel_size > 1. This solution seems only work when tensor_parallel_size == 1

blmoistawinde avatar Sep 06 '23 10:09 blmoistawinde

Hey, I tried to do this, but when the model is loaded using Ray it doesn't work. I get this error

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File <command-1454142161413636>:2
      1 # Merge Lora
----> 2 lora_merge_unmerge_state_dict(llm, lora_state_dict, peft_config, merge=True)

File <command-1454142161413207>:58, in lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge)
     55 def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True):
     56     # merge lora states to weights
     57     for worker in llm.llm_engine.workers:
---> 58         lora_reassign_weights(worker.model, state_dict, 
     59             r=peft_config.r, 
     60             lora_alpha=peft_config.lora_alpha, 
     61             fan_in_fan_out=peft_config.fan_in_fan_out, 
     62             merge=merge
     63         )

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-673bda09-f309-427b-a99f-d419bb29211d/lib/python3.10/site-packages/ray/actor.py:1201, in ActorHandle.__getattr__(self, item)
   1199 def __getattr__(self, item):
   1200     if not self._ray_is_cross_language:
-> 1201         raise AttributeError(
   1202             f"'{type(self).__name__}' object has " f"no attribute '{item}'"
   1203         )
   1204     if item in ["__ray_terminate__"]:
   1206         class FakeActorMethod(object):

AttributeError: 'ActorHandle' object has no attribute 'model'

Find the same issue when tensor_parallel_size > 1. This solution seems only work when tensor_parallel_size == 1

Yes unfortunately this won't work with tensor_parallel_size > 1. With tensor_parallel_size > 1 vllm use Ray and shard each layer weights into ColumnParallel and RowParallel linear layers. That is, W (4096x4096) will be come W1 (4096x2048) on rank 1 and W2 (4096x2048) on rank 2

This means the matmul lora weights (lora_A @ lora_B) must multiplied in each rank, and added to the main linear weight W before it being sharded.

Because Ray is launch in distributed mode, you would need to apply the above operations before the vllm instance is created and right after the model is created and prepare to load pretrained weights on each rank. In other words, you have to modify the load_weights here https://github.com/vllm-project/vllm/blob/6368e777a8ead7fb62054d3779c6237361ec0d86/vllm/model_executor/models/llama.py#L311

The change may looks like this

...
def load_weights(...
...
    for weight_name, shard_size, offset in attention_weight_specs:
        if weight_name not in name or "qkv_proj" in name:
            continue
        param = state_dict[name.replace(weight_name, "qkv_proj")]
        matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling
        loaded_weight = loaded_weight + matrix
        # from here, the full weight is sharded according to  tensor_model_parallel_rank
        loaded_weight = loaded_weight[
                        shard_size * tensor_model_parallel_rank:shard_size *
                        (tensor_model_parallel_rank + 1)]
...

I do find this very complicated and there should be a better way of loading this.

nxphi47 avatar Oct 13 '23 02:10 nxphi47

Does this the problem?

https://github.com/huggingface/peft/pull/227

pieria-elsevier avatar Oct 16 '23 23:10 pieria-elsevier

Does this the problem?

huggingface/peft#227

No. This just merges the weights into the base model. What I, and I think everyone else on this thread, want is a way to hot-swap which LoRA weights are being applied to a base model for a given batch.

For example, if we have a batch size of 4. In an ideal world you can compute the activations of the base model and then just add whatever lora weights you want on top within the batch.

jonhilgart22 avatar Oct 25 '23 04:10 jonhilgart22

I agree that this is a great idea! Happy to help work on this if that's useful to y'all

amanpyq avatar Oct 27 '23 21:10 amanpyq

Just noticed a paper discussing an efficient implementation of multi-LoRA serving called S-LoRA. Link to paper.

They implemented this in LightLLM but could vLLM potentially learn from this implementation?

iiLaurens avatar Nov 07 '23 14:11 iiLaurens

Just wanted to share the S-LoRA paper from Stanford and found @iiLaurens has already shared!

Just noticed a paper discussing an efficient implementation of multi-LoRA serving called S-LoRA. Link to paper.

They implemented this in LightLLM but could vLLM potentially learn from this implementation?

skyshine102 avatar Nov 08 '23 17:11 skyshine102

As far as the hotswap is still not implemented.. What is the current best way to run lora weights using vllm ? Should I merge the lora weights and the base model in advance or can it be done on runtime?

matankley avatar Nov 09 '23 10:11 matankley

@matankley try using s-lora. Looks interesting and they even compare the performance against vLLM. I haven't tried it but looks pretty good for multi adapter setup

nivibilla avatar Nov 09 '23 10:11 nivibilla

Tracking, would love this feature, thank you so much !

binarycrayon avatar Nov 14 '23 17:11 binarycrayon

Tracking, would love this feature, thank you so much !

TingFeng-7 avatar Nov 21 '23 05:11 TingFeng-7

Please check this issue and solution if you wanna add LoRA to llm.

SuperBruceJia avatar Nov 22 '23 01:11 SuperBruceJia

git clone --branch support_peft https://github.com/troph-team/vllm.git
cd vllm
pip install -e . --user

SuperBruceJia avatar Nov 22 '23 01:11 SuperBruceJia

mark

callanwu avatar Dec 20 '23 02:12 callanwu

Tracking.

tianqiwu1225 avatar Dec 21 '23 07:12 tianqiwu1225

Tracking :)

RajdeepBorgohain avatar Dec 27 '23 06:12 RajdeepBorgohain