jax
jax copied to clipboard
`jax.debug.breakpoint` crashes in a hard-to-describe way.
Description
This:
import jax
@jax.jit
def brk():
jax.debug.breakpoint()
def fn():
x0 = jax.numpy.zeros(2)
brk()
jax.eval_shape(fn)
fn()
produces:
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
I think what's happenin is something like the following. The eval_shape
call means that x0
is a tracer. Then jax.debug.breakpoint
grabs x0
from the stack frame above it. Then jax.jit
saves it as a constant of its jaxpr. This means that JAX itself leaks the tracer.
Tagging @sharadmv.
(Also, it doesn't suffice to just stop grabbing stack frames once you leave the JIT'd region. Consider this variant:
import jax
@jax.jit
def brk():
jax.debug.breakpoint()
def fn():
x0 = jax.numpy.zeros(2)
brk()
@jax.jit
def run():
jax.eval_shape(fn)
fn()
run()
)
What jax/jaxlib version are you using?
0.4.13
I'm having a similar issue where the code runs fine without breakpoints but gives jax.errors.UnexpectedTracerError while trying to debug.
I tried checking where it could be coming form with jax.check_tracer_leaks but it points me to tracers that aren't leaking.
Let me know if I can help with details if this seems related.
Hi folks. This problem is affecting me right now with jax==0.4.26
. Solution aside, is there a workaround? I'm not sure how to debug my program otherwise...