lorax
lorax copied to clipboard
Quantization appears to be broken, at least for AWQ and BnB
System Info
I have tried the following Lorax versions: (official version) ghcr.io/predibase/lorax:0.12 (locally compiled) lorax:69bb989
CUDA: 12.4 12.6
Information
- [X] Docker
- [ ] The CLI directly
Tasks
- [X] An officially supported command
- [ ] My own modifications
Reproduction
- Obtain a AWQ or BnB quantized model, such as https://huggingface.co/unsloth/Meta-Llama-3.1-8B-bnb-4bit or https://huggingface.co/hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4
- Using otherwise default settings, configure Lorax to use a quantized model
Expected behavior
I would expect Lorax to start normally, like when using a non-quantized model. However, I have tried several different versions of AWQ and BnB 4bit models, different versions of Lorax, and CUDA/Drivers, all of which fail.
Example errors (also seen in other issue requests): AWQ: "ValueError: too many values to unpack (expected 3)" BnB: "AssertionError: [12582912, 1] != [6144, 4096]"
**Full AWQ Error:
2024-12-21T13:26:16.679407Z INFO shard-manager: lorax_launcher: Starting shard rank=0
2024-12-21T13:26:25.250809Z ERROR lorax_launcher: server.py:317 Error when initializing model
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner model = get_model( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model return FlashLlama( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init super().init( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init model = model_cls(prefix, config, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init [ File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
create_layer_fn( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init self.self_attn = FlashLlamaAttention( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init self.query_key_value = load_attention(config, prefix, weights, layer_id) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention base_layer = load_attention_multi(config, prefix, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi return _load_gqa(config, prefix, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 202, in _load_gqa weight, input_scale, weight_scale = weight ValueError: too many values to unpack (expected 3)
2024-12-21T13:26:26.400592Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:
2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state:
File "/opt/conda/bin/lorax-server", line 8, in
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete return future.result()
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner model = get_model(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model return FlashLlama(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init super().init(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init model = model_cls(prefix, config, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init [
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init self.query_key_value = load_attention(config, prefix, weights, layer_id)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention base_layer = load_attention_multi(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 202, in _load_gqa weight, input_scale, weight_scale = weight
ValueError: too many values to unpack (expected 3) rank=0 2024-12-21T13:26:26.485424Z ERROR lorax_launcher: Shard 0 failed to start 2024-12-21T13:26:26.485464Z INFO lorax_launcher: Shutting down shards Error: ShardCannotStart
**Full BnB Error
2024-12-21T13:33:35.146822Z INFO shard-manager: lorax_launcher: Starting shard rank=0
2024-12-21T13:33:43.387125Z ERROR lorax_launcher: server.py:317 Error when initializing model
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner model = get_model( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model return FlashLlama( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init super().init( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init model = model_cls(prefix, config, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init [ File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
create_layer_fn( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init self.self_attn = FlashLlamaAttention( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init self.query_key_value = load_attention(config, prefix, weights, layer_id) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention base_layer = load_attention_multi(config, prefix, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi return _load_gqa(config, prefix, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 210, in _load_gqa assert list(weight.shape) == [ AssertionError: [12582912, 1] != [6144, 4096]
2024-12-21T13:33:44.667486Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:
2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state:
File "/opt/conda/bin/lorax-server", line 8, in
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete return future.result()
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner model = get_model(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model return FlashLlama(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init super().init(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init model = model_cls(prefix, config, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init [
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init self.query_key_value = load_attention(config, prefix, weights, layer_id)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention base_layer = load_attention_multi(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 210, in _load_gqa assert list(weight.shape) == [
AssertionError: [12582912, 1] != [6144, 4096] rank=0 2024-12-21T13:33:44.753438Z ERROR lorax_launcher: Shard 0 failed to start 2024-12-21T13:33:44.753471Z INFO lorax_launcher: Shutting down shards Error: ShardCannotStart
The issue appears related to https://github.com/predibase/lorax/issues/595 and https://github.com/predibase/lorax/issues/607, both closed, with no clear resolution.
Here is a another related case, https://github.com/predibase/lorax/issues/611 referencing container ghcr.io/predibase/lorax:07addea, which I could get to work with https://huggingface.co/hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4, but is very old.
The problem appears to be here:
if quantize in ["gptq", "awq"]: ... weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)" ... else: ... return weight, input_scale, weight_scale weight = torch.cat(weight_list, dim=dim) ... https://github.com/predibase/lorax/blob/69bb989df8357fdc65bd5d8ce9df687baf521950/server/lorax_server/utils/weights.py#L120 "https://github.com/predibase/lorax/blob/69bb989df8357fdc65bd5d8ce9df687baf521950/server/lorax_server/utils/weights.py#L141
Weight can have a different number of parameters, which is not accounted for here: https://github.com/predibase/lorax/blob/69bb989df8357fdc65bd5d8ce9df687baf521950/server/lorax_server/models/custom_modeling/flash_llama_modeling.py#L202
Hey @codybum ! Thanks for the investigation. Would you be willing to contribute to LoRAX by pushing up a fix with your suggested changes?
I see where it is breaking, not necessarily how to fix, but I will look into this further and see if I can get something working.