jax
jax copied to clipboard
jax.debug.breakpoint gives UnexpectedTracerError when used with jax.lax.cond
Description
import jax
import jax.numpy as jnp
def f(x, example):
jax.lax.cond(example == 1, jax.debug.breakpoint, lambda *args: None)
return x
f_vmap = jax.vmap(f, in_axes=(0, None), out_axes=0)
def g(x, example):
return f_vmap(x, example)
x = jnp.arange(4)
example = jnp.array(0, dtype=jnp.int32)
g(x, example)
example = jnp.array(1, dtype=jnp.int32)
g(x, example)
Gives the error
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
val = Array([0, 1, 2, 3], dtype=int32)
batch_dim = 0, BatchTrace(level=1/0)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
But I believe this code should work, and Jax expert Jake VanderPlas has confirmed that I seem to have uncovered a bug, and asked me to file this github bug. Jake said that he had to turn on JAX_CHECK_TRACER_LEAKS=1 to observe the problem, but that was with a slightly different repro case. I did not have to do this (maybe it's already on in my environment?).
Additional context: The purpose of this code is to enter the debugger on a particular train step, so that I can examine variables inside of f() at that train step.
System info (python version, jaxlib version, accelerator, etc.)
Python: 3.11 JAX: Top of tree inside google as of 3pm Pacific Time on Sept 10, 2024. Accelerator: TPU