warp icon indicating copy to clipboard operation
warp copied to clipboard

Running a warp kernel from Jax JIT-compiled function

Open marcelroed opened this issue 2 years ago • 9 comments

Hi! I'm using Warp as a part of a project that uses JIT-compiled jax functions. How could I integrate Warp in this context?

Example:

@jax.jit
def function(x):
    y = x + 1
    wp.launch(...)  # Modify y with a warp kernel
    return y

I'm looking for something like what can be done with Triton in the example in this README: https://github.com/jax-ml/jax-triton

marcelroed avatar Apr 23 '23 22:04 marcelroed

Hi @marcelroed, good question, I think this should be possible. Let me ask our JAX experts to see if they can offer some advice.

Best, Miles

mmacklin avatar Apr 25 '23 00:04 mmacklin

Hi,

There isn't any easy way to do this right now. To be able to use it in JAX without jit, you could make a JAX primitive in Python. But to get this to work with JIT, you will need to develop C++ code for this.

JAX-Triton already did that: https://github.com/jax-ml/jax-triton/blob/main/lib/triton_kernel_call.cc No one have done a JAX-Warp yet.

nouiz avatar Apr 25 '23 12:04 nouiz

Hey @nouiz! thanks for looking into this! Do you think it would be reasonably straightforward to write something similar that's compatible with Warp? I could have a go at contributing it if so.

marcelroed avatar Apr 25 '23 15:04 marcelroed

I don't think it is trivial. But it could be mechanical. That part of the code isn't simple. You can base yourself on the jax-triton code. This jax custom operation should help you: https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html

Mostly, we need to dynamically create one such custom operation. I can help answer questions is you work on it. But don't expect this to be trivial. I can also find someone if there is questions on the warp side of thing.

nouiz avatar Apr 28 '23 14:04 nouiz

@nouiz , @marcelroed would it help to call precompiled Warp kernels from C++ like this? https://github.com/erwincoumans/warp_cpp

erwincoumans avatar May 04 '23 00:05 erwincoumans

@erwincoumans, that looks like it'll be great for my use-case for sure, thank you! I would like to try to get this supported entirely from Python, however. Will be looking into contributing it this weekend.

marcelroed avatar May 04 '23 00:05 marcelroed

@erwincoumans, that looks like it'll be great for my use-case for sure, thank you! I would like to try to get this supported entirely from Python, however. Will be looking into contributing it this weekend.

You could write PyBind11 Python bindings for it?

erwincoumans avatar May 04 '23 00:05 erwincoumans

Following

tfsingh avatar May 06 '23 18:05 tfsingh

Good news: We have a working implementation of this that doesn't require fiddling with any native bits. The basic usage looks like this:

import warp as wp
import jax
import jax.numpy as jp

# secret sauce
from warp.jax_experimental import jax_kernel

@wp.kernel
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = 3.0 * input[tid]

wp.init()

# create a Jax primitive from a Warp kernel
jax_triple = jax_kernel(triple_kernel)

# use the Warp kernel in a Jax jitted function
@jax.jit
def f():
    x = jp.arange(0, 64, dtype=jp.float32)
    return jax_triple(x)

print(f())

The feature is still experimental and there are some limitations, but we plan to release it shortly because it can already be quite useful.

nvlukasz avatar Mar 05 '24 19:03 nvlukasz