[CORE] Allow loading of quantized lm_head (ParallelLMHead)
Reason for PR:
- lm_head can be quantized with minimal loss to output
- save vram by allowing quantized lm_head
Changes:
- Read
lm_headfrom quantize_config - Allow
ParallelLmHeadto be loaded quantized - Rename
QuantizeMethodBaseto more accurateQuantizableMethodBasesince non-quantized methods also inherit this - Added
QUANTIZEDbool property toQuantizableMethodBaseto avoid all theisinstancecalls - Refractor repeating gptq param check code into
utils/skip_gptq_extra_param
Tooling Cross Dependency (tools that make quantized lm_head using GPTQ):
- https://github.com/AutoGPTQ/AutoGPTQ/pull/648
- https://github.com/intel/auto-round/pull/87#issuecomment-2076371550
- https://github.com/AutoGPTQ/AutoGPTQ/pull/640
Test Model (quantized by auto-round and load tested with autogptq):
https://huggingface.co/LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse
Intel/auto-round project @wenhuach21 has demonostrated that lm_head is a good candidate for quantization with minimal-loss.
https://github.com/intel/auto-round/blob/8a3da144423322dfedb0b3fa702ae35d242496d8/docs/Meta-Llama-3-8B-Instruct-acc.md?plain=1#L3
| Metric | BF16 | w4g128 w/o lm-head | w4g128 with quantized lm-head |
|---|---|---|---|
| Avg. | 0.6352 | 0.6312 | 0.6303 |
| mmlu | 0.6386 | 0.6306 | 0.6318 |
| winogrande | 0.7143 | 0.7238 | 0.7269 |
| truthfulqa_mc1 | 0.3623 | 0.3537 | 0.3525 |
| rte | 0.6751 | 0.6859 | 0.6679 |
| piqa | 0.7867 | 0.7797 | 0.7802 |
| openbookqa | 0.3400 | 0.3300 | 0.3320 |
| lambada_openai | 0.7182 | 0.7200 | 0.7173 |
| hellaswag | 0.5769 | 0.5699 | 0.5701 |
| boolq | 0.8297 | 0.8309 | 0.8284 |
| arc_easy | 0.8152 | 0.8089 | 0.8106 |
| arc_challenge | 0.5299 | 0.5102 | 0.5154 |
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]for bug fixes.[CI/Build]for build or continuous integration improvements.[Doc]for documentation fixes and improvements.[Model]for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]For changes on the vLLM frontend (e.g., OpenAI API server,LLMclass, etc.)[Kernel]for changes affecting CUDA kernels or other compute kernels.[Core]for changes in the core vLLM logic (e.g.,LLMEngine,AsyncLLMEngine,Scheduler, etc.)[Hardware][Vendor]for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]).[Misc]for PRs that do not fit the above categories. Please use this sparingly.
Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
- We adhere to Google Python style guide and Google C++ style guide.
- Pass all linter checks. Please use
format.shto format your code. - The code need to be well-documented to ensure future contributors can easily understand the code.
- Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
- Please add documentation to
docs/source/if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.
What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
- After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
- After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
- After the review, the reviewer will put an
action-requiredlabel on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR. - Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!
@Qubitium thanks for this!
We can get this merged very quickly as it was something I was planning to do this week
@Qubitium could you also support passing a quantization_config to the Embedding?
@Qubitium could you also support passing a quantization_config to the Embedding?
Done.
Looks like you're headed in the right direction.
LMK if you need anything.
Please add a couple models to the CI for testing
@robertgshaw2-neuralmagic All tests passing (except for out of disk Neuro-test). Going through all the models, not all model's lm_head is of type ParallelLMHead. Loading of gptq quantization is currently enabled for ParallelLMHead only.
Added ci tests for lm_head quantized true/false for llama (tinyllama 1.1b).
- [x] TODO: merge the 2 test files into one.
@robertgshaw2-neuralmagic All tests passing (except for out of disk Neuro-test). Going through all the models, not all model's lm_head is of type
ParallelLMHead. Loading of gptq quantization is currently enabled forParallelLMHeadonly.
What happens if this case occurs? Do we fail loudly?
@Qubitium
IMPORTANT - given this PR touches a lot of models, I expanded the testing strategy
https://github.com/Qubitium/vllm/pull/2
PR does three things:
- merged main into this branch (so I could have access to some new files I added for recent PR)
- expanded small model testing to cover bigger set of models
- added a "test_models_logprobs" strategy for testing mediumm ~7b models (we can ask sim
Loading of gptq quantization is currently enabled for
ParallelLMHeadonly.What happens if this case occurs? Do we fail loudly?
Raise error if quantized lm_head enabled and class is not ParallelLMHead https://github.com/vllm-project/vllm/pull/4442/commits/fd1cf0e37c4c6ba0a1497c9de06ad3b45497e9a8
Added is_lm_head to get_quant_method since we must be told if layer is head or not. Many models have both VocabParallelEmbedding and ParallelLMHead. Without boolean, get_quant_method cannot tell them apart and catch this misconfiguration.
Modified ci test so that the test is more explicit in lm_head layer check. https://github.com/vllm-project/vllm/pull/4442/commits/da0af51df35f5dc340806178d622b1f1aab6de6b
Loading of gptq quantization is currently enabled for
ParallelLMHeadonly.What happens if this case occurs? Do we fail loudly?
Raise error if quantized
lm_headenabled and class is notParallelLMHeadfd1cf0eAdded
is_lm_headtoget_quant_methodsince we must be told if layer is head or not. Many models have bothVocabParallelEmbeddingandParallelLMHead. Without boolean,get_quant_methodcannot tell them apart and catch this misconfiguration.
Hold on. Some models have lm_head as VocabParallelEmbedding. So lm_head can be both VocabParallelEmbedding and ParallelLMHead. The commit as stands is not good.
@Qubitium
IMPORTANT - given this PR touches a lot of models, I expanded the testing strategy
PR does three things:
- merged main into this branch (so I could have access to some new files I added for recent PR)
- expanded small model testing to cover bigger set of models
- added a "test_models_logprobs" strategy for testing mediumm ~7b models (we can ask sim
@robertgshaw2-neuralmagic Merged but tests are failing and appears unrelated to the changes made in this PR. Can you check? I am not familiar with the src/causation of the errors. Thanks. hf_runner missing max_model_len
https://buildkite.com/vllm/ci/builds/6110#018f2d04-809c-43bd-99b0-4e7aa6165a73
We are currently adding opt-2.7b test. If successful then we can remove ParallelLmHead limitation. Many models' lm_head is the same as embedding and not separate module. So we will cover all cases with llama and opt-2.7b lm_head quantize test.
- [ ] Test Opt-2.7b lm_head quant where lm_head is same module as embedding.
@Qubitium IMPORTANT - given this PR touches a lot of models, I expanded the testing strategy Qubitium#2 PR does three things:
- merged main into this branch (so I could have access to some new files I added for recent PR)
- expanded small model testing to cover bigger set of models
- added a "test_models_logprobs" strategy for testing mediumm ~7b models (we can ask sim
@robertgshaw2-neuralmagic Merged but tests are failing and appears unrelated to the changes made in this PR. Can you check? I am not familiar with the src/causation of the errors. Thanks. hf_runner missing
max_model_lenhttps://buildkite.com/vllm/ci/builds/6110#018f2d04-809c-43bd-99b0-4e7aa6165a73
Will fix
Running into a big issue loading OPT quantized lm_head due to the fact for OPT (and other models) lm_head just a soft-linked in code to embeddings (not a unique lm_head layer like llama). But the thing is OPT model has both lm_head and embedding in original weights except they are the same tensor shape/size and values.
Checking with intel/auto-round team to see if this is a potential issue that should be addressed at the quantization stage or a remapping (here in vlllm) that we need to do. I just don't know the correct answer yet. This could be a bug in vllm OPT code in which assumptions are made that lm_head is always the same as embedding and should be ignored on load. This may be true for pre-quant but post-quant it may not be.
Once this is addressed, I believe limits to ParallelLmHead can be fully unlocked and we just need to check for ParallelVocabEmbedding
Asking Intel/auto-round devs for clarification/insight: https://github.com/intel/auto-round/issues/100
Change PR to draft-mode until this issue is resolved.
Running into a big issue loading OPT quantized
lm_headdue to the fact for OPT (and other models)lm_headjust a soft-linked in code to embeddings (not a unique lm_head layer like llama). But the thing is OPT model has both lm_head and embedding in original weights except they are the same tensor shape/size and values.Checking with intel/auto-round team to see if this is a potential issue that should be addressed at the quantization stage or a remapping (here in vlllm) that we need to do. I just don't know the correct answer yet. This could be a bug in vllm OPT code in which assumptions are made that lm_head is always the same as embedding and should be ignored on load. This may be true for pre-quant but post-quant it may not be.
Once this is addressed, I believe limits to
ParallelLmHeadcan be fully unlocked and we just need to check forParallelVocabEmbeddingAsking Intel/auto-round devs for clarification/insight: intel/auto-round#100
Change PR to draft-mode until this issue is resolved.
Let's handle this in a follow up PR since the scope of this is already big
@robertgshaw2-neuralmagic I overwrote your test changes in commit https://github.com/vllm-project/vllm/pull/4442/commits/2f63a723726bd4cd06a06c21c51fdaf360fbec8a Will re-merge with your changes later.
Fixed OPT model compat with lm_head. lm-head tests now passing.
New problem Marlin kernel loading with lm_head is broken. For now I disabled Marlin auto-upconvert when lm_head is detected/True.
- [x] Fixed OPT compat with
lm_head: load bothlm_headandembed_tokensseparately. Do not assume they are the same. - [x] Fixed:
Marlin runtime convert of compatible models withlm_headenabled is failing. Add TODO and disable marlin upconvert whenlm_headquant enabeld.
@Qubitium no worries - to make it easier for us to both work on it, im going to move the model testing refactor to another PR
@Qubitium I'm going to merge this first, so we can test properly
https://github.com/vllm-project/vllm/pull/4510
@robertgshaw2-neuralmagic I do not plan to make any more changes unless pending CI build shows something broken. https://buildkite.com/vllm/ci/builds/6198
Feel free to add/mod.
Changes and notes:
- Fixed:
Identified runtime auto-convertion to marlin format forlm_headenabled quants is not working. I am not sure what's going on here since the loader path is different in this runtime conversion case. Need to fix in a new PR. This PR is getting bloated as is. For now, disable auto-marlin iflm_headis enabled. - Moved
lm_head_quantizedproperty from GPTQConfig to base QuantizationConfig. - Fixed OPT compat with lm_head quants
- OPT on end of
load_weights()will deletemodel.lm_head(softlink toembed_tokens) if model is not quantized or quantized with lm_head false as not to duplicate memory as in this case, the weights are the same.
On last point, it may be a good idea to add an api (future PR) to model so there is an equivalent on_load_weights_end method that does model weight cleanups/logic post load.
Extra notes:
- Many models that have similar code like OPT where both lm_head and embedding are the same but de-duplicated in model code with skipping logic in
load_weightsneeds to be modified in a similar fashion as well. I did not check if there are any others nor do I want to bloat this PR even more.
All the relevant tests are passing. Of note that the code changes for OPT works but is actually not efficient. There is negative advantage to have an OPT model enable quantization of lm_head as it would, at the moment, disable the sharing of embed_tokens and load 2 different layers causing he overall OPT memory footprint to increase. For llama models where lm_head are separate, the memory saving is significant.
To align with BaseConfig inheritance by interface and not property vars, refractored BaseConfig.lm_head_quantized variable into BaseConfig.is_lm_head_quantized() interface. https://github.com/vllm-project/vllm/pull/4442/commits/9981620f93d97b9fa50bfb0eb8fcabc20cd18df8
Going to pick this back up this weekend
@Qubitium - I apologize for not getting this merged before - I never managed to get all the model tests to pass and this dropped off my radar. I am picking up the PR now
cc
- @Yard1 for
VocabParallelEmbeddingweight_loading --> anything to be aware of for LoRA? - @comaniac for quantization changes
cc @njhill @cadedaniel: had to make a couple changes to MLPSpeculator, which uses the lm_head
When I did this, the e2e tests for MLPSpeculator failed because of logits divergence in the middle of a sequence at a random batch index. When I tweaked the precision the test runs at to fp32, the issue went away, suggesting to me that it is numerics.
This PR changes the logit_processor from running torch.matmul(hidden, embedding.t() we run with F.Linear()
UPDATE: nick is okay with the change + adding a new model
@comaniac ready for re-review
(ignore spec decode test failure @njhill making a new model that avoids OOM at fp32)
- @Yard1 for
VocabParallelEmbeddingweight_loading --> anything to be aware of for LoRA?
@robertgshaw2-neuralmagic @Yard1 it looks like there was some impact from this ... not sure if it actually exposed a latent bug where the lm_head for gpt_bigcode (and similar) was not previously adaptable: https://github.com/vllm-project/vllm/issues/6314
- @Yard1 for
VocabParallelEmbeddingweight_loading --> anything to be aware of for LoRA?@robertgshaw2-neuralmagic @Yard1 it looks like there was some impact from this ... not sure if it actually exposed a latent bug where the lm_head for gpt_bigcode (and similar) was not previously adaptable: #6314
@njhill do you know if this was working before?
hi!is 'look into quantized embeddings' available now? looking forward to this!