triton icon indicating copy to clipboard operation
triton copied to clipboard

Incorrect result of tl.max() in Triton2.2.0

Open yunzhongOvO opened this issue 1 year ago • 3 comments

Hi, I observed that tl.max() returns wrong value in version triton2.2.0. For example, in the given picture qk is a [16, 1] tensor, the returned value of tl.max() is 1.004396 instead of the true maximum 1.745240.

The same error is also encountered in tl.sum() op. However this error does not occur when using triton2.1.0.

image image

All triton source is build from pip package.

It's worth noting that the error only occurs when tensor dimension is greater than 1, for example, reduce from [N,1] to [1]. The tl.max() returns correct result given 1-dim tensor.

Any suggestions please?

yunzhongOvO avatar Jan 25 '24 08:01 yunzhongOvO

Could you please provide some code snippets that can reliably reproduce the issue? 🙂

Li-dongyang avatar Jan 25 '24 09:01 Li-dongyang

Here is a simple script to reproduce the issue. Try to compare the output of print in v2.1.0 and v2.2.0.

@triton.jit
def _reproduce_max_error(
    q,  # [dim0, dim2]
    k_cache,  # [dim0, dim1, dim2]
    out, # [dim0]
    stride_q0,
    stride_k0,
    stride_k1,
    DIM1_SIZE: tl.constexpr,
    DIM2_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    dim1_offs = tl.arange(0, DIM1_SIZE)
    dim2_offs = tl.arange(0, DIM2_SIZE)
    q = tl.load(q + pid * stride_q0 + dim2_offs)
    k = tl.load(k_cache + pid * stride_k0 + dim1_offs[:, None] * stride_k1 + dim2_offs)
    
    qk = tl.sum((q[None, :] * k).to(tl.float32), axis=1)[:, None]
    qk_max = tl.max(qk, axis=0)
        
    if pid == 0:
        tl.device_print("qk: ", qk)
        tl.device_print("qk_max: ", qk_max)

    tl.store(out + pid + tl.arange(0, 1), qk_max)
    

def reproduce_max_error():
    dim0, dim1, dim2 = 16, 16, 16
    q = torch.randn((dim0, dim2), dtype=torch.half, device='cuda')
    k = torch.randn((dim0, dim1, dim2), dtype=torch.half, device='cuda')
    out = torch.randn((dim0,), dtype=torch.half, device='cuda')
    _reproduce_max_error[(dim0,)](q, k, out, q.stride(0), k.stride(0), k.stride(1), DIM1_SIZE=dim1, DIM2_SIZE=dim2)
    

reproduce_max_error()

yunzhongOvO avatar Jan 25 '24 10:01 yunzhongOvO

@zahimoud may want to take a look? I can reproduce it locally

Jokeren avatar Jan 25 '24 16:01 Jokeren

I am interested in taking this ticket.

sohale avatar Feb 23 '24 00:02 sohale

Sorry, I missed this due to the high volume of github emails. @sohale feel free to take a look. I can take a look after my current task if this is still an issue.

zahimoud avatar Feb 23 '24 01:02 zahimoud

Should we close this issue ?

zahimoud avatar Mar 07 '24 00:03 zahimoud