xla
xla copied to clipboard
Fix crash on FP8 cublas matmuls.
Fix crash on FP8 cublas matmuls.
When cublas LT was enabled by default for FP8 matmuls, the test //tensorflow/compiler/xla/tests:float8_test_gpu started failing on Hopper. It always failed as long as cublas LT was enabled, but I forgot to enable it for that test previously.
The issue was that if cublas LT didn't support FP8 for a gemm, we would fallback to creating a cublas custom call taking FP8 inputs. But cublas never supports FP8. With this change, we create an FP16 matmul if cublas LT doesn't support FP8, converting inputs to FP16 and outputs back to FP8.