[BUG] Calling jit on wrapped jax_kernel ignores atomic_add
Bug Description
I want to access an input array that has been modified by a warp kernel from jax, but I'm seeing unexpected behavior.
For example, here is a simple warp kernel that uses atomic_add to sum an array a and store the result in a 1-element array a_sum:
import jax.numpy as jnp
import warp as wp
from jax import jit
from warp.jax_experimental.ffi import jax_kernel
@wp.kernel
def sum_kernel(
a: wp.array(dtype=wp.float32),
a_sum: wp.array(dtype=wp.float32),
dummy: wp.array(dtype=wp.float32), # jax_kernel needs num_outputs > 0
):
tidx = wp.tid()
wp.atomic_add(a_sum, 0, a[tidx])
dummy[tidx] = a[tidx]
a = jnp.array([1, 2, 3], dtype=jnp.float32) # Example data
# Create a jax_kernel with FFI
sum_fn = jax_kernel(sum_kernel)
# First call to the function works fine
a_sum = jnp.zeros(1, jnp.float32)
_ = sum_fn(a, a_sum)
print(a_sum) # [6.] Correct as expected
# Subsequent call does not build on previous value of a_sum
_ = sum_fn(a, a_sum)
print(a_sum) # [6.] Incorrect! Should be 12
# Touching the value of a_sum does something
a_sum = a_sum + 0
print(a_sum) # [12.] Suddenly correct!
I'm also seeing weird behavior when wrapping the jax_kernel in a jitted function:
def wrapped_sum(a):
a_sum = jnp.zeros(1, jnp.float32)
_ = sum_fn(a, a_sum)
return a_sum
print(wrapped_sum(a)) # [6.] Correct as expected
print(jit(wrapped_sum)(a)) # [0.] Incorrect!
Maybe it's related to https://github.com/NVIDIA/warp/issues/378?
System Information
Warp 1.7.1, CUDA Toolkit 12.8, Python 3.13.2
Thanks for reporting this and nice to hear from you again. I don't think we'll be able to look at this immediately since we're focused on the Newton release for the next few weeks, but this issue is noted.
Thanks for the reply! No problem at all.
@dongwoonhyun I believe this is not a Warp issue. Unlike NumPy arrays, JAX arrays are always immutable. Modifying a_sum in warp kernel breaks the immutability of JAX arrays.
I think we could add some additional checks to raise errors and alerts users when they attempt to break immutability in a jax_callable or jax_kernel.
As for the output of jit(wrapped_sum), I believe this is entirely expected behavior.
Quoting from 🔪 JAX - The Sharp Bits 🔪 — JAX documentation:
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results.
Quoting from External callbacks — JAX documentation:
By design functions passed to
pure_callbackare treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:
The output of sum_fn(a, a_sum) is assigned to _ and not used in final output. So the function call to sum_fn() is dropped by JIT.
@dongwoonhyun this should be fixed with the support for in-out arguments we added in https://github.com/NVIDIA/warp/issues/815.
You can include the a_sum argument in the in_out_argnames list and drop the dummy arg:
@wp.kernel
def sum_kernel(
a: wp.array(dtype=wp.float32),
a_sum: wp.array(dtype=wp.float32),
):
tidx = wp.tid()
wp.atomic_add(a_sum, 0, a[tidx])
a = jnp.array([1, 2, 3], dtype=jnp.float32) # Example data
# Create a jax_kernel with FFI
sum_fn = jax_kernel(sum_kernel, in_out_argnames=["a_sum"])
# First call to the function works fine
a_sum = jnp.zeros(1, jnp.float32)
a_sum, = sum_fn(a, a_sum)
print(a_sum) # [6.] Correct as expected
# Subsequent call does not build on previous value of a_sum
a_sum, = sum_fn(a, a_sum)
print(a_sum) # [12.] Also correct
# Touching the value of a_sum does something
a_sum = a_sum + 0
print(a_sum) # [12.] Still correct
Note that you need to re-assign a_sum when the function returns, otherwise just printing a_sum won't work. As @liblaf mentioned above, JAX arrays are immutable, so JAX makes a copy that is returned from the function.
The jitted variant should work too:
def wrapped_sum(a):
a_sum = jnp.zeros(1, jnp.float32)
a_sum, = sum_fn(a, a_sum)
return a_sum
print(wrapped_sum(a)) # [6.] Correct as expected
print(jit(wrapped_sum)(a)) # [6.] Also correct