TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Refactor FP8 workspaces in linear modules
This PR refactors the logic for FP8 weight workspaces in te.Linear
, te.LayerNormLinear
, and te.LayerNormMLP
. The existing logic is somewhat convoluted since it was designed to pass around raw UINT8 buffers and Float8Tensor
support was kludged in https://github.com/NVIDIA/TransformerEngine/pull/452. For example, when te.Linear
has FP32 params, it maintains two workspace Float8Tensor
s (FP8 weight and transpose) and in each forward pass it will extract out the buffers to create another temporary Float8Tensor
. This PR streamlines the process so it will just maintain a single workspace Float8Tensor
.
Motivations:
- This fixes an FP8 recipe bug introduced in https://github.com/NVIDIA/TransformerEngine/pull/575 (see https://github.com/NVIDIA/TransformerEngine/pull/786#pullrequestreview-2024087226). The FP8 scale update kernel updates the scales for weights in every forward pass, even ones where the FP8 weights are not updated, so we can run into situations where the FP8 scales don't match the FP8 weights. This PR fixes this by taking advantage of the fact that
Float8Tensor
has a private copy of the FP8 scale-inverse won't be affected by scale updates until its values are updated. - FP8 compute can sometimes result in performance degradation due to CPU overheads (see https://github.com/NVIDIA/TransformerEngine/issues/761). The
Float8Tensor
constructor requires a CUDA kernel launch to initialize the FP8 scale-inverse, so creating unnecessaryFloat8Tensor
s adds non-trivial CPU overhead. Benchmarking the forward pass of smallte.Linear
s, this PR gives a 1.12x speedup. - I find the logic in this PR a bit easier to reason about, although I'd appreciate feedback. It feels nicer to let
Float8Tensor
internally handle things like FP8 casting and interacting withfp8_meta
.
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
This PR has grown in scope as I've identified bugs:
- Our current API for transpose caching in
Float8Tensor
doesn't work well with CUDA graphs. At the point we perform FP8 GEMMs, the FP8 weights may or may not have the transpose cache filled (e.g. it's not filled when the model is newly created). The only way to access the cache is to calltranspose_2d(cache=True)
, which fills the cache. But if you capture a CUDA graph withis_first_microbatch=None
, this means that the cache is filled during the warmup steps and you never actually capture the transpose kernel. The easiest fix is to modifyFloat8Tensor
to support lazy transpose caching (see https://github.com/NVIDIA/TransformerEngine/pull/575#discussion_r1548888680, https://github.com/NVIDIA/TransformerEngine/pull/575#pullrequestreview-1985448624, and https://github.com/NVIDIA/TransformerEngine/pull/735). - ~The ONNX export tests were failing because they assume the FP8 scales can be represented with constant operations, which requires that the scales are initialized during the ONNX export.~ The ONNX export tests failed because we copy an FP8 scale with
Tensor.copy_
, which is translated into the ONNX expand operation, I think to handle array broadcasting. This breaks some assumptions that the FP8 scales can be represented as ONNX constant operations. The fix is to useTensor.fill_
on the FP8 scales instead ofTensor.copy_
during the ONNX export process. See https://github.com/NVIDIA/TransformerEngine/pull/861#issuecomment-2125828363.
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
@timmoon10 Have you verified identical numerics with this change?
For testing CUDA graphs with FP8 caching, did you use the noop_flag
in transpose
and the fp8_weight_caching
flag in make_graphed_callables
?
With the tests and bugfixes in https://github.com/NVIDIA/TransformerEngine/pull/869, this PR seems to handle make_graphed_callables
with fp8_weight_caching=True
correctly.
/te-ci pytorch
Training GPT for 100 steps (175B params, TP=2, PP=4), I don't see significant differences in the loss curves with and without this PR. I think this is ready to go.