triton
triton copied to clipboard
Why not allow JITFunction as parameter to another JITFunction(high-order jit function)?
I notice that it is not possible to pass a JITFunction as the parameter to another JITFunction(just call it higher order JITFunction for now).
The code below is an example of a pointwise function _jit_function
, whose operations to map to inputs elements are defined in scalar_fn
.
# test_pointwise.py
import triton
import triton.language as tl
import torch
@triton.jit
def scalar_fn(x):
return tl.log(1 + tl.exp(x))
@triton.jit
def _jit_function(
in_ptr, o_ptr,
size,
tile_size: tl.constexpr,
):
pid = tl.program_id(0)
tid = pid * tile_size + tl.arange(0, tile_size)
mask = tid < size
input_ = tl.load(in_ptr + tid, mask=mask)
out = scalar_fn(input_)
tl.store(o_ptr + tid, out, mask=mask)
def _wrapper(x: torch.Tensor):
out = torch.empty_like(x)
size = out.numel()
tile_size = 512
grid = triton.cdiv(size, tile_size), 1, 1
_jit_function[grid](
x, out, size,
tile_size=tile_size,
num_warps=4,
)
return out
x = torch.randn((3, 4), device="cuda")
print(_wrapper(x))
The JITFunction scalar_fn
can be called at the function body of another JITFunction(_jit_function
). However, if I want to add the scalar function as a parameter to _jit_function
.
# test_pointwise_high_order.py
import triton
import triton.language as tl
import torch
@triton.jit
def scalar_fn(x):
return tl.log(1 + tl.exp(x))
@triton.jit
def _jit_function(
in_ptr, o_ptr,
size,
f: tl.constexpr,
tile_size: tl.constexpr,
):
pid = tl.program_id(0)
tid = pid * tile_size + tl.arange(0, tile_size)
mask = tid < size
input_ = tl.load(in_ptr + tid, mask=mask)
out = f(input_)
tl.store(o_ptr + tid, out, mask=mask)
def _wrapper(x: torch.Tensor):
out = torch.empty_like(x)
size = out.numel()
tile_size = 512
grid = triton.cdiv(size, tile_size), 1, 1
_jit_function[grid](
x, out, size,
f=scalar_fn,
tile_size=tile_size,
num_warps=4,
)
return out
x = torch.randn((3, 4), device="cuda")
print(_wrapper(x))
I would get an error.
Traceback (most recent call last):
File "test_pointwise_high_order.py", line 38, in <module>
print(_wrapper(x))
File "test_pointwise_high_order.py", line 29, in _wrapper
_jit_function[grid](
File "<path_to_site_packages>/triton/runtime/jit.py", line 508, in run
raise TypeError(f"Callable constexpr at index {i} is not supported")
TypeError: Callable constexpr at index 3 is not supported
I notice that the restriction that parameters with typehint tl.constexpr
(constant args) cannot be callable was first introduced in
https://github.com/triton-lang/triton/pull/644.
https://github.com/triton-lang/triton/blob/cfa8d18b835b10fc48449924aadf5982ac10d87c/python/triton/runtime/jit.py#L268C1-L270C79
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
However, commenting relevant lines make it just work as expected.
I think JITFunctions passed to another JITFunction is always a constexpr(in the sense that it is not a parameter to the CompiledKernel). And recent PRs has include the JITFunction arguments passed to tl.reduce
or tl.associative_scan
in the function body into the cache_key(https://github.com/triton-lang/triton/pull/3137). Then here is my question:
Why only allow triton builtins like tl.reduce
and tl.associative_scan
to have JITFunction
as arguments while disallowing JITFunctions to have JITFunctions as constant arguments? It would make high order jit functions
easier with triton.
Assume in a case that I want to make a utility that generate a _warpper
function & a _jit_function
(like the example above) given a _scalar_fn
. Due to the fact that triton's JITFunction relies on inspect
to get the source of the function that triton.jit
decorates, I have to write the code into a file and dynamically imports it(via importlib or some kind of exec
).
However, since JITFunction cannot take another JITFunction as argument, there is no high order JITFunction. So if I wan to generate a JITFunction that calls the user-provided scalar function, I have to make the scalar function available in the module where the generated _jit_function
is defined. If I try to find where _scalar_fn
is defined and try to generate code to import it into the generated code, it would be difficult since _scalar_fn
may be defined in a module whose path is known, or just in a script(__name__ ==__main__
) or interactive repl.
The problem for allowing JITFunction as parameters to another JITFunction I could think of is that the cache key of a JITFunction would have to be complete at run
, instead of just inspecting the body of the JITFunction itself, which can be done at __init__
.
The relation of high order JITFunction and non-high-order JITFunction resembles that of function templates and functions in cpp. When a high order JITFunction has all the JITFunction parameters specifiecd, it is complete(instantiated) in some sense.
So maybe it would be convenient to add an interface to instantiate
a high order JITFunction into a non-high-order JITFunction to avoid analysing the cache key of a high order JITFunction at each run
? @ptillet
Thank you