ThunderKittens icon indicating copy to clipboard operation
ThunderKittens copied to clipboard

scaled_matmul.cu produces wrong results by loading the scaling factors wrong

Open RuiWang1998 opened this issue 10 months ago • 0 comments

When I tested with the following snippet:

M = 128
N = 128
K = 128
slice_ = slice(47, 54)

def to_float8_e4m3fn(x: torch.Tensor):
    scales = x.abs().amax(dim=-1, keepdim=True).float().div(FP8_e4m3_MAX)
    x = x.div(scales).clamp(min=FP8_e4m3_MIN, max=FP8_e4m3_MAX)
    x = x.to(torch.float8_e4m3fn)
    return x, scales

A = torch.randn(M, K, device="cuda").mul(.3)
B = torch.randn(N, K, device="cuda").mul(.3)
C = torch.empty(M, N, device="cuda")

A_fp8, scale_a_inv_s = to_float8_e4m3fn(A)
B_fp8, scale_b_inv_s = to_float8_e4m3fn(B)
tk.fp8_gemm_scaled(A_fp8_tile, B_fp8_tile, C, scale_a_inv_s, scale_b_inv_s)


y = torch._scaled_mm(
    A_fp8,
    B_fp8.T,
    out_dtype=torch.bfloat16,
    scale_a=scale_a_inv_s,  # (16, 1)
    scale_b=scale_b_inv_s.T,  # (1, 16)
    use_fast_accum=True,
)  # bias=bias

y and C produces somewhat similar result and yet y is always better in terms of closeness with torch.mm when the input is not cast to torch.float8_e4m3fn (only 17-24% of the elements are better in C in terms of numerical precision, not the expected 50% compared to torch._scaled_mm).

However, if I set

scale_a_inv_s.fill_(1)
scale_b_inv_s.fill_(2)

right after I cast A and B to float8_e4m3fn, the numerical accuracy matches that of torch._scaled_mm. To further illustrate, I used

scale_a_inv_s.normal_(std=10)
scale_b_inv_s.normal_(std=20)

and the results are wildly different between torch._scaled_mm and TK's version with torch._scaled_mm still closely following bf16 computation results (I cast the scale-adjusted FP8 tensors back to BF16).

RuiWang1998 avatar Feb 12 '25 10:02 RuiWang1998