pyscfad icon indicating copy to clipboard operation
pyscfad copied to clipboard

JAX v0.5 breaks example.

Open pfebrer opened this issue 11 months ago • 2 comments

I know nothing about JAX and I was just beggining to understand pyscfad, but I installed it like:

pip install pyscfad

and I was trying to run your simple example here: https://github.com/fishjojo/pyscfad/blob/main/examples/dft/00-simple.py . It raises the following error:

  File "/home/pfebrer/miniforge3/envs/pyscf/lib/python3.10/site-packages/pyscfad/backend/_jax/lax/linalg.py", line 81, in _eigh_gen_jvp_rule
    eji = w[..., jnp.newaxis, :] - w[..., jnp.newaxis]
AttributeError: module 'jax._src.numpy.lax_numpy' has no attribute 'newaxis'

If I downgrade JAX to v0.4.38 the example works fine 👍

pfebrer avatar Mar 12 '25 16:03 pfebrer