Tarmo Äijö
Tarmo Äijö
[This](https://jax.readthedocs.io/en/latest/_autosummary/jax.clear_caches.html) seems to solve the issue.
I'm not sure how actionable this is but I used the following code: ```python import numpy as np import numpyro, jax from numpyro import sample import numpyro.distributions as dist from...
@fehiepsi I did run a couple of experiments without NumPyro to see the effects of `jax.jit()`, `jax.clear_caches()`, and dynamic shapes. Did you mean something like that? These results suggest that...
This should work ```python print(potential_fn_gen()(init_params.z)) ``` Please check how`init_params.z` looks.