triton
triton copied to clipboard
[Feature Request] Add FP8 GEMM to Ada (SM89)
Hello, it's my understanding that Triton currently does not support the use of FP8 tensor cores with Ada Lovelace GPUs (SM_89), correct? I've noticed that Cublas has implemented support for it, and TransformerEngine also provides FP8 support for Ada.
I am eagerly hoping for prompt support so that we can utilize FP8 on the 4090 as well. 🥺
I would greatly appreciate your help !!
Does https://github.com/NVIDIA/TransformerEngine use cuBLAS?
Does https://github.com/NVIDIA/TransformerEngine use cuBLAS?
Yes.
https://github.com/NVIDIA/TransformerEngine/blob/b5e13a16611be162538f489f3fd7096518640e15/transformer_engine/common/gemm/cublaslt_gemm.cu#L41
Are you aware of any ptx instructions that support fp8 on sm89?
Are you aware of any ptx instructions that support fp8 on sm89?
CUDA 12.4 or newer may support it now.
see: https://github.com/NVIDIA/cutlass/blob/c4e3e122e266644c61b4af33d0cc09f4c391a64b/include/cutlass/arch/mma_sm89.h#L57
Yeah, I think so
PyTorch 2.3 has support for FP8 gemm on Ada Lovelace that we use in vLLM, it would be great to have this supported in triton. PyTorch PR: https://github.com/pytorch/pytorch/pull/118881
@DD-DuDa I tested this on my 4090 and it works now, with CUDA 12.4, PTX 8.4 and Triton nightly. However, there are a few things that still don't work:
- upcasting FP8 to BF16 (although you can go FP8 -> FP32 -> BF16, I have a patch that will do this in PTX so it works seamlessly)
- casting FP8 in RTNE mode
- tl.dot() with FP8 and IEEE precision (vs. the usual TF32 precision)
- of course, all of the Hopper features like fences, WGMMA instructions, etc
Thank you! @rationalism