vllm
vllm copied to clipboard
[BigFix] Fix the lm_head in gpt_bigcode in lora mode
This is an alternative fix for #6314 that doesn't disable LoRA for the lm_head in GPTBigCode that we worked on today with @tjohnson31415 and @tdoublep .
Due to the weight-tie, in lora mode the lm_head implementation was being replaced with VocabParallelEmbeddingWithLoRA which is not meant for the lm_head. To fix the issue this PR initializes the lm_head as an instance of ParallelLMHead and assigns it the weights of the embedding module. Modules of this class are not substituted during the initialization of LoRa. Due to previous differences in the padding of the two layers (64 vs 256) there was a size mismatch, so the same padding is now applied to both. The resulting vocabulary dimension had to be added to bgmv_config.h
Except for the padding adjustment, this is now basically the same as in the LLama code.
cc @robertgshaw2-neuralmagic
This is an alternative fix for #6314 that doesn't disable LoRA for the lm_head in GPTBigCode that we worked on today with @tjohnson31415 and @tdoublep .
Due to the weight-tie, in lora mode the lm_head implementation was being replaced with VocabParallelEmbeddingWithLoRA which is not meant for the lm_head. To fix the issue this PR initializes the lm_head as an instance of ParallelLMHead and assigns it the weights of the embedding module. Modules of this class are not substituted during the initialization of LoRa. Due to previous differences in the padding of the two layers (64 vs 256) there was a size mismatch, so the same padding is now applied to both. The resulting vocabulary dimension had to be added to bgmv_config.h
Except for the padding adjustment, this is now basically the same as in the LLama code.
cc @robertgshaw2-neuralmagic
My (albeit limited) understanding of LoRA is that it is uncommon to train LoRA adapters on Embeddings LM-head UNLESS you are trying to add special tokens to the vocabulary (e.g. for chatml)
Does this implementation enable this? Or is it now that we can add LoRA adapters to the Vocab, but not to the lm-head?
Yes the vLLM multi-LoRA support in general supports adding tokens/embeddings including to lm_head layer (we have used it e.g. for Mixtral).
@followumesh any chance you could review this one?
@maxdebayser A sanity check question, does the LoRA correctly replace ParallelLMHead(VocabParallelEmbedding) with the corresponding LoRA layer (I assume VocabParallelEmbeddingWithLoRA)?
@followumesh , no the lm_head should not be replaced with a Lora class because the LoRA weights are applied in LogitsProcessorWithLora: https://github.com/vllm-project/vllm/blob/db35186391a2abfc6c91d703527dac20d2488107/vllm/lora/layers.py#L1054
This is helpful not only for supporting lm_head as a LoRA module, but also needed for wte to work correctly with LoRA. Without this fix I get gibberish output for a LoRA trained GPTBigCode model even without training a LoRA module for lm_head.
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @maxdebayser.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@maxdebayser do you plan to continue this work?
@hmellor , yes. I need to solve the merge conflicts. But this is done. It's being used in production at IBM for several months now. The only missing thing is to get someone to review and hopefully approve.
Nice, once it's updated, I'll see about getting it reviewed!
I don't think the failure is related as the failing test uses Gemma
I've synced with main now to trigger a new build with the latest code.
I'm very sorry for missing this PR. I will look at it ASAP. Thank you.
@maxdebayser Perhaps directly deleting embedding_modules would be more appropriate?
@jeejeelee thanks for your suggestion. It works when the embedding and lm_head modules are tied, but not when they aren't.
Let me try to give a little bit more context here. We have a GPT-Bigcode model with weight tie and some LoRA adapters that unfortunately are not public, but hopefully I can explain what is going on.
As of version 0.7.4 , when I try to load the model with --enable-lora it fails to load with the following error:
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1243, in profile_run
ERROR 03-21 19:04:02 [engine.py:411] self._dummy_run(max_num_batched_tokens, max_num_seqs)
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1354, in _dummy_run
ERROR 03-21 19:04:02 [engine.py:411] self.execute_model(model_input, kv_caches, intermediate_tensors)
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 03-21 19:04:02 [engine.py:411] return func(*args, **kwargs)
ERROR 03-21 19:04:02 [engine.py:411] ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1669, in execute_model
ERROR 03-21 19:04:02 [engine.py:411] self.set_active_loras(model_input.lora_requests,
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1371, in set_active_loras
ERROR 03-21 19:04:02 [engine.py:411] self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/worker_manager.py", line 167, in set_active_adapters
ERROR 03-21 19:04:02 [engine.py:411] set_active_adapters_worker(requests, mapping, self._apply_adapters,
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/adapter_commons/utils.py", line 54, in set_active_adapters_worker
ERROR 03-21 19:04:02 [engine.py:411] apply_adapters_func(requests)
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/worker_manager.py", line 227, in _apply_adapters
ERROR 03-21 19:04:02 [engine.py:411] self.add_adapter(lora)
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/worker_manager.py", line 250, in add_adapter
ERROR 03-21 19:04:02 [engine.py:411] self._adapter_manager.activate_adapter(lora_request.lora_int_id)
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/models.py", line 720, in activate_adapter
ERROR 03-21 19:04:02 [engine.py:411] result = super().activate_adapter(lora_id)
ERROR 03-21 19:04:02 [engine.py:411] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/models.py", line 405, in activate_adapter
ERROR 03-21 19:04:02 [engine.py:411] module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
ERROR 03-21 19:04:02 [engine.py:411] File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/layers.py", line 1070, in set_lora
ERROR 03-21 19:04:02 [engine.py:411] 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
ERROR 03-21 19:04:02 [engine.py:411] ^^^^^^
ERROR 03-21 19:04:02 [engine.py:411] RuntimeError: The size of tensor a (6144) must match the size of tensor b (49408) at non-singleton dimension 1
With the changes in this PR and also with your suggestion the model loads without errors and the results with and without adapter are the same:
curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
"model": "my-gpt-bigcode-model-with-weight-tie",
"prompt": ["Input: Our waitress seemed less than happy about the prix fixe dinner choices and at one point said, Do you really need to hear the specials? Response:"],
"max_tokens": 10,
"temperature": 0
}'| jq
{
"object": "text_completion",
"model": "my-gpt-bigcode-model-with-weight-tie",
"choices": [
{
"index": 0,
"text": " I don't know, I just don't like",
"logprobs": null,
"finish_reason": "length",
"stop_reason": null,
"prompt_logprobs": null
}
]
}
curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
"model": "my-lora",
"prompt": ["Input: Our waitress seemed less than happy about the prix fixe dinner choices and at one point said, Do you really need to hear the specials? Response:"],
"max_tokens": 10,
"temperature": 0
}'| jq
{
"object": "text_completion",
"model": "my-lora",
"choices": [
{
"index": 0,
"text": " waitress: negative, specials: neutral,",
"logprobs": null,
"finish_reason": "length",
"stop_reason": null,
"prompt_logprobs": null
}
]
}
But, for testing purposes I have the same model where I duplicated the weights for the lm_head and set "tie_word_embeddings": false. When I run this model with the changes in this PR I get the same results as above. Whereas when I just delete embedding_modules, the model loads without crashing but the outputs change:
curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
"model": "my-gpt-bigcode-model-without-weight-tie",
"prompt": ["Input: Our waitress seemed less than happy about the prix fixe dinner choices and at one point said, Do you really need to hear the specials? Response:"],
"max_tokens": 10,
"temperature": 0
}'| jq
{
"object": "text_completion",
"model": "my-gpt-bigcode-model-without-weight-tie",
"choices": [
{
"index": 0,
"text": "",
"logprobs": null,
"finish_reason": "stop",
"stop_reason": null,
"prompt_logprobs": null
}
]
}
curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
"model": "my-lora",
"prompt": ["Input: Our waitress seemed less than happy about the prix fixe dinner choices and at one point said, Do you really need to hear the specials? Response:"],
"max_tokens": 10,
"temperature": 0
}'| jq
{
"object": "text_completion",
"model": "my-lora",
"choices": [
{
"index": 0,
"text": "",
"logprobs": null,
"finish_reason": "stop",
"stop_reason": null,
"prompt_logprobs": null
}
]
}
What's interesting is that the model without weight tie has this behavior even with version 0.7.4 without --enable-lora.
@maxdebayser Thanks for your explanation.
But, for testing purposes I have the same model where I duplicated the weights for the lm_head and set "tie_word_embeddings": false. When I run this model with the changes in this PR I get the same results as above. Whereas when I just delete embedding_modules, the model loads without crashing but the outputs change:
embedding_modules only works for lora, so your second experiment is not related to deleting embedding_modules. Can you upgrade to the latest version of vllm? I tested setting tie_word_embeddings to False, and it throws an error in the recent main branch
If your lora needs to support embedding_modules, I think we can keep it
@jeejeelee , I finally had time to come back to this. Thanks a lot for your suggestions. The only missing piece to support the model without weight tie was to prevent the loader from skipping the lm_head module in that case.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @maxdebayser.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork