[Model Loading] Speedup model loading with distributed loading
Hello! The current method for model loading is quite fixed, regardless of the tensor parallel size. It involves each rank in a tp group reading the full weight file, and then discarding the excess weight tensors if only a portion of the parameters is needed for that rank. When --tensor-parallel-size is greater than 1, most parameters require only 1/tp_size of the parameters, leading to significant additional weights IO.
Observing that the disk IO speed is slow (particularly for bin files), and the transfer rate between multiple GPUs is fast, we can adopt a distributed loading approach. This means each worker loads only 1/tp_size of the weight file (by file division, or for SafeTensors type, it can be by tensor division). Then the parameters needed by the workers are transferred to each other using torch.distributed.scatter or torch.distributed.broadcast. This approach can reduce disk IO to 1/tp_size.
I have implemented the example distributed loading code in llama.py and baichuan.py. I believe other models (if needed) can easily implement similar logic. To ensure compatibility with previous codes, the args introduced in this PR are optional. Therefore, if you do not wish to use distributed loading, the original code does not require any modifications.
When --tensor-parallel-size >= 4, the distributed loading method can significantly accelerate loading times, typically by 40% or more. Here are the experiment results on my machine (8*A100) for Llama-2-70b and Baichuan2-13B.
| Llama-2-70b (TP8) | Baichuan2-13B (TP4) | |
|---|---|---|
| Vanilla | 249.5s | 45.3s |
| Distributed | 141.5s | 25.2s |
| Speedup | 43.3% | 44.4% |