pyscfad
pyscfad copied to clipboard
JAX v0.5 breaks example.
I know nothing about JAX and I was just beggining to understand pyscfad, but I installed it like:
pip install pyscfad
and I was trying to run your simple example here: https://github.com/fishjojo/pyscfad/blob/main/examples/dft/00-simple.py . It raises the following error:
File "/home/pfebrer/miniforge3/envs/pyscf/lib/python3.10/site-packages/pyscfad/backend/_jax/lax/linalg.py", line 81, in _eigh_gen_jvp_rule
eji = w[..., jnp.newaxis, :] - w[..., jnp.newaxis]
AttributeError: module 'jax._src.numpy.lax_numpy' has no attribute 'newaxis'
If I downgrade JAX to v0.4.38 the example works fine 👍