warp icon indicating copy to clipboard operation
warp copied to clipboard

[BUG] Calling jit on wrapped jax_kernel ignores atomic_add

Open dongwoonhyun opened this issue 7 months ago • 2 comments

Bug Description

I want to access an input array that has been modified by a warp kernel from jax, but I'm seeing unexpected behavior.

For example, here is a simple warp kernel that uses atomic_add to sum an array a and store the result in a 1-element array a_sum:

import jax.numpy as jnp
import warp as wp
from jax import jit
from warp.jax_experimental.ffi import jax_kernel

@wp.kernel
def sum_kernel(
    a: wp.array(dtype=wp.float32),
    a_sum: wp.array(dtype=wp.float32),
    dummy: wp.array(dtype=wp.float32),  # jax_kernel needs num_outputs > 0
):
    tidx = wp.tid()
    wp.atomic_add(a_sum, 0, a[tidx])
    dummy[tidx] = a[tidx]

a = jnp.array([1, 2, 3], dtype=jnp.float32)  # Example data

# Create a jax_kernel with FFI
sum_fn = jax_kernel(sum_kernel)

# First call to the function works fine
a_sum = jnp.zeros(1, jnp.float32)
_ = sum_fn(a, a_sum)
print(a_sum)  # [6.]  Correct as expected

# Subsequent call does not build on previous value of a_sum
_ = sum_fn(a, a_sum)
print(a_sum)  # [6.]  Incorrect! Should be 12

# Touching the value of a_sum does something
a_sum = a_sum + 0
print(a_sum)  # [12.]  Suddenly correct!

I'm also seeing weird behavior when wrapping the jax_kernel in a jitted function:

def wrapped_sum(a):
    a_sum = jnp.zeros(1, jnp.float32)
    _ = sum_fn(a, a_sum)
    return a_sum

print(wrapped_sum(a))  # [6.]  Correct as expected
print(jit(wrapped_sum)(a))  # [0.]  Incorrect!

Maybe it's related to https://github.com/NVIDIA/warp/issues/378?

System Information

Warp 1.7.1, CUDA Toolkit 12.8, Python 3.13.2

dongwoonhyun avatar May 21 '25 01:05 dongwoonhyun

Thanks for reporting this and nice to hear from you again. I don't think we'll be able to look at this immediately since we're focused on the Newton release for the next few weeks, but this issue is noted.

shi-eric avatar May 22 '25 16:05 shi-eric

Thanks for the reply! No problem at all.

dongwoonhyun avatar May 22 '25 16:05 dongwoonhyun

@dongwoonhyun I believe this is not a Warp issue. Unlike NumPy arrays, JAX arrays are always immutable. Modifying a_sum in warp kernel breaks the immutability of JAX arrays.

I think we could add some additional checks to raise errors and alerts users when they attempt to break immutability in a jax_callable or jax_kernel.

As for the output of jit(wrapped_sum), I believe this is entirely expected behavior.

Quoting from 🔪 JAX - The Sharp Bits 🔪 — JAX documentation:

JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results.

Quoting from External callbacks — JAX documentation:

By design functions passed to pure_callback are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:

The output of sum_fn(a, a_sum) is assigned to _ and not used in final output. So the function call to sum_fn() is dropped by JIT.

liblaf avatar Aug 01 '25 17:08 liblaf

@dongwoonhyun this should be fixed with the support for in-out arguments we added in https://github.com/NVIDIA/warp/issues/815.

You can include the a_sum argument in the in_out_argnames list and drop the dummy arg:

@wp.kernel
def sum_kernel(
    a: wp.array(dtype=wp.float32),
    a_sum: wp.array(dtype=wp.float32),
):
    tidx = wp.tid()
    wp.atomic_add(a_sum, 0, a[tidx])

a = jnp.array([1, 2, 3], dtype=jnp.float32)  # Example data

# Create a jax_kernel with FFI
sum_fn = jax_kernel(sum_kernel, in_out_argnames=["a_sum"])

# First call to the function works fine
a_sum = jnp.zeros(1, jnp.float32)
a_sum, = sum_fn(a, a_sum)
print(a_sum)  # [6.]  Correct as expected

# Subsequent call does not build on previous value of a_sum
a_sum, = sum_fn(a, a_sum)
print(a_sum)  # [12.]  Also correct

# Touching the value of a_sum does something
a_sum = a_sum + 0
print(a_sum)  # [12.]  Still correct

Note that you need to re-assign a_sum when the function returns, otherwise just printing a_sum won't work. As @liblaf mentioned above, JAX arrays are immutable, so JAX makes a copy that is returned from the function.

The jitted variant should work too:

def wrapped_sum(a):
    a_sum = jnp.zeros(1, jnp.float32)
    a_sum, = sum_fn(a, a_sum)
    return a_sum

print(wrapped_sum(a))  # [6.]  Correct as expected
print(jit(wrapped_sum)(a))  # [6.]  Also correct

nvlukasz avatar Sep 17 '25 22:09 nvlukasz