triton
triton copied to clipboard
IndexError: map::at when using tl.dot
I am working on my first triton kernel and I am running into the following error at runtime.
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 161, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 395, in run
self.cache[device][key] = compile(
File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 225, in compile
next_module = compile_ir(module, metadata)
File "/opt/conda/lib/python3.10/site-packages/triton/backends/cuda/compiler.py", line 560, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
File "/opt/conda/lib/python3.10/site-packages/triton/backends/cuda/compiler.py", line 477, in make_llir
pm.run(mod)
IndexError: map::at
Here's a minimal reproduction
import torch
import triton
import triton.language as tl
import numpy as np
@triton.jit
def my_kernel(output_ptr, input_ptr, h_ptr, n_cols, stride, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
input_start_ptr = input_ptr + row_idx * stride
output_start_ptr = output_ptr + row_idx * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = input_start_ptr + col_offsets
output_ptrs = output_start_ptr + col_offsets
h_ptrs = h_ptr + col_offsets
mask = col_offsets < n_cols
x = tl.load(input_ptrs, mask=mask, other=0.0)
h = tl.load(h_ptrs)
assert(BLOCK_SIZE == 256)
H = tl.reshape(h, (16, 16)) # will make this more general once BLOCK_SIZE == 256 works
X = tl.reshape(x, (16, 16))
#HX = H + X
HX = tl.dot(H, X)
HXflat = tl.ravel(HX)
tl.store(output_ptrs, HXflat, mask=mask)
def makeH():
H2 = np.array([[1, -1], [1, 1]])
H4 = np.kron(H2, H2)
H16 = np.kron(H4, H4)
return torch.tensor(H16.ravel(), dtype=torch.float32).cuda()
Hcuda = makeH()
def foo(x):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
print(f'{BLOCK_SIZE=}')
y = torch.empty_like(x)
assert x.stride(0) == y.stride(0)
my_kernel[(n_rows, )](
y,
x,
Hcuda,
n_cols,
x.stride(0),
BLOCK_SIZE=BLOCK_SIZE
)
return y
x = torch.randn(size=(32, 256)).cuda()
output = foo(x)
print(output)
My environment is this Dockerfile
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
RUN apt update && apt install -y git wget clang gcc
RUN pip install notebook matplotlib pandas
RUN pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
using the latest triton nightly package
My GPU has a compute capability of 7.5
The code runs without the above error if HX = tl.dot(H, X)
is replaced with the commented line above HX = H + X
.
Can anyone confirm whether they can repro this?
I can reproduce this error on an RTX 2070 with CUDA 12.1 and a driver version of 530 but not on an A100 or T4. Setting allow_tf32
to False
resolves the problem, although other solutions not involving disabling TF32 would be more ideal.
Thanks @BobMcDear for trying all these different setups.
I can also verify on my setup that
tl.dot(H, X, allow_tf32=False)
resolves the problem and looking forward to a better solution from the team.
Just ran into this while using JAX's pallas on a Quadro RTX 6000, CUDA 12.3.
It is common for people to develop and debug on a lower end GPU without support for TensorFloat32, then deploy the same code on a GPU that supports it. I think a good default argument should be None rather than True/False and it should mean that behavior is determined by whether support for TensorFloat32 has been detected.
@n17s I believe the intended behaviour of allow_tf32=True
is what you describe: It signals to Triton to utilize TF32 if and only if it is available, similar to activating TF32 on PyTorch. Indeed, in versions 2.0.0
and 2.1.0
, this error does not occur even with the said argument set to True
, and it is only by upgrading to 2.2.0
that I am able to reproduce it. Also, regarding my previous comment, I had 2.1.0
installed on the T4 and that is why I did not encounter the error.
Can confirm I get this error on a T4 with triton 2.2.0
, and I can resolve it by setting allow_tf32=False
(in your case: HX = tl.dot(H, X, allow_tf32=False)
)