triton
triton copied to clipboard
Incorrect result of tl.max() in Triton2.2.0
Hi, I observed that tl.max() returns wrong value in version triton2.2.0. For example, in the given picture qk is a [16, 1] tensor, the returned value of tl.max() is 1.004396 instead of the true maximum 1.745240.
The same error is also encountered in tl.sum() op. However this error does not occur when using triton2.1.0.
All triton source is build from pip package.
It's worth noting that the error only occurs when tensor dimension is greater than 1, for example, reduce from [N,1] to [1]. The tl.max() returns correct result given 1-dim tensor.
Any suggestions please?
Could you please provide some code snippets that can reliably reproduce the issue? 🙂
Here is a simple script to reproduce the issue. Try to compare the output of print in v2.1.0 and v2.2.0.
@triton.jit
def _reproduce_max_error(
q, # [dim0, dim2]
k_cache, # [dim0, dim1, dim2]
out, # [dim0]
stride_q0,
stride_k0,
stride_k1,
DIM1_SIZE: tl.constexpr,
DIM2_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
dim1_offs = tl.arange(0, DIM1_SIZE)
dim2_offs = tl.arange(0, DIM2_SIZE)
q = tl.load(q + pid * stride_q0 + dim2_offs)
k = tl.load(k_cache + pid * stride_k0 + dim1_offs[:, None] * stride_k1 + dim2_offs)
qk = tl.sum((q[None, :] * k).to(tl.float32), axis=1)[:, None]
qk_max = tl.max(qk, axis=0)
if pid == 0:
tl.device_print("qk: ", qk)
tl.device_print("qk_max: ", qk_max)
tl.store(out + pid + tl.arange(0, 1), qk_max)
def reproduce_max_error():
dim0, dim1, dim2 = 16, 16, 16
q = torch.randn((dim0, dim2), dtype=torch.half, device='cuda')
k = torch.randn((dim0, dim1, dim2), dtype=torch.half, device='cuda')
out = torch.randn((dim0,), dtype=torch.half, device='cuda')
_reproduce_max_error[(dim0,)](q, k, out, q.stride(0), k.stride(0), k.stride(1), DIM1_SIZE=dim1, DIM2_SIZE=dim2)
reproduce_max_error()
@zahimoud may want to take a look? I can reproduce it locally
I am interested in taking this ticket.
Sorry, I missed this due to the high volume of github emails. @sohale feel free to take a look. I can take a look after my current task if this is still an issue.
Should we close this issue ?