maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

DEFAULT_MASK_VALUE causes gradient explosion and nan loss on deep models

Open logicchains opened this issue 10 months ago • 2 comments

I was training a llama model on GPU, with a custom embedding. It worked fine with 12 layers, dim 1024, seq length 256, but loss would become nan after the first step if setting num_layers to more than 17. I debugged the gradients, and found after each layer their magnitude would increase by around 100x, until they hit float32_max at around the 18th layer and became inf, leading to nan loss.

The gradient explosion seemed to be coming from local_exps = jnp.exp(attn_weights - local_max) in attentions.py.

Changing

DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) to DEFAULT_MASK_VALUE = -jnp.inf fixed the issue, and the gradients' magnitude stopped increasing after each level.

Presumably the issue wasn't noticed during TPU training as that uses a separate codepath.

logicchains avatar Apr 23 '24 06:04 logicchains

@logicchains thanks for the tips on GPU convergence! We will experiment with this as we set up convergent regimes for GPUs.

@anfals please be aware of this as you do convergence testing on GPU

rwitten avatar Apr 30 '24 16:04 rwitten