Index in triton
We'd like to do some indexing in triton kernels, say we have x_ptr, idx_ptr, out_ptr
x = tl.load(x_ptr + offsets, mask = mask)
idx = tl.load(idx_ptr + offsets, mask = mask)
we have: 1.
idx = idx.to(tl.int32)
output = tl.load(x_ptr + idx)
it works 2.
output = tl.zeros([BLOCK_SIZE, ], dtype=tl.float32)
for i in range(0, BLOCK_SIZE):
output[i] = x[idx[i]]
it reports errors. (error message is put at last) **we want to know:
- if we using approach 1, is it memory efficient ? since we use load.
- if we try x[0], it also errors: "TypeError: 'constexpr' object is not iterable" we didn't see a lot in the docs , so are there any other ways of doing indexing ?**
we using Triton Version: 2.0.0.dev20221120, python 3.8.0 and run on A100 error logs of approach 2:
Traceback (most recent call last):
File "<string>", line 21, in tri_index_kernel
KeyError: ('2-.-0-.-0-1e8410f206c822547fb50e2ea86e45a6-2b0c5161c53c71b37ae20a9996ee4bb8-3aa563e00c5c695dd945e23b09a86848-42648570729a4835b21c1c18cebedbfe-ff946bd4b3b4a4cbdf8cedc6e1c658e0-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float32, torch.float32, torch.float32, 'i32'), (64,), (True, True, True, (True, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 838, in make_triton_ir
generator.visit(fn.parse())
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
return super().visit(node)
File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
return visitor(node)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 260, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 368, in generic_visit
self.visit(item)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
return super().visit(node)
File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
return visitor(node)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 320, in visit_FunctionDef
has_ret = self.visit_compound_statement(node.body)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 254, in visit_compound_statement
self.last_ret = self.visit(stmt)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
return super().visit(node)
File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
return visitor(node)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 648, in visit_For
self.visit_compound_statement(node.body)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 254, in visit_compound_statement
self.last_ret = self.visit(stmt)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
return super().visit(node)
File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
return visitor(node)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 364, in visit_Assign
_names += [self.visit(target)]
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
return super().visit(node)
File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
return visitor(node)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 576, in visit_Subscript
assert node.ctx.__class__.__name__ == "Load"
AssertionError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "ztest.py", line 51, in <module>
output = tri_index(x, idx)
File "ztest.py", line 44, in tri_index
tri_index_kernel[grid](x, idx, output, n_elements, BLOCK_SIZE=64)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 106, in launcher
return self.run(*args, grid=grid, **kwargs)
File "<string>", line 41, in tri_index_kernel
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 1256, in compile
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 892, in _compile
module, _ = make_triton_ir(fn, signature, specialization, constants)
File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 843, in make_triton_ir
raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 23:8:
def tri_index_kernel(
x_ptr, # *Pointer* to first input vector
idx_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask = mask)
idx = tl.load(idx_ptr + offsets, mask = mask)
output = tl.zeros([BLOCK_SIZE, ], dtype=tl.float32)
min_off = tl.min(offsets, axis=0)
max_off = tl.max(offsets, axis=0)
# idx //= 1
idx = idx.to(tl.int32)
output = tl.load(x_ptr + idx)
for i in range(0, BLOCK_SIZE):
output[i] = x[idx[i]]
^
Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet.
Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ?
Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ?
It's supposed to be slow since you store values on the global memory. Though in some cases you will go through the cache.
Triton just raises an assertion error when trying to index a local tensor. I suppose it is related to this issue. Are there any workarounds?
Any updates on this? Is there still no way to do indexing in a Triton kernel?
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py
There’s this in xformers seems similar to indexing into a sparse tensor
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py
There’s this in xformers seems similar to indexing into a sparse tensor
Yes but it goes through global memory which is slow as mentioned by @Jokeren.
I have a similar issue but I only want to index different blocks such as (to compute a spline function up to a certain order):
data = tl.zeros((4, BLOCK_SIZE))
data[0] = w
data[1] = 1 - w
.....
I get similar kind of compiler error but this issue could be easily fixed by creating 4 different shared memory blocks (each with a specific name). In that case, iterating over these blocks with a for loop becomes the issue.
I think I can unroll and name everything to overcome the problem but that would produce unmaintainable code. Is there a known trick to get this to work other than going through global memory?
Are there any updates on indexing through shared memory?
#5262 adds support for this as
output = tl.gather(x, idx, axis=0)
But gather() can't help for indexing more than one dimension, correct? I suppose multiple gather()s can be used sequentially, but does gather() create a temporary copy?