triton icon indicating copy to clipboard operation
triton copied to clipboard

Why not allow JITFunction as parameter to another JITFunction(high-order jit function)?

Open iclementine opened this issue 9 months ago • 0 comments

I notice that it is not possible to pass a JITFunction as the parameter to another JITFunction(just call it higher order JITFunction for now).

The code below is an example of a pointwise function _jit_function, whose operations to map to inputs elements are defined in scalar_fn.

# test_pointwise.py
import triton
import triton.language as tl
import torch

@triton.jit
def scalar_fn(x):
    return tl.log(1 + tl.exp(x))

@triton.jit
def _jit_function(
    in_ptr, o_ptr,
    size,
    tile_size: tl.constexpr,
):
    pid = tl.program_id(0)
    tid = pid * tile_size + tl.arange(0, tile_size)
    mask = tid < size

    input_ = tl.load(in_ptr + tid, mask=mask)
    out = scalar_fn(input_)
    tl.store(o_ptr + tid, out, mask=mask)

def _wrapper(x: torch.Tensor):
    out = torch.empty_like(x)
    size = out.numel()
    tile_size = 512
    grid = triton.cdiv(size, tile_size), 1, 1
    _jit_function[grid](
        x, out, size,
        tile_size=tile_size,
        num_warps=4,
    )
    return out

x = torch.randn((3, 4), device="cuda")
print(_wrapper(x))

The JITFunction scalar_fn can be called at the function body of another JITFunction(_jit_function). However, if I want to add the scalar function as a parameter to _jit_function.

# test_pointwise_high_order.py
import triton
import triton.language as tl
import torch

@triton.jit
def scalar_fn(x):
    return tl.log(1 + tl.exp(x))

@triton.jit
def _jit_function(
    in_ptr, o_ptr,
    size,
    f: tl.constexpr,
    tile_size: tl.constexpr,
):
    pid = tl.program_id(0)
    tid = pid * tile_size + tl.arange(0, tile_size)
    mask = tid < size

    input_ = tl.load(in_ptr + tid, mask=mask)
    out = f(input_)
    tl.store(o_ptr + tid, out, mask=mask)

def _wrapper(x: torch.Tensor):
    out = torch.empty_like(x)
    size = out.numel()
    tile_size = 512
    grid = triton.cdiv(size, tile_size), 1, 1
    _jit_function[grid](
        x, out, size,
        f=scalar_fn,
        tile_size=tile_size,
        num_warps=4,
    )
    return out

x = torch.randn((3, 4), device="cuda")
print(_wrapper(x))

I would get an error.

Traceback (most recent call last):
  File "test_pointwise_high_order.py", line 38, in <module>
    print(_wrapper(x))
  File "test_pointwise_high_order.py", line 29, in _wrapper
    _jit_function[grid](
  File "<path_to_site_packages>/triton/runtime/jit.py", line 508, in run
    raise TypeError(f"Callable constexpr at index {i} is not supported")
TypeError: Callable constexpr at index 3 is not supported

I notice that the restriction that parameters with typehint tl.constexpr (constant args) cannot be callable was first introduced in https://github.com/triton-lang/triton/pull/644.

https://github.com/triton-lang/triton/blob/cfa8d18b835b10fc48449924aadf5982ac10d87c/python/triton/runtime/jit.py#L268C1-L270C79

      # build stub signature -- includes arguments that are specialized
      for i, arg in constants.items():
        if callable(arg):
          raise TypeError(f"Callable constexpr at index {i} is not supported")

However, commenting relevant lines make it just work as expected.

I think JITFunctions passed to another JITFunction is always a constexpr(in the sense that it is not a parameter to the CompiledKernel). And recent PRs has include the JITFunction arguments passed to tl.reduce or tl.associative_scan in the function body into the cache_key(https://github.com/triton-lang/triton/pull/3137). Then here is my question:

Why only allow triton builtins like tl.reduce and tl.associative_scan to have JITFunction as arguments while disallowing JITFunctions to have JITFunctions as constant arguments? It would make high order jit functions easier with triton.

Assume in a case that I want to make a utility that generate a _warpper function & a _jit_function (like the example above) given a _scalar_fn. Due to the fact that triton's JITFunction relies on inspect to get the source of the function that triton.jit decorates, I have to write the code into a file and dynamically imports it(via importlib or some kind of exec).

However, since JITFunction cannot take another JITFunction as argument, there is no high order JITFunction. So if I wan to generate a JITFunction that calls the user-provided scalar function, I have to make the scalar function available in the module where the generated _jit_function is defined. If I try to find where _scalar_fn is defined and try to generate code to import it into the generated code, it would be difficult since _scalar_fn may be defined in a module whose path is known, or just in a script(__name__ ==__main__) or interactive repl.

The problem for allowing JITFunction as parameters to another JITFunction I could think of is that the cache key of a JITFunction would have to be complete at run, instead of just inspecting the body of the JITFunction itself, which can be done at __init__.

The relation of high order JITFunction and non-high-order JITFunction resembles that of function templates and functions in cpp. When a high order JITFunction has all the JITFunction parameters specifiecd, it is complete(instantiated) in some sense.

So maybe it would be convenient to add an interface to instantiate a high order JITFunction into a non-high-order JITFunction to avoid analysing the cache key of a high order JITFunction at each run? @ptillet

Thank you

iclementine avatar May 15 '24 04:05 iclementine