lorax icon indicating copy to clipboard operation
lorax copied to clipboard

Quantization appears to be broken, at least for AWQ and BnB

Open codybum opened this issue 11 months ago • 5 comments

System Info

I have tried the following Lorax versions: (official version) ghcr.io/predibase/lorax:0.12 (locally compiled) lorax:69bb989

CUDA: 12.4 12.6

Information

  • [X] Docker
  • [ ] The CLI directly

Tasks

  • [X] An officially supported command
  • [ ] My own modifications

Reproduction

  1. Obtain a AWQ or BnB quantized model, such as https://huggingface.co/unsloth/Meta-Llama-3.1-8B-bnb-4bit or https://huggingface.co/hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4
  2. Using otherwise default settings, configure Lorax to use a quantized model

Expected behavior

I would expect Lorax to start normally, like when using a non-quantized model. However, I have tried several different versions of AWQ and BnB 4bit models, different versions of Lorax, and CUDA/Drivers, all of which fail.

Example errors (also seen in other issue requests): AWQ: "ValueError: too many values to unpack (expected 3)" BnB: "AssertionError: [12582912, 1] != [6144, 4096]"

**Full AWQ Error:

2024-12-21T13:26:16.679407Z INFO shard-manager: lorax_launcher: Starting shard rank=0 2024-12-21T13:26:25.250809Z ERROR lorax_launcher: server.py:317 Error when initializing model Traceback (most recent call last): File "/opt/conda/bin/lorax-server", line 8, in sys.exit(app()) File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in call return get_command(self)(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in call return self.main(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main return _main( File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main rv = self.invoke(ctx) File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke return ctx.invoke(self.callback, **ctx.params) File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke return __callback(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper return callback(**use_params) # type: ignore File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve server.serve( File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve asyncio.run( File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run return loop.run_until_complete(main) File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete self.run_forever() File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever self._run_once() File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once handle._run() File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run self._context.run(self._callback, *self._args)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner model = get_model( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model return FlashLlama( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init super().init( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init model = model_cls(prefix, config, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init [ File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in create_layer_fn( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init self.self_attn = FlashLlamaAttention( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init self.query_key_value = load_attention(config, prefix, weights, layer_id) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention base_layer = load_attention_multi(config, prefix, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi return _load_gqa(config, prefix, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 202, in _load_gqa weight, input_scale, weight_scale = weight ValueError: too many values to unpack (expected 3)

2024-12-21T13:26:26.400592Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:

2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::22 - Backend = fa2 2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::24 - Prefix caching = False 2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::25 - Chunked prefill = False /opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead. return func(*args, **kwargs) Traceback (most recent call last):

File "/opt/conda/bin/lorax-server", line 8, in sys.exit(app())

File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve server.serve(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve asyncio.run(

File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run return loop.run_until_complete(main)

File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete return future.result()

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner model = get_model(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model return FlashLlama(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init super().init(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init model = model_cls(prefix, config, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init [

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in create_layer_fn(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init self.self_attn = FlashLlamaAttention(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init self.query_key_value = load_attention(config, prefix, weights, layer_id)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention base_layer = load_attention_multi(config, prefix, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi return _load_gqa(config, prefix, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 202, in _load_gqa weight, input_scale, weight_scale = weight

ValueError: too many values to unpack (expected 3) rank=0 2024-12-21T13:26:26.485424Z ERROR lorax_launcher: Shard 0 failed to start 2024-12-21T13:26:26.485464Z INFO lorax_launcher: Shutting down shards Error: ShardCannotStart

**Full BnB Error 2024-12-21T13:33:35.146822Z INFO shard-manager: lorax_launcher: Starting shard rank=0 2024-12-21T13:33:43.387125Z ERROR lorax_launcher: server.py:317 Error when initializing model Traceback (most recent call last): File "/opt/conda/bin/lorax-server", line 8, in sys.exit(app()) File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in call return get_command(self)(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in call return self.main(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main return _main( File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main rv = self.invoke(ctx) File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke return ctx.invoke(self.callback, **ctx.params) File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke return __callback(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper return callback(**use_params) # type: ignore File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve server.serve( File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve asyncio.run( File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run return loop.run_until_complete(main) File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete self.run_forever() File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever self._run_once() File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once handle._run() File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run self._context.run(self._callback, *self._args)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner model = get_model( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model return FlashLlama( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init super().init( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init model = model_cls(prefix, config, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init [ File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in create_layer_fn( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init self.self_attn = FlashLlamaAttention( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init self.query_key_value = load_attention(config, prefix, weights, layer_id) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention base_layer = load_attention_multi(config, prefix, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi return _load_gqa(config, prefix, weights) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 210, in _load_gqa assert list(weight.shape) == [ AssertionError: [12582912, 1] != [6144, 4096]

2024-12-21T13:33:44.667486Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:

2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::22 - Backend = fa2 2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::24 - Prefix caching = False 2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::25 - Chunked prefill = False /opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead. return func(*args, **kwargs) Traceback (most recent call last):

File "/opt/conda/bin/lorax-server", line 8, in sys.exit(app())

File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve server.serve(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve asyncio.run(

File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run return loop.run_until_complete(main)

File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete return future.result()

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner model = get_model(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model return FlashLlama(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init super().init(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init model = model_cls(prefix, config, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init [

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in create_layer_fn(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init self.self_attn = FlashLlamaAttention(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init self.query_key_value = load_attention(config, prefix, weights, layer_id)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention base_layer = load_attention_multi(config, prefix, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi return _load_gqa(config, prefix, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 210, in _load_gqa assert list(weight.shape) == [

AssertionError: [12582912, 1] != [6144, 4096] rank=0 2024-12-21T13:33:44.753438Z ERROR lorax_launcher: Shard 0 failed to start 2024-12-21T13:33:44.753471Z INFO lorax_launcher: Shutting down shards Error: ShardCannotStart

codybum avatar Dec 21 '24 13:12 codybum

The issue appears related to https://github.com/predibase/lorax/issues/595 and https://github.com/predibase/lorax/issues/607, both closed, with no clear resolution.

codybum avatar Dec 21 '24 13:12 codybum

Here is a another related case, https://github.com/predibase/lorax/issues/611 referencing container ghcr.io/predibase/lorax:07addea, which I could get to work with https://huggingface.co/hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4, but is very old.

codybum avatar Dec 21 '24 14:12 codybum

The problem appears to be here:

if quantize in ["gptq", "awq"]: ... weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)" ... else: ... return weight, input_scale, weight_scale weight = torch.cat(weight_list, dim=dim) ... https://github.com/predibase/lorax/blob/69bb989df8357fdc65bd5d8ce9df687baf521950/server/lorax_server/utils/weights.py#L120 "https://github.com/predibase/lorax/blob/69bb989df8357fdc65bd5d8ce9df687baf521950/server/lorax_server/utils/weights.py#L141

Weight can have a different number of parameters, which is not accounted for here: https://github.com/predibase/lorax/blob/69bb989df8357fdc65bd5d8ce9df687baf521950/server/lorax_server/models/custom_modeling/flash_llama_modeling.py#L202

codybum avatar Dec 21 '24 14:12 codybum

Hey @codybum ! Thanks for the investigation. Would you be willing to contribute to LoRAX by pushing up a fix with your suggested changes?

arnavgarg1 avatar Jan 03 '25 16:01 arnavgarg1

I see where it is breaking, not necessarily how to fix, but I will look into this further and see if I can get something working.

codybum avatar Jan 03 '25 20:01 codybum