jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax.debug.breakpoint` crashes in a hard-to-describe way.

Open patrick-kidger opened this issue 11 months ago • 2 comments

Description

This:

import jax

@jax.jit
def brk():
    jax.debug.breakpoint()

def fn():
    x0 = jax.numpy.zeros(2)
    brk()

jax.eval_shape(fn)
fn()

produces:

jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.

I think what's happenin is something like the following. The eval_shape call means that x0 is a tracer. Then jax.debug.breakpoint grabs x0 from the stack frame above it. Then jax.jit saves it as a constant of its jaxpr. This means that JAX itself leaks the tracer.

Tagging @sharadmv.

(Also, it doesn't suffice to just stop grabbing stack frames once you leave the JIT'd region. Consider this variant:

import jax

@jax.jit
def brk():
    jax.debug.breakpoint()

def fn():
    x0 = jax.numpy.zeros(2)
    brk()

@jax.jit
def run():
    jax.eval_shape(fn)
    fn()

run()

)

What jax/jaxlib version are you using?

0.4.13

patrick-kidger avatar Jul 14 '23 16:07 patrick-kidger

I'm having a similar issue where the code runs fine without breakpoints but gives jax.errors.UnexpectedTracerError while trying to debug.

I tried checking where it could be coming form with jax.check_tracer_leaks but it points me to tracers that aren't leaking.

Let me know if I can help with details if this seems related.

DiegoRenner avatar Sep 08 '23 20:09 DiegoRenner

Hi folks. This problem is affecting me right now with jax==0.4.26. Solution aside, is there a workaround? I'm not sure how to debug my program otherwise...

cool-RR avatar Apr 29 '24 18:04 cool-RR