DDP (replicate) + TP?
Currently, when there are two device meshes (tp and dp), torchtitan should choose FSDP as the only backend for DP. Ref:
https://github.com/pytorch/torchtitan/blob/d2a4904f58accc683c17c66a360026cb3c8109af/torchtitan/parallelisms/parallelize_llama.py#L97-L98
However, the replicate should support >1D mesh and be used with TP enabled. Ref.
Q1: Why does torchtitan not support DDP (replicate) + TP? Is it only an implementation choice?
I have handwritten DDP + TP in torchtitan and surprisingly found that the loss never goes down. It seems there are no gradients after loss.backward().
To reproduce, use the branch above and run run_llama_train.sh on an 8-GPU machine.
Q2: Is it a bug or an intended feature that DDP+TP is not used, and that results in missing gradients?
And collect_env:
Collecting environment information...
PyTorch version: 2.5.0.dev20240903+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 9.13 (stretch) (x86_64)
GCC version: (Debian 6.3.0-18+deb9u1) 6.3.0 20170516
Clang version: Could not collect
CMake version: version 3.21.2
Libc version: glibc-2.24
Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.56.bsk.2-amd64-x86_64-with-glibc2.24
Is CUDA available: True
CUDA runtime version: 12.6.20
CUDA_MODULE_LOADING set to: LAZY
...
Nvidia driver version: 560.28.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
...
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.12.1
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0.dev20240903+cu118
[pip3] torchaudio==2.5.0.dev20240903+cu118
[pip3] torchdata==0.8.0
[pip3] torchvision==0.20.0.dev20240903+cu118
[conda] numpy 1.26.4 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi
[conda] torch 2.5.0.dev20240903+cu118 pypi_0 pypi
[conda] torchaudio 2.5.0.dev20240903+cu118 pypi_0 pypi
[conda] torchdata 0.8.0 pypi_0 pypi
[conda] torchvision 0.20.0.dev20240903+cu118 pypi_0 pypi
P.S.
- Torch 2.4.0 shares the similar abnormal results
- Using
DistributedDataParallel(class) rather thanreplicatebehaves well
Thanks in advance!
We do not plan to support DDP + TP as we have not identified any major use cases for this combination. When working with large models, it is more common to use FSDP + TP instead of DDP + TP. Additionally, FSDP offers several features that are not available in DDP, such as fp8. Therefore, we believe that DDP is better suited for smaller models. In TorchTitan, we enabled DDP primarily for sanity check purposes, such as verifying parallelism with 8B model and very a small batch size. So we did not verify the correctness of DDP + TP.
We do not plan to support DDP + TP as we have not identified any major use cases for this combination. When working with large models, it is more common to use FSDP + TP instead of DDP + TP. Additionally, FSDP offers several features that are not available in DDP, such as fp8. Therefore, we believe that DDP is better suited for smaller models. In TorchTitan, we enabled DDP primarily for sanity check purposes, such as verifying parallelism with 8B model and very a small batch size. So we did not verify the correctness of DDP + TP.
Thanks for the reply! I learned that FSDP+TP should be the primary/only use, especially for LLMs.
And just check. I am wondering about the original comments indicating "This is a temporary work around to enable DDP + TP." in https://github.com/pytorch/pytorch/blob/7dc1788396fc9e2860c0c236e0c0e108e96b83c8/torch/distributed/_composable/replicate.py#L225-L237. Does it not suggest that the DDP + TP is working now?
We didn't verify the accuracy but just verified the composability. So there may be accuracy issue.
Thanks for the reply!