numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

get_model_relations and get_dependencies give UnexpectedTracerError on seeded models

Open danielward27 opened this issue 4 months ago • 1 comments

I'm assuming these functions are expected to work with seeded models, but they yield UnexpectedTracerErrors.

import numpyro
import jax
import jax.random as jr
import numpyro.distributions as dist
from numpyro.infer.inspect import get_model_relations
from numpyro import handlers

def model():
    m = numpyro.sample('m', dist.Normal(0, 1))
    numpyro.sample('sd', dist.LogNormal(m, 1))

seeded_model = handlers.seed(model, jr.key(0))

with jax.checking_leaks():
    get_model_relations(seeded_model)  # UnexpectedTracerError, same for get_dependencies

The traceback is

Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/home/dw16200/miniconda3/envs/softcvi_env/lib/python3.12/site-packages/numpyro/infer/inspect.py", line 323, in get_model_relations
    trace = jax.eval_shape(get_trace).trace
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dw16200/miniconda3/envs/softcvi_env/lib/python3.12/contextlib.py", line 144, in __exit__
    next(self.gen)
^^^^^^^^^^^^^^^^^^
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(key<fry>[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line <stdin>:3:4 (model)
<DynamicJaxprTracer 129514377205776> is referred to by <seed 129515008443408>.rng_key
<seed 129515008443408> is referred to by __main__.seeded_model

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

danielward27 avatar Oct 10 '24 17:10 danielward27