triton
triton copied to clipboard
Locally scoped kernels are not found
I was writing a unit test in PyTorch and wanted to define some kernels scoped under my test. It looks like if you define kernels scoped, the one you call directly is accessible but that one cannot call other kernels in the same scope.
The repro below results in
NameError('kernel2 is not defined')
Nothing urgent, I can temporarily define these in global scope as workaround but wanted to track this issue for cleanliness of codebase.
import torch
import triton
from triton import language as tl
def test(t1, t2, out):
@triton.jit
def kernel(
in_ptr0,
in_ptr1,
out_ptr0,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
kernel2(in_ptr0, in_ptr1, out_ptr0, n_elements, BLOCK_SIZE)
@triton.jit
def kernel2(
in_ptr0,
in_ptr1,
out_ptr0,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
in_ptr = tl.where(n_elements == 1, in_ptr0, in_ptr1)
x = tl.load(in_ptr + offsets, mask=mask)
tl.store(out_ptr0 + offsets, x, mask=mask)
kernel[(4,)](t1, t2, out, 1, 16)
t1 = torch.tensor([1, 1, 1], device='cuda')
t2 = torch.tensor([2, 2, 2], device='cuda')
out = torch.zeros_like(t1)
test(t1, t2, out)
It's a limitation of triton. Maybe we could have a better error message
Could you explain why this is a limitation? I assume something in CodeGenerator
needs to properly handle scopes. It uses ast.NodeVisitor
which can handle this case. There's probably a simple fix to visit_Call
?
It has to do with local and global namespaces. Though I don't remember details.