triton
triton copied to clipboard
Inconsistency between constants as arguments and captured globals
TLDR: After #3762 global variables which are captured by a kernel must be tl.constexpr
or annotated as such. It is surprising to me that the kernel argument which has an annotation is actually an object of type constexpr
when the CodeGenerator.visit
is running, but the captured global is not. Either that should be fixed, or the suggestion in the error message should only recommend globals be defined as VAR = tl.constexpr(<value>)
.
Details
I had some code that looks like this (actual original is from pytorch tests:
STRING_CONSTANT_C = 'value'
@triton.jit
def kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: "tl.constexpr", CONSTANT_NAME: "tl.constexpr"):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
if CONSTANT_NAME.value == STRING_CONSTANT_C:
output = 2 * x
tl.store(out_ptr + offsets, output, mask=mask)
After getting the new error about globals needing to be tl.constexpr
I tried defining STRING_CONSTANT_C` is defined globally like this:
STRING_CONSTANT_C = tl.constexpr('value')
Compilation of fails w/ if conditionals can only accept values of type {int, NoneType, bool}, not objects of type NotImplementedType
. Digging into that a bit I realize that this is parsing as str.__eq__
comparing a string to a tl.constexpr
object. So I modify the kernel source so the conditional uses STRING_CONSTANT_C.value
, which works.
The other recommendation from the error message introduced by #3762 is to use an annotation on the captured global, trying that out
STRING_CONSTANT_C: tl.constexpr = 'value'
That fails with the modified conditional and works with the original source.
Proposal
Why not translate these captured variables to always be tl.constexpr
instances the way that arguments with the annotation are handled?
Reproducer script: https://gist.github.com/amjames/973b378f7c0fa8c92b6c92d05d90547b