Tim Moon

Results 80 comments of Tim Moon

This PR has grown in scope as I've identified bugs: - Our current API for transpose caching in `Float8Tensor` doesn't work well with CUDA graphs. At the point we perform...

With the tests and bugfixes in https://github.com/NVIDIA/TransformerEngine/pull/869, this PR seems to handle `make_graphed_callables` with `fp8_weight_caching=True` correctly.

Training GPT for 100 steps (175B params, TP=2, PP=4), I don't see significant differences in the loss curves with and without this PR. I think this is ready to go.

`torch.nn.Linear`'s forward pass just requires a single cuBLAS GEMM kernel. `te.Linear` is more complicated since it also launches kernels for the FP8 cast, caching FP8 scales for the backward pass,...