TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Refactor FP8 workspaces in linear modules

Open timmoon10 opened this issue 9 months ago • 4 comments

This PR refactors the logic for FP8 weight workspaces in te.Linear, te.LayerNormLinear, and te.LayerNormMLP. The existing logic is somewhat convoluted since it was designed to pass around raw UINT8 buffers and Float8Tensor support was kludged in https://github.com/NVIDIA/TransformerEngine/pull/452. For example, when te.Linear has FP32 params, it maintains two workspace Float8Tensors (FP8 weight and transpose) and in each forward pass it will extract out the buffers to create another temporary Float8Tensor. This PR streamlines the process so it will just maintain a single workspace Float8Tensor.

Motivations:

  • This fixes an FP8 recipe bug introduced in https://github.com/NVIDIA/TransformerEngine/pull/575 (see https://github.com/NVIDIA/TransformerEngine/pull/786#pullrequestreview-2024087226). The FP8 scale update kernel updates the scales for weights in every forward pass, even ones where the FP8 weights are not updated, so we can run into situations where the FP8 scales don't match the FP8 weights. This PR fixes this by taking advantage of the fact that Float8Tensor has a private copy of the FP8 scale-inverse won't be affected by scale updates until its values are updated.
  • FP8 compute can sometimes result in performance degradation due to CPU overheads (see https://github.com/NVIDIA/TransformerEngine/issues/761). The Float8Tensor constructor requires a CUDA kernel launch to initialize the FP8 scale-inverse, so creating unnecessary Float8Tensors adds non-trivial CPU overhead. Benchmarking the forward pass of small te.Linears, this PR gives a 1.12x speedup.
  • I find the logic in this PR a bit easier to reason about, although I'd appreciate feedback. It feels nicer to let Float8Tensor internally handle things like FP8 casting and interacting with fp8_meta.

timmoon10 avatar Apr 27 '24 01:04 timmoon10

/te-ci pytorch

timmoon10 avatar Apr 27 '24 02:04 timmoon10

/te-ci pytorch

timmoon10 avatar May 15 '24 01:05 timmoon10

/te-ci pytorch

timmoon10 avatar May 16 '24 04:05 timmoon10

This PR has grown in scope as I've identified bugs:

  • Our current API for transpose caching in Float8Tensor doesn't work well with CUDA graphs. At the point we perform FP8 GEMMs, the FP8 weights may or may not have the transpose cache filled (e.g. it's not filled when the model is newly created). The only way to access the cache is to call transpose_2d(cache=True), which fills the cache. But if you capture a CUDA graph with is_first_microbatch=None, this means that the cache is filled during the warmup steps and you never actually capture the transpose kernel. The easiest fix is to modify Float8Tensor to support lazy transpose caching (see https://github.com/NVIDIA/TransformerEngine/pull/575#discussion_r1548888680, https://github.com/NVIDIA/TransformerEngine/pull/575#pullrequestreview-1985448624, and https://github.com/NVIDIA/TransformerEngine/pull/735).
  • ~The ONNX export tests were failing because they assume the FP8 scales can be represented with constant operations, which requires that the scales are initialized during the ONNX export.~ The ONNX export tests failed because we copy an FP8 scale with Tensor.copy_, which is translated into the ONNX expand operation, I think to handle array broadcasting. This breaks some assumptions that the FP8 scales can be represented as ONNX constant operations. The fix is to use Tensor.fill_ on the FP8 scales instead of Tensor.copy_ during the ONNX export process. See https://github.com/NVIDIA/TransformerEngine/pull/861#issuecomment-2125828363.

timmoon10 avatar May 16 '24 04:05 timmoon10

/te-ci pytorch

timmoon10 avatar May 20 '24 21:05 timmoon10

/te-ci pytorch

timmoon10 avatar May 21 '24 19:05 timmoon10

/te-ci pytorch

timmoon10 avatar May 21 '24 21:05 timmoon10

/te-ci pytorch

timmoon10 avatar May 22 '24 21:05 timmoon10

@timmoon10 Have you verified identical numerics with this change?

ksivaman avatar May 23 '24 16:05 ksivaman

For testing CUDA graphs with FP8 caching, did you use the noop_flag in transpose and the fp8_weight_caching flag in make_graphed_callables?

ksivaman avatar May 23 '24 16:05 ksivaman

With the tests and bugfixes in https://github.com/NVIDIA/TransformerEngine/pull/869, this PR seems to handle make_graphed_callables with fp8_weight_caching=True correctly.

timmoon10 avatar May 28 '24 19:05 timmoon10

/te-ci pytorch

timmoon10 avatar May 28 '24 19:05 timmoon10

Training GPT for 100 steps (175B params, TP=2, PP=4), I don't see significant differences in the loss curves with and without this PR. I think this is ready to go.

timmoon10 avatar May 30 '24 02:05 timmoon10