triton icon indicating copy to clipboard operation
triton copied to clipboard

Calling .to(...) on scalar arg breaks when specialization kicks in

Open lw opened this issue 1 year ago • 2 comments

One of our kernels has a scalar parameter. To avoid an overflow, at some point we cast it to int64 (i.e., we call v.to(tl.int64)).

On occasion we pass a value of 1 to that argument, and this triggers the specialization mechanism to build a ad-hoc kernel, where that parameter is replaced by a constexpr value. This however causes a compilation error when calling .to(...) on a constexpr.

Here is a repro:

In [1]: import triton

In [2]: import torch

In [3]: import triton.language as tl

In [4]: @triton.jit
   ...: def my_kernel(in_, out, mul):
   ...:     val = tl.load(in_)
   ...:     tl.store(out, val * mul.to(tl.int64))
   ...:

In [5]: in_ = torch.ones((1,), dtype=torch.long, device="cuda")

In [6]: out = torch.empty((1,), dtype=torch.long, device="cuda")

In [7]: my_kernel[1,](in_, out, 42)
Out[7]: <triton.compiler.compiler.CompiledKernel at 0x7fee5eea8850>

In [8]: out
Out[8]: tensor([42], device='cuda:0')

In [9]: my_kernel[1,](in_, out, 1)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
...
CompilationError: at 3:24:def my_kernel(in_, out, mul):
    val = tl.load(in_)
    tl.store(out, val * mul.to(tl.int64))
                        ^
AttributeError("'constexpr' object has no attribute 'to'")

I have triton-nightly 2.1.0.post20231118001511.

lw avatar Jan 15 '24 16:01 lw

Note that we cannot use do_not_specialize to work around this because that argument is in fact a stride thus we want it to be specialized!

lw avatar Jan 15 '24 16:01 lw

It would be nice if constexpr values had the same interface as tensor, but in the mean time you can use tl.full to turn a constant into a scalar tensor of the expected dtype

tl.full([], mul, tl.int64)

peterbell10 avatar Jan 15 '24 17:01 peterbell10