JAX backend doesn't use `jax.tree_util`
In this part, expand_composites is set to False and this prevents jax.tree_util for kicking in.
https://github.com/tensorflow/probability/blob/64a70d0850f2ab6b69693b7fb728b7c6eb759a7e/tensorflow_probability/python/internal/backend/numpy/nest.py#L314-L320
The resulting problem is that we cannot use pytree incompatible with dm_tree
Herę is notebook reproducing the bug https://colab.research.google.com/gist/krzysztofrusek/a9fa71ca2bf3952a9f18358309225107/eqx_tfp.ipynb
expand_composites should probably be set to True in that location. There are other locations of similar nature where it is set to True (e.g. https://github.com/tensorflow/probability/blob/23a292a64f255fe7a98b32317d238b3e84f50c7f/tensorflow_probability/python/internal/loop_util.py#L183, https://github.com/tensorflow/probability/blob/23a292a64f255fe7a98b32317d238b3e84f50c7f/tensorflow_probability/python/mcmc/internal/util.py#L116), so I don't see an obvious reason why it shouldn't be True here as well.