lorax icon indicating copy to clipboard operation
lorax copied to clipboard

Cannot run a FP8 quantized model with LoraX

Open Aktsvigun opened this issue 1 year ago • 6 comments

System Info

Lorax version:

Name: lorax-client Version: 0.6.3 Summary: LoRAX Python Client Home-page: https://github.com/predibase/lorax Author: Travis Addair Author-email: [email protected] License: Apache-2.0 Location: /mnt/share/ai_studio/.venv/lib/python3.11/site-packages Requires: aiohttp, certifi, huggingface-hub, pydantic Required-by:

Platform: linux, x86_64

nvidia-smi output:

Mon Nov 11 14:54:42 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:8B:00.0 Off |                    0 |
| N/A   28C    P0             75W /  700W |       1MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:8C:00.0 Off |                    0 |
| N/A   29C    P0             71W /  700W |       1MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Information

  • [X] Docker
  • [ ] The CLI directly

Tasks

  • [X] An officially supported command
  • [ ] My own modifications

Reproduction

docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/predibase/lorax:latest --model-id neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8

Expected behavior

I am using an official script to run LoRAX via docker from the official LoRAX page (section Launch LoRAX Server) - the only modification is the model id - I'm using FP8 quantized Llama-3.1-8b. However, it seems that LoRAX's backend does not support FP8 models, as I'm getting a FP8-related error:

2024-11-11T14:47:20.230969Z ERROR lorax_launcher: server.py:311 Error when initializing model
Traceback (most recent call last):
  File "/opt/conda/bin/lorax-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 439, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 296, in serve_inner
    model = get_model(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 186, in get_model
    return FlashLlama(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in __init__
    super().__init__(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1128, in __init__
    model = model_cls(prefix, config, weights)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in __init__
    self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in __init__
    [
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in <listcomp>
    create_layer_fn(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in __init__
    self.self_attn = FlashLlamaAttention(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in __init__
    self.query_key_value = load_attention(config, prefix, weights, layer_id)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
    base_layer = load_attention_multi(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
    return _load_gqa(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 194, in _load_gqa
    weight = weights.get_multi_weights_col(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/weights.py", line 141, in get_multi_weights_col
    weight = torch.cat(weight_list, dim=dim)
RuntimeError: "cat_cuda" not implemented for 'Float8_e4m3fn'

2024-11-11T14:47:21.122584Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:

2024-11-11 14:47:10.890 | INFO     | lorax_server.utils.state:<module>:22 - Backend = fa2
2024-11-11 14:47:10.890 | INFO     | lorax_server.utils.state:<module>:24 - Prefix caching = False
2024-11-11 14:47:10.890 | INFO     | lorax_server.utils.state:<module>:25 - Chunked prefill = False
/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
  return func(*args, **kwargs)
Traceback (most recent call last):

  File "/opt/conda/bin/lorax-server", line 8, in <module>
    sys.exit(app())

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
    server.serve(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 439, in serve
    asyncio.run(

  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 296, in serve_inner
    model = get_model(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 186, in get_model
    return FlashLlama(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in __init__
    super().__init__(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1128, in __init__
    model = model_cls(prefix, config, weights)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in __init__
    self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in __init__
    [

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in <listcomp>
    create_layer_fn(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in __init__
    self.self_attn = FlashLlamaAttention(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in __init__
    self.query_key_value = load_attention(config, prefix, weights, layer_id)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
    base_layer = load_attention_multi(config, prefix, weights)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
    return _load_gqa(config, prefix, weights)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 194, in _load_gqa
    weight = weights.get_multi_weights_col(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/weights.py", line 141, in get_multi_weights_col
    weight = torch.cat(weight_list, dim=dim)

RuntimeError: "cat_cuda" not implemented for 'Float8_e4m3fn'

Could you please investigate?

Aktsvigun avatar Nov 11 '24 15:11 Aktsvigun