PiPPy
PiPPy copied to clipboard
[spmd] self-attention not converging
What the problem is:
Both single-node and sharded TensorParallelMultiheadAttention(#477) modules diverge (the forward output becomes -inf after less than 10 iterations). Also they produce different forward output of which the relative difference is too small to be captured by self.assertEqual as an inequality.
How to reproduce:
I created a branchad-hoc-self-attn-exp which based on origin/main with a bunch of print statements added to help reproduce the problem.
git checkout origin/ad-hoc-self-attn-exp
pytest test/spmd/tensor/parallel/test_tp_examples.py -s -k test_self_attn_megatron_e2e
Observation:
- Both modules produce output increasing from
-50to-infin 10 iterations. - The output of
output.sum()andoutput_tp.sum()are not exactly identical with a relatively small numeric difference.
Suggesting to use MLE loss not sum.