text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Gemma-2b errors when sharding

Open seongminp opened this issue 10 months ago • 0 comments

System Info

Docker image: TGI 1.4.5 Target: x86_64-unknown-linux-gnu Cargo version: 1.75.0 Commit sha: 4ee0a0c4010b6e000f176977648aa1749339e8cb Docker label: sha-4ee0a0c nvidia-smi: N/A

Information

  • [X] Docker
  • [ ] The CLI directly

Tasks

  • [X] An officially supported command
  • [ ] My own modifications

Reproduction

Launch command (huggingface authentication was taken care of):

docker run \
  --gpus '"device=0,1"' \
  --shm-size 1g     -p $port:80   -v /path_to_modle_on_disk:/model \
  ghcr.io/huggingface/text-generation-inference:1.4.5 \
    --model-id /model \
    --num-shard 2 \
    --dtype "bfloat16"\
    --max-input-length 3500 \
    --max-total-tokens 4092

Error message:

You are using a model of type gemma to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Traceback (most recent call last):

  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 89, in serve
    server.serve(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 235, in serve
    asyncio.run(

  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 196, in serve_inner
    model = get_model(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 344, in get_model
    return FlashGemma(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_gemma.py", line 62, in __init__
    model = FlashGemmaForCausalLM(config, weights)



  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 427, in __init__
    self.model = FlashGemmaModel(config, weights)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 365, in __init__
    [

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 366, in <listcomp>
    FlashGemmaLayer(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 299, in __init__
    self.self_attn = FlashGemmaAttention(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 183, in __init__
    self.query_key_value = load_attention(config, prefix, weights)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 117, in load_attention
    return _load_gqa(config, prefix, weights)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 143, in _load_gqa
    assert list(weight.shape) == [

AssertionError: [1280, 2048] != [1536, 2048]
 rank=0
Error: ShardCannotStart

Expected behavior

Expected behavior: the model should launch.

I think the problem is in this line, where num_key_value_heads is 0 because config.num_key_value_heads is 1.

seongminp avatar Apr 02 '24 18:04 seongminp