llama3
llama3 copied to clipboard
Loading Model on multiple GPUs
Describe the bug
I am currently building the model from the source for the model - meta-llama/Meta-Llama-3-8B-Instruct:
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert model_args.vocab_size == tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
However, only GPU 0 will store the model but all others are empty. Supposing nothing else has been changed, I wonder how I can load this particular model on multiple GPUs (like how device_map="auto" works when loading a normal model.)
(I have tried to use accelerate.load_checkpoint_in_model but it didn't work)
Minimal reproducible example
torchrun --nproc_per_node 1 example_chat_completion.py \
--ckpt_dir Meta-Llama-3-8B-Instruct/ \
--tokenizer_path Meta-Llama-3-8B-Instruct/tokenizer.model \
--max_seq_len 512 --max_batch_size 6
Output
It will load the whole model on a single GPU card.
Runtime Environment
- Model:
meta-llama-3-8b-instruct - Using via huggingface?: no
- OS: Icon name: computer-server Chassis: server Machine ID: 2305030051f947988b5faecaf45ece43 Boot ID: 00739920e39a457999c5ae3b99f47675 Operating System: Springdale Open Enterprise Linux 8.6 (Modena) CPE OS Name: cpe:/o:springdale:enterprise_linux:8.6:GA Kernel: Linux 4.18.0-372.32.1.el8_6.x86_64 Architecture: x86-64
- CUDA version: 12.4
- PyTorch version: 2.3.1
- Python version: 3.8.12
- GPU:
Additional context Thanks a lot!
Hi @DerrickYLJ in your torchrun call you need to specify the --nproc_per_node to your number of GPU. It will spin up a process for each GPU to split the model.
The same problem, when I set the --nproc_per_node to 8, it will get an error:"AssertionError: Loading a checkpoint for MP=1 but world size is 8".
Hi @DerrickYLJ in your torchrun call you need to specify the --nproc_per_node to your number of GPU. It will spin up a process for each GPU to split the model.
Yes, I have tried that but it will output the assertion failure exactly the same in another comment.
I think that the problem is due to Llama3-8B-Instruct only has one checkpoint file? So how does set nproc_per_node will help, or more specifically, how can we solve this?
Thank you!
Sorry @ISADORAyt wasn't paying attention that @DerrickYLJ was loading the 8B model. The code in this repo is only able to load the 8B on a single GPU and the 70B model on 8 GPUs. To run different splits you'll need to look into different engine like vllm which you can either run standalone or through TorchServe's integration https://github.com/pytorch/serve?tab=readme-ov-file#-quick-start-llm-deployment
I think that the problem is due to Llama3-8B-Instruct only has one checkpoint file? So how does set nproc_per_node will help, or more specifically, how can we solve this?
@DerrickYLJ Please see above, I misread your initial post.
Same issue! Could you please tell me how you solved this problem? I have 4 GPUs. Is that true that this repo code is only able to load the 8B on a single GPU, not any else numbers ,like 4? Thank you so much!
Sorry @ISADORAyt wasn't paying attention that @DerrickYLJ was loading the 8B model. The code in this repo is only able to load the 8B on a single GPU and the 70B model on 8 GPUs. To run different splits you'll need to look into different engine like vllm which you can either run standalone or through TorchServe's integration https://github.com/pytorch/serve?tab=readme-ov-file#-quick-start-llm-deployment
Same issue! Is that true that this repo code is only able to load the 8B on a single GPU, not any else numbers ,like 4? Thank you so much! Is there other ways to cope with this problem?
im having the same issue any updates here?