ThunderKittens
ThunderKittens copied to clipboard
scaled_matmul.cu produces wrong results by loading the scaling factors wrong
When I tested with the following snippet:
M = 128
N = 128
K = 128
slice_ = slice(47, 54)
def to_float8_e4m3fn(x: torch.Tensor):
scales = x.abs().amax(dim=-1, keepdim=True).float().div(FP8_e4m3_MAX)
x = x.div(scales).clamp(min=FP8_e4m3_MIN, max=FP8_e4m3_MAX)
x = x.to(torch.float8_e4m3fn)
return x, scales
A = torch.randn(M, K, device="cuda").mul(.3)
B = torch.randn(N, K, device="cuda").mul(.3)
C = torch.empty(M, N, device="cuda")
A_fp8, scale_a_inv_s = to_float8_e4m3fn(A)
B_fp8, scale_b_inv_s = to_float8_e4m3fn(B)
tk.fp8_gemm_scaled(A_fp8_tile, B_fp8_tile, C, scale_a_inv_s, scale_b_inv_s)
y = torch._scaled_mm(
A_fp8,
B_fp8.T,
out_dtype=torch.bfloat16,
scale_a=scale_a_inv_s, # (16, 1)
scale_b=scale_b_inv_s.T, # (1, 16)
use_fast_accum=True,
) # bias=bias
y and C produces somewhat similar result and yet y is always better in terms of closeness with torch.mm when the input is not cast to torch.float8_e4m3fn (only 17-24% of the elements are better in C in terms of numerical precision, not the expected 50% compared to torch._scaled_mm).
However, if I set
scale_a_inv_s.fill_(1)
scale_b_inv_s.fill_(2)
right after I cast A and B to float8_e4m3fn, the numerical accuracy matches that of torch._scaled_mm. To further illustrate, I used
scale_a_inv_s.normal_(std=10)
scale_b_inv_s.normal_(std=20)
and the results are wildly different between torch._scaled_mm and TK's version with torch._scaled_mm still closely following bf16 computation results (I cast the scale-adjusted FP8 tensors back to BF16).