vllm
vllm copied to clipboard
Would it be possible to support LoRA fine-tuned models?
How easy or difficult it would be to support LoRA fine-tuned models? Would it need big changes to the vLLM engine or is it something that can be done at the higher level by modifying the model?
I think no change is needed on the vLLM side. You can simply combine the additional weights in LoRA with the pertained model weights. Then the resulting model has the same architecture as the pretrained model, so can be served on vLLM.
Ah, I see. Point taken. However, in some applications we might have several fine-tuned models on top of the same base model. In those cases, recombining the weights might not be desirable (for example to ease storage or lifecycle management). Thus it would be great if we could use vLLM with LoRA models without recombining the weights.
@asalaria-cisco Thanks for the further explanation! You're right. I also agree that that's a cool feature to have. Currently, we are focusing on fixing bugs and adding new models. After these are addressed, we will consider adding the feature.
To accelerate our development, could you share which framework you used for training LLMs with LoRA? I'm wondering what the weight format looks like.
huggingface/peft is the most popular choice
Also want this feature. Merging weight is kinda painful (merging requires large disk space).
I really want this feature, too. It will be awesome!
Adding to this, if someone gets around to support independent LoRA adapter weights, I'd like to request a particular architecture difference that makes it easier to switch between adapters.
Right now Peft implements this by modifying the base model's modules and replacing them with Lora linear or embedding layers. This makes it impossible to access the base model directly, nor does it let you load up a separate adapter using the base model. It's possible to deal with this and load/unload the adapters in between each call, but it'd be really really awesome to be able to load up adapters independently and still be able to do inference on the base model all without doing something silly like set_adapter('this_calls_adapter')
or disable_adapters()
.
I added this quick solution below for llama-hf model. The steps are
- Load original llama to vllm with
llm = LLM("llama-7b")
... - Load lora states dict
lora_state_dict = torch.load("lora_states.pt")['module']
. - Merge lora states to llm do
lora_merge_unmerge_state_dict(llm, lora_state_dict, merge=True)
- Do whatever inference job with llm ...
- To unmerge and obtain the original llama, run
lora_merge_unmerge_state_dict(llm, lora_state_dict, merge=False)
def lora_reassign_weights(model, state_dict, r, lora_alpha, fan_in_fan_out=False, merge=True):
is_merged = getattr(model, "is_merged", False)
assert is_merged != merge, f'{is_merged} != {merge}: if is_merged, then must be unmerge; if not is_merged, then must merge'
named_params = [(n, p) for n, p in model.named_parameters()]
scaling = lora_alpha / r
print(f'Lora configs: alpha={lora_alpha}, r={r}, scaling={scaling}')
state_dict = {k.replace("base_model.model.", ""): v for k, v in state_dict.items()}
replaced = set()
merged_names = {
# these are projector weights that got combined into single matrix in vllm
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
non_merged_names = ['o_proj', 'down_proj']
for name, param in named_params:
param.requires_grad = False
if "_proj.weight" not in name:
continue
for wn, wn_series in merged_names.items():
if name.endswith(f"{wn}.weight"):
for stride_id, att_weight_name in enumerate(wn_series):
lora_a = name.replace(f"{wn}.weight", f"{att_weight_name}.lora_A.weight")
lora_b = name.replace(f"{wn}.weight", f"{att_weight_name}.lora_B.weight")
shard_size = param.shape[0] // len(wn_series)
if lora_a in state_dict:
assert lora_b in state_dict, f'{lora_b} not in state_dict'
assert state_dict[lora_b].shape[1] == r, f'{r=} != {state_dict[lora_b].shape}'
matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling
assert param.data[shard_size * stride_id:shard_size * (stride_id + 1)].shape == matrix.shape
if merge:
param.data[shard_size * stride_id:shard_size * (stride_id + 1)] += matrix
else:
param.data[shard_size * stride_id:shard_size * (stride_id + 1)] -= matrix
replaced.add(lora_a)
replaced.add(lora_b)
for wn in non_merged_names:
if name.endswith(f"{wn}.weight"):
lora_a = name.replace(f"{wn}.weight", f"{wn}.lora_A.weight")
lora_b = name.replace(f"{wn}.weight", f"{wn}.lora_B.weight")
if lora_a in state_dict:
assert lora_b in state_dict
matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling
assert param.data.shape == matrix.shape, f'invalid shape: {name} {param.data.shape} != {matrix.shape}'
if merge:
param.data += matrix
else:
param.data -= matrix
replaced.add(lora_a)
replaced.add(lora_b)
no_replaced = [k for k in state_dict.keys() if k not in replaced]
assert len(no_replaced) == 0, f'some lora states not loaded, check again!: {no_replaced}'
model.is_merged = merge
def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True):
# merge lora states to weights
for worker in llm.llm_engine.workers:
lora_reassign_weights(worker.model, state_dict,
r=peft_config.r,
lora_alpha=peft_config.lora_alpha,
fan_in_fan_out=peft_config.fan_in_fan_out,
merge=merge
)
Hope this helps.
Hey, I tried to do this, but when the model is loaded using Ray it doesn't work. I get this error
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
File <command-1454142161413636>:2
1 # Merge Lora
----> 2 lora_merge_unmerge_state_dict(llm, lora_state_dict, peft_config, merge=True)
File <command-1454142161413207>:58, in lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge)
55 def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True):
56 # merge lora states to weights
57 for worker in llm.llm_engine.workers:
---> 58 lora_reassign_weights(worker.model, state_dict,
59 r=peft_config.r,
60 lora_alpha=peft_config.lora_alpha,
61 fan_in_fan_out=peft_config.fan_in_fan_out,
62 merge=merge
63 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-673bda09-f309-427b-a99f-d419bb29211d/lib/python3.10/site-packages/ray/actor.py:1201, in ActorHandle.__getattr__(self, item)
1199 def __getattr__(self, item):
1200 if not self._ray_is_cross_language:
-> 1201 raise AttributeError(
1202 f"'{type(self).__name__}' object has " f"no attribute '{item}'"
1203 )
1204 if item in ["__ray_terminate__"]:
1206 class FakeActorMethod(object):
AttributeError: 'ActorHandle' object has no attribute 'model'
It would be great to support multiple LoRA matrices that can be hotswapped using the same inference infrastructure running vLLM. There is a good example of how lmsys solved this here
https://github.com/lm-sys/FastChat/pull/1905
I wrote that PR for FastChat. That's actually not the most preferred solution since it requires walking through the model's list of modules and updating them to activate/deactivate the right adapter during each request. Ideally we'd be able to have all a way to call the base model + adapter of choice without having to re-write the model on every request.
This adds a noticable delay and requires locking the whole model per request
@nivibilla I don't use ray so I'm not sure. But you need to locate and apply the reassign_weights
function to the VLLM LlamaForCausalLM model here, wherever it is in Ray.
I added this quick solution below for llama-hf model. The steps are
- Load original llama to vllm with
llm = LLM("llama-7b")
...- Load lora states dict
lora_state_dict = torch.load("lora_states.pt")['module']
.- Merge lora states to llm do
lora_merge_unmerge_state_dict(llm, lora_state_dict, merge=True)
- Do whatever inference job with llm ...
- To unmerge and obtain the original llama, run
lora_merge_unmerge_state_dict(llm, lora_state_dict, merge=False)
def lora_reassign_weights(model, state_dict, r, lora_alpha, fan_in_fan_out=False, merge=True): is_merged = getattr(model, "is_merged", False) assert is_merged != merge, f'{is_merged} != {merge}: if is_merged, then must be unmerge; if not is_merged, then must merge' named_params = [(n, p) for n, p in model.named_parameters()] scaling = lora_alpha / r print(f'Lora configs: alpha={lora_alpha}, r={r}, scaling={scaling}') state_dict = {k.replace("base_model.model.", ""): v for k, v in state_dict.items()} replaced = set() merged_names = { # these are projector weights that got combined into single matrix in vllm "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"] } non_merged_names = ['o_proj', 'down_proj'] for name, param in named_params: param.requires_grad = False if "_proj.weight" not in name: continue for wn, wn_series in merged_names.items(): if name.endswith(f"{wn}.weight"): for stride_id, att_weight_name in enumerate(wn_series): lora_a = name.replace(f"{wn}.weight", f"{att_weight_name}.lora_A.weight") lora_b = name.replace(f"{wn}.weight", f"{att_weight_name}.lora_B.weight") shard_size = param.shape[0] // len(wn_series) if lora_a in state_dict: assert lora_b in state_dict, f'{lora_b} not in state_dict' assert state_dict[lora_b].shape[1] == r, f'{r=} != {state_dict[lora_b].shape}' matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling assert param.data[shard_size * stride_id:shard_size * (stride_id + 1)].shape == matrix.shape if merge: param.data[shard_size * stride_id:shard_size * (stride_id + 1)] += matrix else: param.data[shard_size * stride_id:shard_size * (stride_id + 1)] -= matrix replaced.add(lora_a) replaced.add(lora_b) for wn in non_merged_names: if name.endswith(f"{wn}.weight"): lora_a = name.replace(f"{wn}.weight", f"{wn}.lora_A.weight") lora_b = name.replace(f"{wn}.weight", f"{wn}.lora_B.weight") if lora_a in state_dict: assert lora_b in state_dict matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling assert param.data.shape == matrix.shape, f'invalid shape: {name} {param.data.shape} != {matrix.shape}' if merge: param.data += matrix else: param.data -= matrix replaced.add(lora_a) replaced.add(lora_b) no_replaced = [k for k in state_dict.keys() if k not in replaced] assert len(no_replaced) == 0, f'some lora states not loaded, check again!: {no_replaced}' model.is_merged = merge def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True): # merge lora states to weights for worker in llm.llm_engine.workers: lora_reassign_weights(worker.model, state_dict, r=peft_config.r, lora_alpha=peft_config.lora_alpha, fan_in_fan_out=peft_config.fan_in_fan_out, merge=merge )
Hope this helps.
Thanks for sharing, I would like to ask about the matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling. This transpose does not look like a torch method or a numpy method. Is this a pseudocode? If I add lora in the linear layer, is it okay to use matrix = state_dict[lora_b] @ state_dict[lora_a] * scaling ?
@zuxinqi Sorry I forgot, transpose here:
def transpose(weight, fan_in_fan_out):
return weight.T if fan_in_fan_out else weight
Hey, I tried to do this, but when the model is loaded using Ray it doesn't work. I get this error
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) File <command-1454142161413636>:2 1 # Merge Lora ----> 2 lora_merge_unmerge_state_dict(llm, lora_state_dict, peft_config, merge=True) File <command-1454142161413207>:58, in lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge) 55 def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True): 56 # merge lora states to weights 57 for worker in llm.llm_engine.workers: ---> 58 lora_reassign_weights(worker.model, state_dict, 59 r=peft_config.r, 60 lora_alpha=peft_config.lora_alpha, 61 fan_in_fan_out=peft_config.fan_in_fan_out, 62 merge=merge 63 ) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-673bda09-f309-427b-a99f-d419bb29211d/lib/python3.10/site-packages/ray/actor.py:1201, in ActorHandle.__getattr__(self, item) 1199 def __getattr__(self, item): 1200 if not self._ray_is_cross_language: -> 1201 raise AttributeError( 1202 f"'{type(self).__name__}' object has " f"no attribute '{item}'" 1203 ) 1204 if item in ["__ray_terminate__"]: 1206 class FakeActorMethod(object): AttributeError: 'ActorHandle' object has no attribute 'model'
Find the same issue when tensor_parallel_size > 1. This solution seems only work when tensor_parallel_size == 1
Hey, I tried to do this, but when the model is loaded using Ray it doesn't work. I get this error
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) File <command-1454142161413636>:2 1 # Merge Lora ----> 2 lora_merge_unmerge_state_dict(llm, lora_state_dict, peft_config, merge=True) File <command-1454142161413207>:58, in lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge) 55 def lora_merge_unmerge_state_dict(llm, state_dict, peft_config, merge=True): 56 # merge lora states to weights 57 for worker in llm.llm_engine.workers: ---> 58 lora_reassign_weights(worker.model, state_dict, 59 r=peft_config.r, 60 lora_alpha=peft_config.lora_alpha, 61 fan_in_fan_out=peft_config.fan_in_fan_out, 62 merge=merge 63 ) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-673bda09-f309-427b-a99f-d419bb29211d/lib/python3.10/site-packages/ray/actor.py:1201, in ActorHandle.__getattr__(self, item) 1199 def __getattr__(self, item): 1200 if not self._ray_is_cross_language: -> 1201 raise AttributeError( 1202 f"'{type(self).__name__}' object has " f"no attribute '{item}'" 1203 ) 1204 if item in ["__ray_terminate__"]: 1206 class FakeActorMethod(object): AttributeError: 'ActorHandle' object has no attribute 'model'
Find the same issue when tensor_parallel_size > 1. This solution seems only work when tensor_parallel_size == 1
Yes unfortunately this won't work with tensor_parallel_size > 1. With tensor_parallel_size > 1
vllm use Ray and shard each layer weights into ColumnParallel and RowParallel linear layers. That is, W (4096x4096) will be come W1 (4096x2048) on rank 1 and W2 (4096x2048) on rank 2
This means the matmul lora weights (lora_A @ lora_B
) must multiplied in each rank, and added to the main linear weight W before it being sharded.
Because Ray is launch in distributed mode, you would need to apply the above operations before the vllm instance is created and right after the model is created and prepare to load pretrained weights on each rank. In other words, you have to modify the load_weights
here https://github.com/vllm-project/vllm/blob/6368e777a8ead7fb62054d3779c6237361ec0d86/vllm/model_executor/models/llama.py#L311
The change may looks like this
...
def load_weights(...
...
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name or "qkv_proj" in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
matrix = transpose(state_dict[lora_b] @ state_dict[lora_a], fan_in_fan_out) * scaling
loaded_weight = loaded_weight + matrix
# from here, the full weight is sharded according to tensor_model_parallel_rank
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
...
I do find this very complicated and there should be a better way of loading this.
Does this the problem?
https://github.com/huggingface/peft/pull/227
Does this the problem?
No. This just merges the weights into the base model. What I, and I think everyone else on this thread, want is a way to hot-swap which LoRA weights are being applied to a base model for a given batch.
For example, if we have a batch size of 4. In an ideal world you can compute the activations of the base model and then just add whatever lora weights you want on top within the batch.
I agree that this is a great idea! Happy to help work on this if that's useful to y'all
Just noticed a paper discussing an efficient implementation of multi-LoRA serving called S-LoRA. Link to paper.
They implemented this in LightLLM but could vLLM potentially learn from this implementation?
Just wanted to share the S-LoRA paper from Stanford and found @iiLaurens has already shared!
Just noticed a paper discussing an efficient implementation of multi-LoRA serving called S-LoRA. Link to paper.
They implemented this in LightLLM but could vLLM potentially learn from this implementation?
As far as the hotswap is still not implemented.. What is the current best way to run lora weights using vllm ? Should I merge the lora weights and the base model in advance or can it be done on runtime?
@matankley try using s-lora. Looks interesting and they even compare the performance against vLLM. I haven't tried it but looks pretty good for multi adapter setup
Tracking, would love this feature, thank you so much !
Tracking, would love this feature, thank you so much !
Please check this issue and solution if you wanna add LoRA to llm.
git clone --branch support_peft https://github.com/troph-team/vllm.git
cd vllm
pip install -e . --user
mark
Tracking.
Tracking :)