litgpt
litgpt copied to clipboard
A potential bug for multi-GPU training
Hi,
I found the following strange phenomena when running your code for tinyllama pretraining.
- When using multiple GPUs, I got completely different results when running the same code twice. Further, many loss spike occurs. See the example for 2-card training. I use all the default settings except that I shrink the learning rate from 4e-4 to 2e-4 and batchsize from 1024 to 512.
AdamW 2-card: run1
wandb: 🚀 View run at https://wandb.ai/yushunzhang0410/pretrain-tiny-llama-1.1b/runs/83b8yfjz
AdamW 2-card: run2
wandb: 🚀 View run at https://wandb.ai/yushunzhang0410/pretrain-tiny-llama-1.1b/runs/8p6axrgw
Two runs are totally different and the training fails.
- When simply changing the above settings to single GPU, these issues do not occur. Two runs are mostly the same (with slight difference though) and the loss decreases stably without any spikes.
AdamW 1-card: run 1
wandb: 🚀 View run at https://wandb.ai/yushunzhang0410/pretrain-tiny-llama-1.1b/runs/kdg2qmj8
AdamW 1-card: run 2
wandb: 🚀 View run at https://wandb.ai/yushunzhang0410/pretrain-tiny-llama-1.1b/runs/vh23qd0u
Two runs are mostly the same and the loss decreases stably.
Do you encounter a similar issue? Any idea why?
Thanks for the report. Can you try:
- running without torch.compile (comment this line out https://github.com/Lightning-AI/litgpt/blob/main/litgpt/pretrain.py#L174)
- running with torch.compile but on PyTorch 2.3
Thanks a lot for investigating this.
cc @awaelchli for visibility
Hi,
I still encounter this issue when using your latest code on github.
four A800-80GB GPU, AdamW, Tinyllama, all default settings. I did not change anything except the data path. I still encounter loss spike which does not exists in single-GPU training.
wandb: 🚀 View run at https://wandb.ai/yushunzhang0410/pretrain-tiny-llama-1.1b-litgpt-version/runs/bhiopo5z
I simply use pip install 'litgpt[all]' to get all the dependencies, as you suggested in the github. I checked your default pretrain.py and find I am using model.compile, with Pytorch 2.3.0. This meets your suggestion "running with torch.compile but on PyTorch 2.3"
What should I do now? Am I the only one encountering this issue? Do you have this issue on your side? I think you can easily reproduce this issue if you git clone + pip install 'litgpt[all]' + run the code (just as I did).
Your wandb log metadata suggests you are using lightning 2.2dev, which probably came with an older version of litgpt that you had. You might need this fix for pretraining, so I suggest updating lightning to the latest version first.
Hi,
Thanks for your prompt reply.
I think my lightning version is correct. My code is based on a fresh environment created yesterday, where I simply run " pip install 'litgpt[all]' "as you suggested in the github. As confirmed in the "conda list" screenshot below, I am using lightning 2.3.0.dev20240328 .
Any other possible issue?
The initialization fix I made was on April 11, so the package you have is still too old. The fix was then cherry-picked into lightning 2.2.2. So I would still update the package.