BayesNewton
BayesNewton copied to clipboard
Cannot run demo, possible incompatibility with latest Jax
Dear all,
I am trying to run the demo examples, but I run in the following error
ImportError Traceback (most recent call last) Input In [22], in <cell line: 1>() ----> 1 import bayesnewton 2 import objax 3 import numpy as np
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in
ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)
I think its related to this from the Jax website:
The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.
I now realise that your pip installation asks for a specific jax version, which is a bit problematic for me, given that I am running on a M1 and installed jax via condaforge, I am not sure I can match to a compatible version, I will try and let you know if I succeed.
I managed to downgrade jax, but there is no jaxlib 0.1.60 available in condaforge, seems like it could be the source of this bug I get when trying to load objax 1.31:
TypeError Traceback (most recent call last) Input In [22], in <cell line: 1>() ----> 1 import bayesnewton 2 import objax 3 import numpy as np
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:1, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/objax/init.py:17, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/objax/_patch_jax.py:20, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/init.py:93, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/image/init.py:18, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/_src/image/scale.py:20, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/lax/init.py:324, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/_src/lax/fft.py:87, in
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/api.py:184, in jit(fun, static_argnums, device, backend, donate_argnums)
129 """Sets up fun
for just-in-time compilation with XLA.
130
131 Args:
(...)
181 -0.85743 -0.78232 0.76827 0.59566 ]
182 """
183 if FLAGS.experimental_cpp_jit and config.omnistaging_enabled:
--> 184 return _cpp_jit(fun, static_argnums, device, backend, donate_argnums)
185 else:
186 return _python_jit(fun, static_argnums, device, backend, donate_argnums)
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/api.py:370, in cpp_jit(fun, static_argnums, device, backend, donate_argnums) 367 return config.read("jax_disable_jit") 369 static_argnums = (0,) + tuple(i + 1 for i in static_argnums) --> 370 cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, 371 get_jax_enable_x64, get_jax_disable_jit_flag, 372 static_argnums_) 374 # TODO(mattjj): make cpp callable follow descriptor protocol for bound methods 375 @wraps(fun) 376 @api_boundary 377 def f_jitted(*args, **kwargs):
TypeError: jit(): incompatible function arguments. The following argument types are supported: 1. (fun: function, cache_miss: function, get_device: function, static_argnums: List[int], static_argnames: List[str] = [], donate_argnums: List[int] = [], cache: jaxlib.xla_extension.CompiledFunctionCache = None) -> object
Invoked with: <function _rfft_transpose at 0x7f93913f20d0>, <function _cpp_jit.
Hi,
Sorry for the slow response and apologies that you've been having issues with the package versions. This is indeed frustrating. I would love to update the package to work with the most recent versions, but I don't currently have the spare time.
I am using an M1 mac and things are working OK for me, but I'm not using condaforge.
The index_update
issue should be fairly easy to fix. However, I recall seeing some performance issues when I tried updating objax in the past, and I never managed to debug the issue. I hope to get around to this at some point in the future.
This issue should be fixed here: https://github.com/AaltoML/BayesNewton/pull/18 (updated to current release of jax)