TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Massively reduce LayerNorm/RMSNorm training memory usage by sharing saved tensor with other parts of the networks

Open RuiWang1998 opened this issue 2 years ago • 1 comments

Saving the output of the normalization instead of the input reduces the memory cost in modern networks, where the output is going to be saved anyways (e.g., a Linear layer) and the input is only needed here otherwise.

We do this by changing the norm backward kernels to load from output instead of the input, and to compute normalized tensor from output instead of input. To stabilize gradients, we also clamp by magnitude the gamma value for division.

However, for now it seems that it comes at a price at a somewhat lower numerical precision. To be investiagated further.

For now we pass the operator tests if we increase the tolerance for layer normalization's dgamma and dx.

Effect

From the sequence parallel paper, layer norm costs 4sbh in memory per layer during training. However, if we save the output of the layer instead of the input, we can forget (free) those as they are not needed anywhere else. This amounts to ~1/6 of total activation cost of a transformer model.

Note that this may results in slightly higher numerical errors because we are using output for gradients by the time we use which, the rounding errors may propagate.

TODO

For now we are passing the operator tests, the tests in test_numerics.py (with added LayerNorm tests and gradients tests). We will very soon add more features including supports for fp8 and integration with python frameworks. However, we note that this might require rethinking the dynamics of LayerNorm+Linear/MLP in the presence of fp8. Maybe eventually will just make the normalization layers output half/single precision data and save those instead of the inputs.

Also

See https://github.com/NVIDIA/apex/pull/1715

RuiWang1998 avatar Sep 11 '23 14:09 RuiWang1998

@timmoon10 Could you help in reviewing this PR?

ptrendx avatar Sep 18 '23 19:09 ptrendx