Running a warp kernel from Jax JIT-compiled function
Hi! I'm using Warp as a part of a project that uses JIT-compiled jax functions. How could I integrate Warp in this context?
Example:
@jax.jit
def function(x):
y = x + 1
wp.launch(...) # Modify y with a warp kernel
return y
I'm looking for something like what can be done with Triton in the example in this README: https://github.com/jax-ml/jax-triton
Hi @marcelroed, good question, I think this should be possible. Let me ask our JAX experts to see if they can offer some advice.
Best, Miles
Hi,
There isn't any easy way to do this right now. To be able to use it in JAX without jit, you could make a JAX primitive in Python. But to get this to work with JIT, you will need to develop C++ code for this.
JAX-Triton already did that: https://github.com/jax-ml/jax-triton/blob/main/lib/triton_kernel_call.cc No one have done a JAX-Warp yet.
Hey @nouiz! thanks for looking into this! Do you think it would be reasonably straightforward to write something similar that's compatible with Warp? I could have a go at contributing it if so.
I don't think it is trivial. But it could be mechanical. That part of the code isn't simple. You can base yourself on the jax-triton code. This jax custom operation should help you: https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html
Mostly, we need to dynamically create one such custom operation. I can help answer questions is you work on it. But don't expect this to be trivial. I can also find someone if there is questions on the warp side of thing.
@nouiz , @marcelroed would it help to call precompiled Warp kernels from C++ like this? https://github.com/erwincoumans/warp_cpp
@erwincoumans, that looks like it'll be great for my use-case for sure, thank you! I would like to try to get this supported entirely from Python, however. Will be looking into contributing it this weekend.
@erwincoumans, that looks like it'll be great for my use-case for sure, thank you! I would like to try to get this supported entirely from Python, however. Will be looking into contributing it this weekend.
You could write PyBind11 Python bindings for it?
Following
Good news: We have a working implementation of this that doesn't require fiddling with any native bits. The basic usage looks like this:
import warp as wp
import jax
import jax.numpy as jp
# secret sauce
from warp.jax_experimental import jax_kernel
@wp.kernel
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
tid = wp.tid()
output[tid] = 3.0 * input[tid]
wp.init()
# create a Jax primitive from a Warp kernel
jax_triple = jax_kernel(triple_kernel)
# use the Warp kernel in a Jax jitted function
@jax.jit
def f():
x = jp.arange(0, 64, dtype=jp.float32)
return jax_triple(x)
print(f())
The feature is still experimental and there are some limitations, but we plan to release it shortly because it can already be quite useful.