probability icon indicating copy to clipboard operation
probability copied to clipboard

JAX 0.7 compatibility

Open davindicode opened this issue 6 months ago • 5 comments

File ".../venvserver/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 681, in <module>
    jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
...
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.

Seems like TFP for JAX was using deprecated features that are removed in the recent JAX release v0.7.0

davindicode avatar Aug 06 '25 04:08 davindicode

Please use tfp-nightly for JAX support; more discussion here: https://github.com/tensorflow/probability/issues/1994

csuter avatar Aug 06 '25 20:08 csuter

@csuter unfortunately, tfp-nightly runs into the same issue.

epignatelli avatar Aug 12 '25 14:08 epignatelli

Can you make sure the old tensorflow-probability pip package isn't still installed?

csuter avatar Aug 12 '25 14:08 csuter

@csuter yes, installed in a fresh conda environment, but tfp-nightly didn't do the trick, unfortunately.

To give the full context, the issue is when installing distrax, which relies on tfp, see https://github.com/google-deepmind/distrax/issues/295#issuecomment-3180184263. It seems very likely that this is a distrax issue rather than a tfp one, even when forcing tfp-nightly.

epignatelli avatar Aug 19 '25 09:08 epignatelli

distrax is installing the non nightly tfp package. This will have to be fixed on the distrax side. Tfp nightly works fine with latest Jax.

csuter avatar Aug 19 '25 12:08 csuter