triton icon indicating copy to clipboard operation
triton copied to clipboard

Loading from TMA descriptor hangs

Open saagarjha opened this issue 8 months ago • 4 comments

Describe the bug

I haven't looked into why this is happening yet but I've been able to reduce it. If you run the following code it will hang:

#!/usr/bin/env python3

import torch
import triton


autotune_config = [triton.Config({"BLOCK_SIZE": block_size}) for block_size in map(lambda n: 2**n, range(7, 10))]


@triton.autotune(configs=autotune_config, key=[])
@triton.jit
def test(
    matrix,
    n,
    BLOCK_SIZE: triton.language.constexpr,
):
    data = triton.language.make_tensor_descriptor(
        matrix, block_shape=(BLOCK_SIZE, 16), shape=(n, n), strides=(n, 1)
    ).load((0, 0))

    # Make sure the load happens and isn't optimized out
    triton.language.store(matrix + triton.language.zeros_like(data), data)


if __name__ == "__main__":
    triton.set_allocator(lambda size, alignment, stream: torch.empty(size, device="cuda", dtype=torch.int8))

    n = 4096

    matrix = torch.zeros(n, n, dtype=torch.int, device="cuda")

    test[(1,)](matrix, n)

I see this happen reliably when the autotuner hits BLOCK_SIZE=512, and in the debugger this happens on the very first kernel too. It's stuck at the very end of the kernel, here:

   0x0000000327796140 <+2368>:	SYNCS.PHASECHK.TRANS64.TRYWAIT P0[UR8+0x8000],RZ
=> 0x0000000327796150 <+2384>:	@!P0 BRA 0x327796140

Environment details

Triton: built from source at https://github.com/triton-lang/triton/commit/a39389aac290cb5764ba9a47a93f8af8c7197916 GPU: GH200

saagarjha avatar Mar 31 '25 23:03 saagarjha