triton
triton copied to clipboard
Calling .to(...) on scalar arg breaks when specialization kicks in
One of our kernels has a scalar parameter. To avoid an overflow, at some point we cast it to int64 (i.e., we call v.to(tl.int64)).
On occasion we pass a value of 1 to that argument, and this triggers the specialization mechanism to build a ad-hoc kernel, where that parameter is replaced by a constexpr value. This however causes a compilation error when calling .to(...) on a constexpr.
Here is a repro:
In [1]: import triton
In [2]: import torch
In [3]: import triton.language as tl
In [4]: @triton.jit
...: def my_kernel(in_, out, mul):
...: val = tl.load(in_)
...: tl.store(out, val * mul.to(tl.int64))
...:
In [5]: in_ = torch.ones((1,), dtype=torch.long, device="cuda")
In [6]: out = torch.empty((1,), dtype=torch.long, device="cuda")
In [7]: my_kernel[1,](in_, out, 42)
Out[7]: <triton.compiler.compiler.CompiledKernel at 0x7fee5eea8850>
In [8]: out
Out[8]: tensor([42], device='cuda:0')
In [9]: my_kernel[1,](in_, out, 1)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
...
CompilationError: at 3:24:def my_kernel(in_, out, mul):
val = tl.load(in_)
tl.store(out, val * mul.to(tl.int64))
^
AttributeError("'constexpr' object has no attribute 'to'")
I have triton-nightly 2.1.0.post20231118001511.
Note that we cannot use do_not_specialize to work around this because that argument is in fact a stride thus we want it to be specialized!
It would be nice if constexpr values had the same interface as tensor, but in the mean time you can use tl.full to turn a constant into a scalar tensor of the expected dtype
tl.full([], mul, tl.int64)