PiPPy
PiPPy copied to clipboard
[spmd] self-attention module's proj.bias isn't properly updated on all ranks but rank 0
What the problem is:
- Sharded
TensorParallelMultiheadAttention
(#477) module fails to updateproj.bias
parameter though the back-propagated gradient is correct. - Also, this error doesn't occur on rank 0.
How to reproduce:
I created a branch ad-hoc-self-attn-exp
which based on origin/main
with a bunch of print statements added to help reproduce the problem.
git checkout origin/ad-hoc-self-attn-exp
pytest test/spmd/tensor/parallel/test_tp_examples.py -s -k test_self_attn_megatron_e2e
Observation:
- Other parameters (
qkv.weight
,qkv.bias
,proj.weight
) don't have this issue. - Model on Rank 0 perform the same parameter update with the single-node model.
- Both
proj.bias
and itsgrad
has replicate placement on mesh whileqkv.bias
is sharded on mesh.
This is kind of OK because we have a trick for this so that we only use bias from rank0 (local rank)