returnn icon indicating copy to clipboard operation
returnn copied to clipboard

RF upcast to f32 for certain ops (layernorm, attention, etc) (bf16/f16)

Open albertz opened this issue 5 months ago • 1 comments

RF doesn't do any special logic to upcast to f32 when the input is f16 or bf16 or so (e.g. via Torch AMP) for modules like LayerNorm. But maybe it should.

Of course, we should not just change the behavior. This should be an option. And maybe a new behavior version which changes the default of the option.

Note, one simple test whether any of this has an effect at all is to train without AMP, just using f32. If there is no difference, then none of this here has an effect.

See also what others are doing:

  • https://unsloth.ai/blog/gemma-bugs

Note, the behavior has changed in more recent versions of PyTorch for some of the PyTorch modules (e.g. torch.nn.LayerNorm, changed in 2.6 or so?).

Examples of LayerNorm, RMSNorm casting to f32:

  • HF Transformers Llama4
  • PyTorch itself in more recent versions.
  • Torchtune RMSNorm (used for Llama and many others)
  • Torchtune Gemma2 (includes also the mult with scale as f32)
  • https://github.com/pytorch/pytorch/issues/66707
  • https://github.com/pytorch/pytorch/issues/72643#issuecomment-2113682073
  • https://github.com/unslothai/unsloth/blob/main/unsloth/kernels/rms_layernorm.py#L108
  • https://github.com/manuelciosici/transformers/commit/9bc465304a76b52bc904f1805ad64b0c3d87e69a
  • https://x.com/hsu_byron/status/1827072742291861975
  • https://github.com/huggingface/transformers/issues/30236
  • https://github.com/pytorch/pytorch/pull/66920
  • https://github.com/huggingface/transformers/issues/33133

Examples of rotary embedding (RoPE) casting to f32:

Examples of attention casting to f32:

  • (Torchtune uses Torch scaled_dot_product_attention or Torch flex_attention)
  • Llama3, only for the softmax
  • PyTorch scaled_dot_product_attention (I think...)
  • PyTorch FlexAttention (I think...)
  • Torchtune T5
  • Torchtune Gemma2

Examples of attention (intentionally) leaving the orig type:

Example of final logits casting to f32:

Cross entropy (CE) done on final logits which are often casted to f32.

CTC loss (rf.ctc_loss)

Note, for the final log_softmax, CE, and/or CTC loss, the RF CE/CTC function can do the log_softmax internally, or this can be done externally. The user code can look like this:

log_prob = rf.log_softmax(logits_packed, axis=model.target_dim)
log_prob = rf.label_smoothed_log_prob_gradient(log_prob, 0.1, axis=model.target_dim)
loss = rf.cross_entropy(
    target=targets_packed, estimated=log_prob, estimated_type="log-probs", axis=model.target_dim
)

When the user code looks like this, this handling of upcast cannot really be handled by a RETURNN option. I think the user then also explicitly needs to do it.

For the other cases (e.g. CE/CTC loss internally does log_softmax, or also RMSNorm, attention, etc), RETURNN can handle it internally (but controlled via some option, as mentioned before).

(cc @dorian-K)

albertz avatar Jul 24 '25 13:07 albertz

I just realized that Torch AMP already automatically upcasts to f32 for certain ops. That includes, among others: layer_norm, log, log_softmax, softmax, exp, sum, nll_loss, rsqrt, norm, etc.

So, take the RF RMSNorm code:

variance = rf.reduce_mean(rf.square(x), axis=self.in_dim)
norm_x = x * rf.rsqrt(variance + self.eps)
out = norm_x * self.scale

mean and square are not on the list, so those are still in bf16? So maybe that's an issue here. (I might use torch.norm here instead, which is on the AMP list for being upcasted to f32...)

albertz avatar Aug 02 '25 22:08 albertz