xla icon indicating copy to clipboard operation
xla copied to clipboard

DDP does not work with torch_xla.compile

Open bfolie opened this issue 7 months ago • 2 comments

🐛 Bug

Model wrapped by PyTorch's DDP can be compiled and run, but on the second step they hit an unexpected recompilation and hang indefinitely. This can be seen by trying to run the DDP example. On a TPU v4-8 it takes about 2 minutes to compile the first time, then emits the following. It then hangs for at least 60 minutes (that's the longest I've waited).

Epoch 1 train begin  9:44PM UTC on Jun 06, 2025
epoch: 1, step: 0, loss: 6.926775932312012, rate: 1.1839591635674727
epoch: 1, step: 0, loss: 6.926775932312012, rate: 1.1839480487292668
epoch: 1, step: 0, loss: 6.926775932312012, rate: 1.1839564272665863
epoch: 1, step: 0, loss: 6.926775932312012, rate: 1.1839555943683504

Unexpected Execution Analysis: ================================================================================
Unexpected Execution Analysis: Compilation Cause
Unexpected Execution Analysis:   most likely user code trying to access tensor value before torch_xla.sync
Unexpected Execution Analysis: Graph Info: 
Unexpected Execution Analysis:   Graph Name: resnet_step_fn
Unexpected Execution Analysis:   Graph Hash: 811e91c7d7f3c67640d8c2fc2d2a5e09
Unexpected Execution Analysis:   Number of Graph Inputs: 3
Unexpected Execution Analysis:   Number of Graph Outputs: 1
Unexpected Execution Analysis: Python Frame Triggered Execution: 
Unexpected Execution Analysis:   _pre_forward (/usr/local/lib/python3.10/site-packages/torch/nn/parallel/distributed.py:1533)
Unexpected Execution Analysis:   forward (/usr/local/lib/python3.10/site-packages/torch/nn/parallel/distributed.py:1644)
Unexpected Execution Analysis:   _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1778)
Unexpected Execution Analysis:   _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1767)
Unexpected Execution Analysis:   step_fn (/workspaces/torch/pytorch/xla/examples/train_resnet_base.py:47)
Unexpected Execution Analysis:   inner (/usr/local/lib/python3.10/contextlib.py:79)
Unexpected Execution Analysis:   train_loop_fn (/workspaces/torch/pytorch/xla/examples/train_resnet_base.py:58)
Unexpected Execution Analysis:   start_training (/workspaces/torch/pytorch/xla/examples/train_resnet_base.py:69)
Unexpected Execution Analysis:   ..........
Unexpected Execution Analysis: --------------------------------------------------------------------------------
Unexpected Execution Analysis: ================================================================================

The error message points to this line. Removing self.reducer._rebuild_buckets() fixes the issue.

bfolie avatar Jun 06 '25 21:06 bfolie

I'm not sure I follow. In the log you posted, there's no error. There's a compilation log, indicating that the line you linked triggered a recompilation. That said, again, it's not an error message.

ysiraichi avatar Jun 07 '25 15:06 ysiraichi

Sorry @ysiraichi, you're right that I neglected an important fact -- after hitting recompilation it hangs for an indeterminate period of time. I'll update the issue.

bfolie avatar Jun 09 '25 14:06 bfolie