jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.debug.breakpoint gives UnexpectedTracerError when used with jax.lax.cond

Open billmark opened this issue 5 months ago • 3 comments

Description

import jax
import jax.numpy as jnp

def f(x, example):
  jax.lax.cond(example == 1, jax.debug.breakpoint, lambda *args: None)
  return x

f_vmap = jax.vmap(f, in_axes=(0, None), out_axes=0)

def g(x, example):
  return f_vmap(x, example)

x = jnp.arange(4)
example = jnp.array(0, dtype=jnp.int32)
g(x, example)
example = jnp.array(1, dtype=jnp.int32)
g(x, example)

Gives the error

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([0, 1, 2, 3], dtype=int32)
  batch_dim = 0, BatchTrace(level=1/0)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

But I believe this code should work, and Jax expert Jake VanderPlas has confirmed that I seem to have uncovered a bug, and asked me to file this github bug. Jake said that he had to turn on JAX_CHECK_TRACER_LEAKS=1 to observe the problem, but that was with a slightly different repro case. I did not have to do this (maybe it's already on in my environment?).

Additional context: The purpose of this code is to enter the debugger on a particular train step, so that I can examine variables inside of f() at that train step.

System info (python version, jaxlib version, accelerator, etc.)

Python: 3.11 JAX: Top of tree inside google as of 3pm Pacific Time on Sept 10, 2024. Accelerator: TPU

billmark avatar Sep 10 '24 22:09 billmark