[BUG] Assertation error self.config["num_attention_heads"] % self.world_size_ == 0 when not perfectly divisible
Before you submit an issue, please search for existing issues to avoid duplicates.
Issue description: An assertation error is thrown when a world size is not perfectly divisible by the number of attention heads. For example a world size of 5 set with --tp 5 when running llama2 7b
Steps to reproduce:
`python -m lightllm.server.api_server --model_dir ~/models/Llama-2-7b-chat-hf --host 0.0.0.0 --port 8080 --tp 5
Expected behavior:
model can be sharded across all gpus
Error logging:
========= Remote Traceback (1) =========
Traceback (most recent call last):
File "/anaconda3/envs/lightllm/lib/python3.9/site-packages/rpyc/core/protocol.py", line 359, in _dispatch_request
res = self._HANDLERS[handler](self, *args)
File "/anaconda3/envs/lightllm/lib/python3.9/site-packages/rpyc/core/protocol.py", line 837, in _handle_call
return obj(*args, **dict(kwargs))
File "/Projects/lightllm/lightllm/server/router/model_infer/model_rpc.py", line 119, in exposed_init_model
raise e
File "/Projects/lightllm/lightllm/server/router/model_infer/model_rpc.py", line 82, in exposed_init_model
self.model = LlamaTpPartModel(model_kvargs)
File "/Projects/lightllm/lightllm/models/llama/model.py", line 33, in __init__
super().__init__(kvargs)
File "/Projects/lightllm/lightllm/common/basemodel/basemodel.py", line 46, in __init__
self._verify_must()
File "/Projects/lightllm/lightllm/common/basemodel/basemodel.py", line 69, in _verify_must
assert self.config["num_attention_heads"] % self.world_size_ == 0
AssertionError
Environment:
Please provide information about your environment, such as:
- OS: Linux pop-os 6.2.6-76060206-generic
- GPU info:
NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2- Graphics cards: 5x3090
- Python: CPython3.9
- LightLLm: eaa1b96a34626dc857c353f106970c0138a7ac88
- openai-triton: 2.1.0
Thank you for your pointing out. We recommend using a world size that is divisible by num_attention_heads, so that different shards can have balanced loads without affecting overall performance.