triton
triton copied to clipboard
h100 fp8 gemm with fp16-to-fp8 casting from the load make the performance bad
Hi, When i tried the fp8 gemm code in matmul.py to cast the input "a" to be float16 but casted to fp8 just before the dot product op by setting AB_DTYPE to be tl.float8e4nv (link: https://github.com/openai/triton/blob/addd94e4a8d4dc0beefd5df6d97d57d18065436a/python/triton/ops/matmul.py#L128C1-L130C31), the throughput got really bad.
why is it and what would be the work-around? thanks
when taking A in fp16 and B in fp8 and casting only A into fp8 before the doc-product, it gave this error
unimplemented code path UNREACHABLE executed at /home/xxx/triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp:369!
With the latest triton nightly, I'm also running into this issue when casting bf16 inputs to fp8 right before tl.dot
. I'm setting AB_DTYPE
to tl.float8e4nv
before calling the matmul kernel and am running into the same error above. There seems to be another somewhat similar issue here, is this related? I see that @Jokeren you addressed that issue?