triton icon indicating copy to clipboard operation
triton copied to clipboard

Locally scoped kernels are not found

Open oulgen opened this issue 1 year ago • 3 comments

I was writing a unit test in PyTorch and wanted to define some kernels scoped under my test. It looks like if you define kernels scoped, the one you call directly is accessible but that one cannot call other kernels in the same scope.

The repro below results in

NameError('kernel2 is not defined')

Nothing urgent, I can temporarily define these in global scope as workaround but wanted to track this issue for cleanliness of codebase.

import torch
import triton
from triton import language as tl

def test(t1, t2, out):
    @triton.jit
    def kernel(
        in_ptr0,
        in_ptr1,
        out_ptr0,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        kernel2(in_ptr0, in_ptr1, out_ptr0, n_elements, BLOCK_SIZE)

    @triton.jit
    def kernel2(
        in_ptr0,
        in_ptr1,
        out_ptr0,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):

        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        in_ptr = tl.where(n_elements == 1, in_ptr0, in_ptr1)
        x = tl.load(in_ptr + offsets, mask=mask)
        tl.store(out_ptr0 + offsets, x, mask=mask)

    kernel[(4,)](t1, t2, out, 1, 16)

t1 = torch.tensor([1, 1, 1], device='cuda')
t2 = torch.tensor([2, 2, 2], device='cuda')
out = torch.zeros_like(t1)
test(t1, t2, out)

oulgen avatar Jan 29 '24 20:01 oulgen

It's a limitation of triton. Maybe we could have a better error message

Jokeren avatar Jan 29 '24 22:01 Jokeren

Could you explain why this is a limitation? I assume something in CodeGenerator needs to properly handle scopes. It uses ast.NodeVisitor which can handle this case. There's probably a simple fix to visit_Call?

oulgen avatar Jan 29 '24 22:01 oulgen

It has to do with local and global namespaces. Though I don't remember details.

Jokeren avatar Jan 29 '24 22:01 Jokeren