Tim Moon
Tim Moon
This PR has grown in scope as I've identified bugs: - Our current API for transpose caching in `Float8Tensor` doesn't work well with CUDA graphs. At the point we perform...
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
With the tests and bugfixes in https://github.com/NVIDIA/TransformerEngine/pull/869, this PR seems to handle `make_graphed_callables` with `fp8_weight_caching=True` correctly.
/te-ci pytorch
Training GPT for 100 steps (175B params, TP=2, PP=4), I don't see significant differences in the loss curves with and without this PR. I think this is ready to go.
`torch.nn.Linear`'s forward pass just requires a single cuBLAS GEMM kernel. `te.Linear` is more complicated since it also launches kernels for the FP8 cast, caching FP8 scales for the backward pass,...
/te-ci pytorch