text-generation-inference
text-generation-inference copied to clipboard
Gemma-2b errors when sharding
System Info
Docker image: TGI 1.4.5 Target: x86_64-unknown-linux-gnu Cargo version: 1.75.0 Commit sha: 4ee0a0c4010b6e000f176977648aa1749339e8cb Docker label: sha-4ee0a0c nvidia-smi: N/A
Information
- [X] Docker
- [ ] The CLI directly
Tasks
- [X] An officially supported command
- [ ] My own modifications
Reproduction
Launch command (huggingface authentication was taken care of):
docker run \
--gpus '"device=0,1"' \
--shm-size 1g -p $port:80 -v /path_to_modle_on_disk:/model \
ghcr.io/huggingface/text-generation-inference:1.4.5 \
--model-id /model \
--num-shard 2 \
--dtype "bfloat16"\
--max-input-length 3500 \
--max-total-tokens 4092
Error message:
You are using a model of type gemma to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Traceback (most recent call last):
File "/opt/conda/bin/text-generation-server", line 8, in <module>
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 89, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 235, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 196, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 344, in get_model
return FlashGemma(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_gemma.py", line 62, in __init__
model = FlashGemmaForCausalLM(config, weights)
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 427, in __init__
self.model = FlashGemmaModel(config, weights)
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 365, in __init__
[
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 366, in <listcomp>
FlashGemmaLayer(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 299, in __init__
self.self_attn = FlashGemmaAttention(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 183, in __init__
self.query_key_value = load_attention(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 117, in load_attention
return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_gemma_modeling.py", line 143, in _load_gqa
assert list(weight.shape) == [
AssertionError: [1280, 2048] != [1536, 2048]
rank=0
Error: ShardCannotStart
Expected behavior
Expected behavior: the model should launch.
I think the problem is in this line, where num_key_value_heads
is 0 because config.num_key_value_heads
is 1.