numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Jax throws error when tracing auto guide after passing it to get_model_relations

Open mochar opened this issue 5 months ago • 4 comments

Bug Description

I make an auto guide with AutoNormal and pass it to get_model_relations. Next I attempt to create a trace from that guide. Jax then throws an error complaining about side effects:

Error + traceback
---------------------------------------------------------------------------
UnexpectedTracerError                     Traceback (most recent call last)
Cell In[10], line 1
----> 1 handlers.trace(handlers.seed(guide, 0)).get_trace()

File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:191, in trace.get_trace(self, *args, **kwargs)
    183 def get_trace(self, *args, **kwargs) -> OrderedDict[str, Message]:
    184     """
    185     Run the wrapped callable and return the recorded trace.
    186 
   (...)    189     :return: `OrderedDict` containing the execution trace.
    190     """
--> 191     self(*args, **kwargs)
    192     return self.trace

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:846, in seed.__call__(self, *args, **kwargs)
    842     cloned_seeded_fn = seed(
    843         self.fn, rng_seed=self.rng_key, hide_types=self.hide_types
    844     )
    845     cloned_seeded_fn.stateful = True
--> 846     return cloned_seeded_fn.__call__(*args, **kwargs)
    847 return super().__call__(*args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:847, in seed.__call__(self, *args, **kwargs)
    845     cloned_seeded_fn.stateful = True
    846     return cloned_seeded_fn.__call__(*args, **kwargs)
--> 847 return super().__call__(*args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py:440, in AutoNormal.__call__(self, *args, **kwargs)
    435 site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
    436 if site["fn"].support is constraints.real or (
    437     isinstance(site["fn"].support, constraints.independent)
    438     and site["fn"].support.base_constraint is constraints.real
    439 ):
--> 440     result[name] = numpyro.sample(name, site_fn)
    441 else:
    442     with helpful_support_errors(site):

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:250, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    235 initial_msg = {
    236     "type": "sample",
    237     "name": name,
   (...)    246     "infer": {} if infer is None else infer,
    247 }
    249 # ...and use apply_stack to send it to the Messengers
--> 250 msg = apply_stack(initial_msg)
    251 return msg["value"]

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:61, in apply_stack(msg)
     58     if msg.get("stop"):
     59         break
---> 61 default_process_message(msg)
     63 # A Messenger that sets msg["stop"] == True also prevents application
     64 # of postprocess_message by Messengers above it on the stack
     65 # via the pointer variable from the process_message loop
     66 for handler in _PYRO_STACK[-pointer - 1 :]:

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:32, in default_process_message(msg)
     30 if msg["value"] is None:
     31     if msg["type"] == "sample":
---> 32         msg["value"], msg["intermediates"] = msg["fn"](
     33             *msg["args"], sample_intermediates=True, **msg["kwargs"]
     34         )
     35     else:
     36         msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])

File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/distribution.py:393, in Distribution.__call__(self, *args, **kwargs)
    391 sample_intermediates = kwargs.pop("sample_intermediates", False)
    392 if sample_intermediates:
--> 393     return self.sample_with_intermediates(key, *args, **kwargs)
    394 return self.sample(key, *args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/distribution.py:351, in Distribution.sample_with_intermediates(self, key, sample_shape)
    341 def sample_with_intermediates(self, key, sample_shape=()):
    342     """
    343     Same as ``sample`` except that any intermediate computations are
    344     returned (useful for `TransformedDistribution`).
   (...)    349     :rtype: numpy.ndarray
    350     """
--> 351     return self.sample(key, sample_shape=sample_shape), []

File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/continuous.py:2198, in Normal.sample(self, key, sample_shape)
   2194 assert is_prng_key(key)
   2195 eps = random.normal(
   2196     key, shape=sample_shape + self.batch_shape + self.event_shape
   2197 )
-> 2198 return self.loc + eps * self.scale

File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py:1083, in _forward_operator_to_aval.<locals>.op(self, *args)
   1082 def op(self, *args):
-> 1083   return getattr(self.aval, f"_{name}")(self, *args)

File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py:583, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    581 args = (other, self) if swap else (self, other)
    582 if isinstance(other, _accepted_binop_types):
--> 583   return binary_op(*args)
    584 # Note: don't use isinstance here, because we don't want to raise for
    585 # subclasses, e.g. NamedTuple objects that may override operators.
    586 if type(other) in _rejected_binop_types:

File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/ufunc_api.py:182, in ufunc.__call__(self, out, where, *args)
    180   raise NotImplementedError(f"where argument of {self}")
    181 call = self.__static_props['call'] or self._call_vectorized
--> 182 return call(*args)

    [... skipping hidden 3 frame]

File ~/.venv/lib/python3.13/site-packages/jax/_src/core.py:1053, in check_eval_args(args)
   1051 for arg in args:
   1052   if isinstance(arg, Tracer):
-> 1053     raise escaped_tracer_error(arg)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was get_trace at /home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/inspect.py:307 traced for jit.
------------------------------
The leaked intermediate value was created on line /home/mochar/.venv/lib/python3.13/site-packages/numpyro/util.py:141:15 (while_loop). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py:160:16 (AutoGuide._setup_prototype)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:750:40 (initialize_model)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:472:46 (find_valid_initial_params)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:465:52 (find_valid_initial_params.<locals>._find_valid_params)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/util.py:141:15 (while_loop)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.UnexpectedTracerError

I do not encounter this problem when making my own guide instead of using AutoGuide.

Steps to Reproduce

import numpyro
from numpyro import distributions as dist
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.inspect import get_model_relations

def model():
    numpyro.sample('a', dist.Normal())
    
guide = AutoNormal(model)
relations = get_model_relations(guide)
handlers.trace(handlers.seed(guide, 0)).get_trace()

mochar avatar Aug 14 '25 05:08 mochar

It looks like there's a NameError: name 'dist' is not defined in the traceback. Could you check if fixing that import resolves the error?

tillahoffmann avatar Aug 14 '25 14:08 tillahoffmann

@tillahoffmann My bad that's an error I got when setting up the example, I accidently copied it over. I've removed it now.

mochar avatar Aug 14 '25 18:08 mochar

Thanks for sharing! JAX sometimes complains about leaked tracer a while after the leak actually happened. It turns out that get_relations already leaks tracers as illustrated by the following.

I'll have a think. My hunch is that the jax.eval_shape is traced and sets an attribute on the autoguide, i.e., get_relations is not pure. This might be similar to https://github.com/google/flax/issues/4520.

import jax
import numpyro
from numpyro import distributions as dist
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.inspect import get_model_relations


def model():
    numpyro.sample('a', dist.Normal())

guide = AutoNormal(model)

with jax.check_tracer_leaks():
    relations = get_model_relations(guide)
Details
Traceback (most recent call last):
  File "/Users/till/git/numpyro/playground/issue_2062.py", line 14, in <module>
    relations = get_model_relations(guide)
  File "/Users/till/git/numpyro/numpyro/infer/inspect.py", line 326, in get_model_relations
    trace = jax.eval_shape(get_trace).trace
            ~~~~~~~~~~~~~~^^^^^^^^^^^
  File "/Users/till/.local/share/uv/python/cpython-3.13.3-macos-aarch64-none/lib/python3.13/contextlib.py", line 149, in __exit__
    except StopIteration:
        return False
Exception: Leaked trace DynamicJaxprTrace. Leaked tracer(s):

JitTracer<uint32[2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545092688> is referred to by <dict 4545289664>['rng_key']
<dict 4545289664> is referred to by <dict 4545291072>['kwargs']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<uint32[1,2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545092576> is referred to by <list 4545291392>[0]
<list 4545291392> is referred to by <TracingEqn 4544973904>.in_tracers
<TracingEqn 4544973904> is referred to by <DynamicJaxprTracer 4545092688>
<DynamicJaxprTracer 4545092688> is referred to by <dict 4545289664>['rng_key']
<dict 4545289664> is referred to by <dict 4545291072>['kwargs']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<uint32[2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545093584> is referred to by <seed 4544784560>.rng_key
<seed 4544784560> is referred to by <AutoNormal 4545055104>._postprocess_fn
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<key<fry>[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545097168> is referred to by <list 4545416960>[0]
<list 4545416960> is referred to by <TracingEqn 4545279696>.in_tracers
<TracingEqn 4545279696> is referred to by <DynamicJaxprTracer 4545093584>
<DynamicJaxprTracer 4545093584> is referred to by <seed 4544784560>.rng_key
<seed 4544784560> is referred to by <AutoNormal 4545055104>._postprocess_fn
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<float32[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545093248> is referred to by <dict 4545291072>['value']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<float32[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<uint32[2,2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545092240> is referred to by <list 4545293568>[0]
<list 4545293568> is referred to by <TracingEqn 4545170352>.in_tracers
<TracingEqn 4545170352> is referred to by <DynamicJaxprTracer 4545092576>
<DynamicJaxprTracer 4545092576> is referred to by <list 4545291392>[0]
<list 4545291392> is referred to by <TracingEqn 4544973904>.in_tracers
<TracingEqn 4544973904> is referred to by <DynamicJaxprTracer 4545092688>
<DynamicJaxprTracer 4545092688> is referred to by <dict 4545289664>['rng_key']
<dict 4545289664> is referred to by <dict 4545291072>['kwargs']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<int32[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545095936> is referred to by <list 4545317824>[0]
<list 4545317824> is referred to by <TracingEqn 4545281872>.in_tracers
<TracingEqn 4545281872> is referred to by <DynamicJaxprTracer 4545097168>
<DynamicJaxprTracer 4545097168> is referred to by <list 4545416960>[0]
<list 4545416960> is referred to by <TracingEqn 4545279696>.in_tracers
<TracingEqn 4545279696> is referred to by <DynamicJaxprTracer 4545093584>
<DynamicJaxprTracer 4545093584> is referred to by <seed 4544784560>.rng_key
<seed 4544784560> is referred to by <AutoNormal 4545055104>._postprocess_fn
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<~int32[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545092912> is referred to by <tuple 4545315200>[2]
<tuple 4545315200> is referred to by <TracingEqn 4545276624>.in_tracers
<TracingEqn 4545276624> is referred to by <DynamicJaxprTracer 4545093248>
<DynamicJaxprTracer 4545093248> is referred to by <dict 4545291072>['value']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<~int32[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545093136> is referred to by <tuple 4545315200>[1]
<tuple 4545315200> is referred to by <TracingEqn 4545276624>.in_tracers
<TracingEqn 4545276624> is referred to by <DynamicJaxprTracer 4545093248>
<DynamicJaxprTracer 4545093248> is referred to by <dict 4545291072>['value']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<key<fry>[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545092800> is referred to by <tuple 4545315200>[0]
<tuple 4545315200> is referred to by <TracingEqn 4545276624>.in_tracers
<TracingEqn 4545276624> is referred to by <DynamicJaxprTracer 4545093248>
<DynamicJaxprTracer 4545093248> is referred to by <dict 4545291072>['value']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<bool[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545096496> is referred to by <list 4545429056>[5]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<~float32[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545096272> is referred to by <list 4545429056>[3]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<uint32[2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545091904> is referred to by <list 4545429056>[1]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<~int32[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545095376> is referred to by <list 4545429056>[0]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<key<fry>[2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545092128> is referred to by <list 4545292608>[0]
<list 4545292608> is referred to by <TracingEqn 4544804048>.in_tracers
<TracingEqn 4544804048> is referred to by <DynamicJaxprTracer 4545092240>
<DynamicJaxprTracer 4545092240> is referred to by <list 4545293568>[0]
<list 4545293568> is referred to by <TracingEqn 4545170352>.in_tracers
<TracingEqn 4545170352> is referred to by <DynamicJaxprTracer 4545092576>
<DynamicJaxprTracer 4545092576> is referred to by <list 4545291392>[0]
<list 4545291392> is referred to by <TracingEqn 4544973904>.in_tracers
<TracingEqn 4544973904> is referred to by <DynamicJaxprTracer 4545092688>
<DynamicJaxprTracer 4545092688> is referred to by <dict 4545289664>['rng_key']
<dict 4545289664> is referred to by <dict 4545291072>['kwargs']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<uint32[1,2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545091792> is referred to by <list 4545285696>[0]
<list 4545285696> is referred to by <TracingEqn 4544745552>.in_tracers
<TracingEqn 4544745552> is referred to by <DynamicJaxprTracer 4545091904>
<DynamicJaxprTracer 4545091904> is referred to by <list 4545429056>[1]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<key<fry>[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545092016> is referred to by <list 4545292480>[0]
<list 4545292480> is referred to by <TracingEqn 4544803808>.in_tracers
<TracingEqn 4544803808> is referred to by <DynamicJaxprTracer 4545092128>
<DynamicJaxprTracer 4545092128> is referred to by <list 4545292608>[0]
<list 4545292608> is referred to by <TracingEqn 4544804048>.in_tracers
<TracingEqn 4544804048> is referred to by <DynamicJaxprTracer 4545092240>
<DynamicJaxprTracer 4545092240> is referred to by <list 4545293568>[0]
<list 4545293568> is referred to by <TracingEqn 4545170352>.in_tracers
<TracingEqn 4545170352> is referred to by <DynamicJaxprTracer 4545092576>
<DynamicJaxprTracer 4545092576> is referred to by <list 4545291392>[0]
<list 4545291392> is referred to by <TracingEqn 4544973904>.in_tracers
<TracingEqn 4544973904> is referred to by <DynamicJaxprTracer 4545092688>
<DynamicJaxprTracer 4545092688> is referred to by <dict 4545289664>['rng_key']
<dict 4545289664> is referred to by <dict 4545291072>['kwargs']
<dict 4545291072> is referred to by <OrderedDict 4544968000>['a']
<OrderedDict 4544968000> is referred to by <AutoNormal 4545055104>.prototype_trace
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<uint32[2,2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545091456> is referred to by <list 4545288000>[0]
<list 4545288000> is referred to by <TracingEqn 4545026240>.in_tracers
<TracingEqn 4545026240> is referred to by <DynamicJaxprTracer 4545091792>
<DynamicJaxprTracer 4545091792> is referred to by <list 4545285696>[0]
<list 4545285696> is referred to by <TracingEqn 4544745552>.in_tracers
<TracingEqn 4544745552> is referred to by <DynamicJaxprTracer 4545091904>
<DynamicJaxprTracer 4545091904> is referred to by <list 4545429056>[1]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<key<fry>[2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545091344> is referred to by <list 4545284992>[0]
<list 4545284992> is referred to by <TracingEqn 4544784256>.in_tracers
<TracingEqn 4544784256> is referred to by <DynamicJaxprTracer 4545091456>
<DynamicJaxprTracer 4545091456> is referred to by <list 4545288000>[0]
<list 4545288000> is referred to by <TracingEqn 4545026240>.in_tracers
<TracingEqn 4545026240> is referred to by <DynamicJaxprTracer 4545091792>
<DynamicJaxprTracer 4545091792> is referred to by <list 4545285696>[0]
<list 4545285696> is referred to by <TracingEqn 4544745552>.in_tracers
<TracingEqn 4544745552> is referred to by <DynamicJaxprTracer 4545091904>
<DynamicJaxprTracer 4545091904> is referred to by <list 4545429056>[1]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<key<fry>[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545091232> is referred to by <list 4545286272>[0]
<list 4545286272> is referred to by <TracingEqn 4544781216>.in_tracers
<TracingEqn 4544781216> is referred to by <DynamicJaxprTracer 4545091344>
<DynamicJaxprTracer 4545091344> is referred to by <list 4545284992>[0]
<list 4545284992> is referred to by <TracingEqn 4544784256>.in_tracers
<TracingEqn 4544784256> is referred to by <DynamicJaxprTracer 4545091456>
<DynamicJaxprTracer 4545091456> is referred to by <list 4545288000>[0]
<list 4545288000> is referred to by <TracingEqn 4545026240>.in_tracers
<TracingEqn 4545026240> is referred to by <DynamicJaxprTracer 4545091792>
<DynamicJaxprTracer 4545091792> is referred to by <list 4545285696>[0]
<list 4545285696> is referred to by <TracingEqn 4544745552>.in_tracers
<TracingEqn 4544745552> is referred to by <DynamicJaxprTracer 4545091904>
<DynamicJaxprTracer 4545091904> is referred to by <list 4545429056>[1]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<uint32[2]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545091120> is referred to by <list 4544767488>[0]
<list 4544767488> is referred to by <TracingEqn 4545251920>.in_tracers
<TracingEqn 4545251920> is referred to by <DynamicJaxprTracer 4545091232>
<DynamicJaxprTracer 4545091232> is referred to by <list 4545286272>[0]
<list 4545286272> is referred to by <TracingEqn 4544781216>.in_tracers
<TracingEqn 4544781216> is referred to by <DynamicJaxprTracer 4545091344>
<DynamicJaxprTracer 4545091344> is referred to by <list 4545284992>[0]
<list 4545284992> is referred to by <TracingEqn 4544784256>.in_tracers
<TracingEqn 4544784256> is referred to by <DynamicJaxprTracer 4545091456>
<DynamicJaxprTracer 4545091456> is referred to by <list 4545288000>[0]
<list 4545288000> is referred to by <TracingEqn 4545026240>.in_tracers
<TracingEqn 4545026240> is referred to by <DynamicJaxprTracer 4545091792>
<DynamicJaxprTracer 4545091792> is referred to by <list 4545285696>[0]
<list 4545285696> is referred to by <TracingEqn 4544745552>.in_tracers
<TracingEqn 4544745552> is referred to by <DynamicJaxprTracer 4545091904>
<DynamicJaxprTracer 4545091904> is referred to by <list 4545429056>[1]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<key<fry>[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545090896> is referred to by <list 4544767552>[0]
<list 4544767552> is referred to by <TracingEqn 4545250640>.in_tracers
<TracingEqn 4545250640> is referred to by <DynamicJaxprTracer 4545091120>
<DynamicJaxprTracer 4545091120> is referred to by <list 4544767488>[0]
<list 4544767488> is referred to by <TracingEqn 4545251920>.in_tracers
<TracingEqn 4545251920> is referred to by <DynamicJaxprTracer 4545091232>
<DynamicJaxprTracer 4545091232> is referred to by <list 4545286272>[0]
<list 4545286272> is referred to by <TracingEqn 4544781216>.in_tracers
<TracingEqn 4544781216> is referred to by <DynamicJaxprTracer 4545091344>
<DynamicJaxprTracer 4545091344> is referred to by <list 4545284992>[0]
<list 4545284992> is referred to by <TracingEqn 4544784256>.in_tracers
<TracingEqn 4544784256> is referred to by <DynamicJaxprTracer 4545091456>
<DynamicJaxprTracer 4545091456> is referred to by <list 4545288000>[0]
<list 4545288000> is referred to by <TracingEqn 4545026240>.in_tracers
<TracingEqn 4545026240> is referred to by <DynamicJaxprTracer 4545091792>
<DynamicJaxprTracer 4545091792> is referred to by <list 4545285696>[0]
<list 4545285696> is referred to by <TracingEqn 4544745552>.in_tracers
<TracingEqn 4544745552> is referred to by <DynamicJaxprTracer 4545091904>
<DynamicJaxprTracer 4545091904> is referred to by <list 4545429056>[1]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

JitTracer<int32[]>
The error occurred while tracing the function get_trace at /Users/till/git/numpyro/numpyro/infer/inspect.py:307 for jit. 
<DynamicJaxprTracer 4545091008> is referred to by <list 4544767808>[0]
<list 4544767808> is referred to by <TracingEqn 4545057456>.in_tracers
<TracingEqn 4545057456> is referred to by <DynamicJaxprTracer 4545090896>
<DynamicJaxprTracer 4545090896> is referred to by <list 4544767552>[0]
<list 4544767552> is referred to by <TracingEqn 4545250640>.in_tracers
<TracingEqn 4545250640> is referred to by <DynamicJaxprTracer 4545091120>
<DynamicJaxprTracer 4545091120> is referred to by <list 4544767488>[0]
<list 4544767488> is referred to by <TracingEqn 4545251920>.in_tracers
<TracingEqn 4545251920> is referred to by <DynamicJaxprTracer 4545091232>
<DynamicJaxprTracer 4545091232> is referred to by <list 4545286272>[0]
<list 4545286272> is referred to by <TracingEqn 4544781216>.in_tracers
<TracingEqn 4544781216> is referred to by <DynamicJaxprTracer 4545091344>
<DynamicJaxprTracer 4545091344> is referred to by <list 4545284992>[0]
<list 4545284992> is referred to by <TracingEqn 4544784256>.in_tracers
<TracingEqn 4544784256> is referred to by <DynamicJaxprTracer 4545091456>
<DynamicJaxprTracer 4545091456> is referred to by <list 4545288000>[0]
<list 4545288000> is referred to by <TracingEqn 4545026240>.in_tracers
<TracingEqn 4545026240> is referred to by <DynamicJaxprTracer 4545091792>
<DynamicJaxprTracer 4545091792> is referred to by <list 4545285696>[0]
<list 4545285696> is referred to by <TracingEqn 4544745552>.in_tracers
<TracingEqn 4544745552> is referred to by <DynamicJaxprTracer 4545091904>
<DynamicJaxprTracer 4545091904> is referred to by <list 4545429056>[1]
<list 4545429056> is referred to by <TracingEqn 4545448016>.in_tracers
<TracingEqn 4545448016> is referred to by <DynamicJaxprTracer 4545096832>
<DynamicJaxprTracer 4545096832> is referred to by <dict 4545287424>['a']
<dict 4545287424> is referred to by <AutoNormal 4545055104>._init_locs
<AutoNormal 4545055104> is referred to by __main__.guide

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

tillahoffmann avatar Aug 14 '25 18:08 tillahoffmann

AutoNormal is a stateful object which might create new jax attributes during calling. I think you can close it via something like

guide = lambda: AutoNormal(model)()

fehiepsi avatar Aug 15 '25 01:08 fehiepsi