lorax
lorax copied to clipboard
Cannot run a FP8 quantized model with LoraX
System Info
Lorax version:
Name: lorax-client Version: 0.6.3 Summary: LoRAX Python Client Home-page: https://github.com/predibase/lorax Author: Travis Addair Author-email: [email protected] License: Apache-2.0 Location: /mnt/share/ai_studio/.venv/lib/python3.11/site-packages Requires: aiohttp, certifi, huggingface-hub, pydantic Required-by:
Platform: linux, x86_64
nvidia-smi output:
Mon Nov 11 14:54:42 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 |
| N/A 28C P0 75W / 700W | 1MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:8C:00.0 Off | 0 |
| N/A 29C P0 71W / 700W | 1MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Information
- [X] Docker
- [ ] The CLI directly
Tasks
- [X] An officially supported command
- [ ] My own modifications
Reproduction
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/predibase/lorax:latest --model-id neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8
Expected behavior
I am using an official script to run LoRAX via docker from the official LoRAX page (section Launch LoRAX Server) - the only modification is the model id - I'm using FP8 quantized Llama-3.1-8b. However, it seems that LoRAX's backend does not support FP8 models, as I'm getting a FP8-related error:
2024-11-11T14:47:20.230969Z ERROR lorax_launcher: server.py:311 Error when initializing model
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in <module>
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
return get_command(self)(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
return self.main(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
return _main(
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
rv = self.invoke(ctx)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
return callback(**use_params) # type: ignore
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 439, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
self.run_forever()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 296, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 186, in get_model
return FlashLlama(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in __init__
super().__init__(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1128, in __init__
model = model_cls(prefix, config, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in __init__
self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in __init__
[
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in <listcomp>
create_layer_fn(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in __init__
self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in __init__
self.query_key_value = load_attention(config, prefix, weights, layer_id)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
base_layer = load_attention_multi(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 194, in _load_gqa
weight = weights.get_multi_weights_col(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/weights.py", line 141, in get_multi_weights_col
weight = torch.cat(weight_list, dim=dim)
RuntimeError: "cat_cuda" not implemented for 'Float8_e4m3fn'
2024-11-11T14:47:21.122584Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:
2024-11-11 14:47:10.890 | INFO | lorax_server.utils.state:<module>:22 - Backend = fa2
2024-11-11 14:47:10.890 | INFO | lorax_server.utils.state:<module>:24 - Prefix caching = False
2024-11-11 14:47:10.890 | INFO | lorax_server.utils.state:<module>:25 - Chunked prefill = False
/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
return func(*args, **kwargs)
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in <module>
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 439, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 296, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 186, in get_model
return FlashLlama(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in __init__
super().__init__(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1128, in __init__
model = model_cls(prefix, config, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in __init__
self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in __init__
[
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in <listcomp>
create_layer_fn(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in __init__
self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in __init__
self.query_key_value = load_attention(config, prefix, weights, layer_id)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
base_layer = load_attention_multi(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 194, in _load_gqa
weight = weights.get_multi_weights_col(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/weights.py", line 141, in get_multi_weights_col
weight = torch.cat(weight_list, dim=dim)
RuntimeError: "cat_cuda" not implemented for 'Float8_e4m3fn'
Could you please investigate?