triton icon indicating copy to clipboard operation
triton copied to clipboard

Inconsistency between constants as arguments and captured globals

Open amjames opened this issue 9 months ago • 3 comments

TLDR: After #3762 global variables which are captured by a kernel must be tl.constexpr or annotated as such. It is surprising to me that the kernel argument which has an annotation is actually an object of type constexpr when the CodeGenerator.visit is running, but the captured global is not. Either that should be fixed, or the suggestion in the error message should only recommend globals be defined as VAR = tl.constexpr(<value>).

Details

I had some code that looks like this (actual original is from pytorch tests:

STRING_CONSTANT_C = 'value'

@triton.jit
def kernel(in_ptr, out_ptr,  n_elements, BLOCK_SIZE: "tl.constexpr", CONSTANT_NAME: "tl.constexpr"):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    if CONSTANT_NAME.value == STRING_CONSTANT_C:
        output = 2 * x
    tl.store(out_ptr + offsets, output, mask=mask)

After getting the new error about globals needing to be tl.constexpr I tried defining STRING_CONSTANT_C` is defined globally like this:

STRING_CONSTANT_C = tl.constexpr('value')

Compilation of fails w/ if conditionals can only accept values of type {int, NoneType, bool}, not objects of type NotImplementedType. Digging into that a bit I realize that this is parsing as str.__eq__ comparing a string to a tl.constexpr object. So I modify the kernel source so the conditional uses STRING_CONSTANT_C.value, which works.

The other recommendation from the error message introduced by #3762 is to use an annotation on the captured global, trying that out

STRING_CONSTANT_C: tl.constexpr = 'value'

That fails with the modified conditional and works with the original source.

Proposal

Why not translate these captured variables to always be tl.constexpr instances the way that arguments with the annotation are handled?

Reproducer script: https://gist.github.com/amjames/973b378f7c0fa8c92b6c92d05d90547b

amjames avatar May 15 '24 22:05 amjames