triton icon indicating copy to clipboard operation
triton copied to clipboard

IndexError: map::at when using tl.dot

Open n17s opened this issue 1 year ago • 7 comments

I am working on my first triton kernel and I am running into the following error at runtime.

  File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 161, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 395, in run
    self.cache[device][key] = compile(
  File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 225, in compile
    next_module = compile_ir(module, metadata)
  File "/opt/conda/lib/python3.10/site-packages/triton/backends/cuda/compiler.py", line 560, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/opt/conda/lib/python3.10/site-packages/triton/backends/cuda/compiler.py", line 477, in make_llir
    pm.run(mod)
IndexError: map::at

Here's a minimal reproduction


import torch
import triton
import triton.language as tl
import numpy as np


@triton.jit
def my_kernel(output_ptr, input_ptr, h_ptr, n_cols, stride, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)


    input_start_ptr = input_ptr + row_idx * stride
    output_start_ptr = output_ptr + row_idx * stride

    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = input_start_ptr + col_offsets
    output_ptrs = output_start_ptr + col_offsets
    h_ptrs = h_ptr + col_offsets

    mask = col_offsets < n_cols
    x = tl.load(input_ptrs, mask=mask, other=0.0)
    h = tl.load(h_ptrs)
    assert(BLOCK_SIZE == 256) 
    H = tl.reshape(h, (16, 16)) # will make this more general once BLOCK_SIZE == 256 works
    X = tl.reshape(x, (16, 16))
    #HX = H + X 
    HX = tl.dot(H, X)
    HXflat = tl.ravel(HX)
    tl.store(output_ptrs, HXflat, mask=mask)

def makeH():
    H2 = np.array([[1, -1], [1, 1]])
    H4 = np.kron(H2, H2)
    H16 = np.kron(H4, H4)
    return torch.tensor(H16.ravel(), dtype=torch.float32).cuda()

Hcuda = makeH()

def foo(x):
    n_rows, n_cols = x.shape
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    print(f'{BLOCK_SIZE=}')
    y = torch.empty_like(x)
    assert x.stride(0) == y.stride(0)
    my_kernel[(n_rows, )](
        y,
        x,
        Hcuda,
        n_cols,
        x.stride(0),
        BLOCK_SIZE=BLOCK_SIZE
    )
    return y


x = torch.randn(size=(32, 256)).cuda()
output = foo(x)
print(output)

My environment is this Dockerfile

FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime

RUN apt update && apt install -y git wget clang gcc
RUN pip install notebook matplotlib pandas
RUN pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly

using the latest triton nightly package

My GPU has a compute capability of 7.5

The code runs without the above error if HX = tl.dot(H, X) is replaced with the commented line above HX = H + X.

n17s avatar Jan 24 '24 22:01 n17s

Can anyone confirm whether they can repro this?

n17s avatar Jan 25 '24 20:01 n17s

I can reproduce this error on an RTX 2070 with CUDA 12.1 and a driver version of 530 but not on an A100 or T4. Setting allow_tf32 to False resolves the problem, although other solutions not involving disabling TF32 would be more ideal.

BobMcDear avatar Jan 25 '24 21:01 BobMcDear

Thanks @BobMcDear for trying all these different setups.

I can also verify on my setup that

tl.dot(H, X, allow_tf32=False)

resolves the problem and looking forward to a better solution from the team.

n17s avatar Jan 25 '24 22:01 n17s

Just ran into this while using JAX's pallas on a Quadro RTX 6000, CUDA 12.3.

davisyoshida avatar Jan 26 '24 08:01 davisyoshida

It is common for people to develop and debug on a lower end GPU without support for TensorFloat32, then deploy the same code on a GPU that supports it. I think a good default argument should be None rather than True/False and it should mean that behavior is determined by whether support for TensorFloat32 has been detected.

n17s avatar Jan 26 '24 18:01 n17s

@n17s I believe the intended behaviour of allow_tf32=True is what you describe: It signals to Triton to utilize TF32 if and only if it is available, similar to activating TF32 on PyTorch. Indeed, in versions 2.0.0 and 2.1.0, this error does not occur even with the said argument set to True, and it is only by upgrading to 2.2.0 that I am able to reproduce it. Also, regarding my previous comment, I had 2.1.0 installed on the T4 and that is why I did not encounter the error.

BobMcDear avatar Jan 27 '24 14:01 BobMcDear

Can confirm I get this error on a T4 with triton 2.2.0, and I can resolve it by setting allow_tf32=False (in your case: HX = tl.dot(H, X, allow_tf32=False))

UmerHA avatar Apr 10 '24 23:04 UmerHA