TransformerEngine
TransformerEngine copied to clipboard
How to cast 16/32-bit to FP8?
Hi, how to cast a float/bfloat16 tensor to fp8? I want to conduct W8A8 (fp8) quantization. But I didn't find an example of quantizing act to FP8 format.
The easiest approach is to use native PyTorch FP8 dtypes:
x = torch.randn(128, device="cuda", dtype=torch.float32)
y = x.to(dtype=torch.float8_e4m3fn) # or torch.float8_e5m2
You could also use transformer_engine.pytorch.Float8Tensor / float8_experimental.Float8Tensor:
scale = torch.ones(1, device="cuda", dtype=torch.float32)
y1 = te.Float8Tensor.to_float8(x)
y2 = float8_experimental.Float8Tensor.to_float8(x, scale, torch.float8_e4m3fn)
These classes are based on each other and they have some nice convenience features (support for scaling factors, casting to higher precision for ops that don't support FP8, float8_experimental has torch.compile support).
Finally, you could directly use the FP8 kernels from Transformer Engine:
y = te.cpp_extensions.cast_to_fp8(
x,
fp8_meta,
0,
transformer_engine_torch.DType.kFloat8E4M3,
)
I strongly advise against using these internal functions though. Their APIs are unstable, messy, and tightly integrated with TE's logic for computing FP8 scaling factors.
Thanks @timmoon10. How to do mixed-precision calculations? matrix multiplication of FP8 and FP16 tensors to get FP16 output.
If you just want the performance benefit of FP8 matmuls, I recommend using Transformer Engine modules (like te.Linear) in your model (see this FP8 tutorial). They will internally handle the FP8 casts and FP8 scaling factors.
If you want more control, you'll have to get a bit into the weeds. I'm not sure if native PyTorch FP8 tensors support matmuls (even if they did, there would be numerical issues without FP8 scaling factors), but I see that float8_experimental.Float8Tensor does support matmuls with scaling factors (see addmm_float8_unwrapped). As far as I can tell, this just ends up calling cuBLAS (see scaled_gemm). Be advised that cuBLAS only supports FP8 inputs (see the FP8 support matrix for cublasLtMatmul). Implementing a custom matmul kernel with support for mixed FP8 and FP16 inputs may be possible using CUTLASS, but would get quite involved (and probably still be slower than TE for end-to-end training).