Tarmo Äijö

Results 4 comments of Tarmo Äijö

[This](https://jax.readthedocs.io/en/latest/_autosummary/jax.clear_caches.html) seems to solve the issue.

I'm not sure how actionable this is but I used the following code: ```python import numpy as np import numpyro, jax from numpyro import sample import numpyro.distributions as dist from...

@fehiepsi I did run a couple of experiments without NumPyro to see the effects of `jax.jit()`, `jax.clear_caches()`, and dynamic shapes. Did you mean something like that? These results suggest that...

This should work ```python print(potential_fn_gen()(init_params.z)) ``` Please check how`init_params.z` looks.