PiPPy
PiPPy copied to clipboard
[spmd] self-attention not converging
What the problem is:
Both single-node and sharded TensorParallelMultiheadAttention
(#477) modules diverge (the forward output becomes -inf
after less than 10 iterations). Also they produce different forward output of which the relative difference is too small to be captured by self.assertEqual
as an inequality.
How to reproduce:
I created a branchad-hoc-self-attn-exp
which based on origin/main
with a bunch of print statements added to help reproduce the problem.
git checkout origin/ad-hoc-self-attn-exp
pytest test/spmd/tensor/parallel/test_tp_examples.py -s -k test_self_attn_megatron_e2e
Observation:
- Both modules produce output increasing from
-50
to-inf
in 10 iterations. - The output of
output.sum()
andoutput_tp.sum()
are not exactly identical with a relatively small numeric difference.
Suggesting to use MLE loss not sum.