text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Error on LLM with a Classification head

Open oroojlooy opened this issue 11 months ago • 0 comments

System Info

Oracle Linux 7.9 Python 3.9.18 text-generation-server 1.3.4 torch 2.1.2 gpu: 8*A100-40Gb

Information

  • [X] Docker
  • [ ] The CLI directly

Tasks

  • [X] An officially supported command
  • [ ] My own modifications

Reproduction

Using the following commands should reproduce the error (replace $YOUR-TOKEN with your own HF token). Note that I can launch and call the LLM directly through the transformers package, and there is no issue there.

model=afshinO/mistral_with_classification_head
sudo docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$YOUR-TOKEN -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model  --num-shard 8  --max-total-tokens 32000 --max-input-length 16000 --max-batch-prefill-tokens=16000 

Expected behavior

I am getting the following error when trying the above command:

2024-03-14T02:46:47.492912Z  INFO text_generation_launcher: Sharding model on 8 processes
2024-03-14T02:46:47.492997Z  INFO download: text_generation_launcher: Starting download process.
2024-03-14T02:46:53.025672Z  INFO text_generation_launcher: Files are already present on the host. Skipping download.

2024-03-14T02:46:54.196449Z  INFO download: text_generation_launcher: Successfully downloaded weights.
2024-03-14T02:46:54.196733Z  INFO shard-manager: text_generation_launcher: Starting shard rank=0
2024-03-14T02:46:54.196774Z  INFO shard-manager: text_generation_launcher: Starting shard rank=1
2024-03-14T02:46:54.197218Z  INFO shard-manager: text_generation_launcher: Starting shard rank=2
2024-03-14T02:46:54.197245Z  INFO shard-manager: text_generation_launcher: Starting shard rank=3
2024-03-14T02:46:54.198503Z  INFO shard-manager: text_generation_launcher: Starting shard rank=5
2024-03-14T02:46:54.198438Z  INFO shard-manager: text_generation_launcher: Starting shard rank=4
2024-03-14T02:46:54.198567Z  INFO shard-manager: text_generation_launcher: Starting shard rank=6
2024-03-14T02:46:54.199566Z  INFO shard-manager: text_generation_launcher: Starting shard rank=7
2024-03-14T02:47:00.825279Z  WARN text_generation_launcher: Disabling exllama v2 and using v1 instead because there are issues when sharding

2024-03-14T02:47:01.133253Z  WARN text_generation_launcher: Disabling exllama v2 and using v1 instead because there are issues when sharding

2024-03-14T02:47:01.173272Z  WARN text_generation_launcher: Disabling exllama v2 and using v1 instead because there are issues when sharding

2024-03-14T02:47:01.188345Z  WARN text_generation_launcher: Disabling exllama v2 and using v1 instead because there are issues when sharding

2024-03-14T02:47:01.220290Z  WARN text_generation_launcher: Disabling exllama v2 and using v1 instead because there are issues when sharding

2024-03-14T02:47:01.233702Z  WARN text_generation_launcher: Disabling exllama v2 and using v1 instead because there are issues when sharding

2024-03-14T02:47:01.247665Z  WARN text_generation_launcher: Disabling exllama v2 and using v1 instead because there are issues when sharding

2024-03-14T02:47:01.252840Z  WARN text_generation_launcher: Disabling exllama v2 and using v1 instead because there are issues when sharding

2024-03-14T02:47:04.203171Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
2024-03-14T02:47:04.203394Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=3
2024-03-14T02:47:04.203894Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=1
2024-03-14T02:47:04.204261Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=2
2024-03-14T02:47:04.204998Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=4
2024-03-14T02:47:04.205321Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=5
2024-03-14T02:47:04.205689Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=6
2024-03-14T02:47:04.206907Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=7
2024-03-14T02:47:08.377542Z ERROR text_generation_launcher: Error when initializing model
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 89, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 235, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 196, in serve_inner
    model = get_model(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 288, in get_model
    return FlashMistral(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_mistral.py", line 430, in __init__
    super(FlashMistral, self).__init__(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_mistral.py", line 333, in __init__
    model = model_cls(config, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py", line 421, in __init__
    self.model = MistralModel(config, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py", line 356, in __init__
    [
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py", line 357, in <listcomp>
    MistralLayer(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py", line 291, in __init__
    self.self_attn = MistralAttention(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py", line 168, in __init__
    self.query_key_value = load_attention(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py", line 97, in load_attention
    return _load_gqa(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py", line 124, in _load_gqa
    assert list(weight.shape) == [
AssertionError: [1572864, 1] != [2560, 4096]

The error repeats on all 8 cores of the gpu.

oroojlooy avatar Mar 14 '24 14:03 oroojlooy