PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

[spmd] self-attention module's proj.bias isn't properly updated on all ranks but rank 0

Open XilunWu opened this issue 2 years ago • 1 comments

What the problem is:

  • Sharded TensorParallelMultiheadAttention(#477) module fails to update proj.bias parameter though the back-propagated gradient is correct.
  • Also, this error doesn't occur on rank 0.

How to reproduce:

I created a branch ad-hoc-self-attn-exp which based on origin/main with a bunch of print statements added to help reproduce the problem.

git checkout origin/ad-hoc-self-attn-exp
pytest test/spmd/tensor/parallel/test_tp_examples.py -s -k test_self_attn_megatron_e2e

Observation:

  • Other parameters (qkv.weight, qkv.bias, proj.weight) don't have this issue.
  • Model on Rank 0 perform the same parameter update with the single-node model.
  • Both proj.bias and its grad has replicate placement on mesh while qkv.bias is sharded on mesh.

XilunWu avatar Oct 07 '22 01:10 XilunWu

This is kind of OK because we have a trick for this so that we only use bias from rank0 (local rank)

fduwjj avatar Oct 11 '22 16:10 fduwjj