JAX 0.7 compatibility
File ".../venvserver/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 681, in <module>
jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
...
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.
Seems like TFP for JAX was using deprecated features that are removed in the recent JAX release v0.7.0
Please use tfp-nightly for JAX support; more discussion here: https://github.com/tensorflow/probability/issues/1994
@csuter unfortunately, tfp-nightly runs into the same issue.
Can you make sure the old tensorflow-probability pip package isn't still installed?
@csuter yes, installed in a fresh conda environment, but tfp-nightly didn't do the trick, unfortunately.
To give the full context, the issue is when installing distrax, which relies on tfp, see https://github.com/google-deepmind/distrax/issues/295#issuecomment-3180184263. It seems very likely that this is a distrax issue rather than a tfp one, even when forcing tfp-nightly.
distrax is installing the non nightly tfp package. This will have to be fixed on the distrax side. Tfp nightly works fine with latest Jax.