PiPPy
PiPPy copied to clipboard
[spmd] self-attention module's proj.bias isn't properly updated on all ranks but rank 0
What the problem is:
- Sharded
TensorParallelMultiheadAttention(#477) module fails to updateproj.biasparameter though the back-propagated gradient is correct. - Also, this error doesn't occur on rank 0.
How to reproduce:
I created a branch ad-hoc-self-attn-exp which based on origin/main with a bunch of print statements added to help reproduce the problem.
git checkout origin/ad-hoc-self-attn-exp
pytest test/spmd/tensor/parallel/test_tp_examples.py -s -k test_self_attn_megatron_e2e
Observation:
- Other parameters (
qkv.weight,qkv.bias,proj.weight) don't have this issue. - Model on Rank 0 perform the same parameter update with the single-node model.
- Both
proj.biasand itsgradhas replicate placement on mesh whileqkv.biasis sharded on mesh.
This is kind of OK because we have a trick for this so that we only use bias from rank0 (local rank)