torchtitan
torchtitan copied to clipboard
End-to-end training of DeepSeek V3
Any plans to have a real training script for DSV3?
Right now run.py only has the forward pass on some dummy data on DSV2 so it's unclear how much of DSV3 is supported and whether it actually works.
For the DSV3 forward pass, I was able to run using 32 H200s but had to lower config.max_seq_len quite a bit, otherwise was OOMing when setting up symmetric memory. Would love to be able to train DSV3!
Hi @EugenHotaj - yes a full DS v3 training script is coming.
The current PR's are part of an iterative process...more is coming soon!
@lessw2020 amazing, looking forward to it!
Hi @EugenHotaj thanks for your interest. We have a stack of PRs to enable training: https://github.com/pytorch/torchtitan/pull/941 https://github.com/pytorch/torchtitan/pull/952 https://github.com/pytorch/torchtitan/pull/954 https://github.com/pytorch/torchtitan/pull/956
If you don't want to wait till they land in main, you can check out the top of the stack, and run it with:
torchrun --standalone --nproc-per-node 8 run.py
or to the scale you want (just remember to modify the init_device_mesh dimensions in run.py).
Please be aware that neither performance nor accuracy has been verified at this point.
@kwen2501 yes I have been following along and actually have gotten a training loop internally 🙂 .
As you mentioned, I think there may be issues wrt to accuracy because the training loss is very high at step 0. I also tried to generate from the model and the output is gibberish.
One issue is that I don't think we're using the correct rope / rope params, but this only slightly helped and I think there are more issues.
Yep, let me start with verifying the model's forward.
For training, I think I initially muted some init_weight calls. Let me add them back.
Hi @lessw2020 and @kwen2501, thanks for the awesome work. Was wondering if there is an update on this.