xla
xla copied to clipboard
[NVIDIA] Don't use c_scale when the operand c is non-fp8
For current fp8 gemm, we set the c_scale to one, though it is effectively never used. Newer cublaslt, however, has a stricter requirement that c_scale can be set only when the operand c is fp8. So, this PR fixes this issue by removing the c_scale. This shouldn't affect the current fp8 gemm, as it is not used anyway.
cc. @philipphack @reedwm
Right, I originally tried to fix it in the gemm rewriter, but I hit segfault here and there in many other places and I guess those places were built on the assumption of the absolute operand positions. Basically, the current operand list is like [a, b, [c], a_scale, b_scale, c_scale, d_scale, [vector_bias]], where [] means optional. I tried to convert it to [a, b, [c], a_scale, b_scale, [vector_bias], [d_scale]] but it seems it is not that simple as I expected. So, this PR is more a hot fix since the scaling factors for non-fp8 will be dropped by cublaslt anyway even with current versions. That said, I agree we still need to pursue the gemm rewritter fix. Maybe @philipphack @wenscarl can help?
How urgent is fixing this? Is this broken in a current version of CUDA? If not, I would much prefer to go the gemm_rewriter route. If @philipphack and @wenscarl cannot help with this, I can help as well.
Also IIRC I don't think we ever have a c_scale, but instead pass 1 in gemm_rewriter, since I don't recall ever supporting an FP8 matrix bias. I could be misremembering though. But I assume d_scale also must not be passed when the output is not FP8, right?
It's blocking our internal tests with the newer cublaslt and hardware, which no longer allows the non-fp8 tensor to be associated with scaling factors.
Yes, the c_scale is never used in real cases and d_scale will be simply ignored in the current cublaslt for non-fp8 output. And the newer cublaslt simply disallowed such cases to avoid giving users a false impression of scaling non-narrow-dtypes. So, I think it is safe for this fix.
@kaixih Would it be possible to add a simple test case that detects this? If it's urgent, I can proceed with the merge first.
Hi @penpornk, could you please help merge this if possible? It is currently blocking our CI tests for the new hardware. I agree that we should have a unit test for this change, and I will take care of that for the change mentioned by @reedwm. That unit test wil ensure the custom-call node does not have a scale operand for a non-fp8 input. However, for this hotfix, I don’t think it’s necessary, as long as we ensure that all existing unit tests pass. As I mentioned earlier, the c_scales and d_scales are actually ignored and never used in the actual computation by cuBLASLT. Thank you.
@kaixih Thank you for the quick reply! Will work on merging now.