PyBaMM icon indicating copy to clipboard operation
PyBaMM copied to clipboard

JAX BDF solver tests failing / update `[jax]` versions (due to `scipy.linalg.tril` deprecation)

Open agriyakhetarpal opened this issue 10 months ago • 13 comments

The JAX BDF solver tests are failing on all PRs (#3846, #3945, etc.) for Python 3.9 and later because SciPy removed some linear algebra routines in v1.13.0. The Python 3.8 tests are passing because SciPy has dropped support for it earlier

I'm guessing we need to bump the jax and jaxlib versions now or relax the pin in the requirements, because there have been quite many releases since v0.4.20 – the current version available at the time of writing is v0.4.25.

Checklist

  • [x] Fix Jax test issues with updated versions
  • [ ] Unpin scipy version (#3962)

agriyakhetarpal avatar Apr 03 '24 11:04 agriyakhetarpal

It's probably not as trivial as bumping the JAX version because there are a few other errors that I don't understand with JAX's JIT and spectral volumes, so I'm putting this aside for a bit to return to soon and let others proceed if there is progress

agriyakhetarpal avatar Apr 03 '24 12:04 agriyakhetarpal

Bumping to v0.4.24 fixes at least part of the tests, earlier versions still have the SciPy error

agriyakhetarpal avatar Apr 03 '24 12:04 agriyakhetarpal

It is worthwhile to bump jax up as high as possible. We have people that are experienced with Jax that might be able to help. We are going to get into more compatibility issues as the code ages

kratman avatar Apr 03 '24 12:04 kratman

I agree with you – v0.4.26 is their latest release, should we drop the pin altogether? It might break on v0.5.X, so having >0.4, <0.5 bounds is another option

agriyakhetarpal avatar Apr 03 '24 13:04 agriyakhetarpal

Pinning is fine so there are not unexpected changes. Realistically we should have all major dependencies pinned. Something like dependabot should do the updates so the failures are all in one place

kratman avatar Apr 03 '24 13:04 kratman

Do you need help with this one?

kratman avatar Apr 03 '24 13:04 kratman

we should have all major dependencies pinned

We shouldn't pin to exact versions as that may cause compatibility issues for our users (if they try to use pybamm + another package that happens to pin e.g. numpy to a different version). We can specify ranges but they should be as wide as possible

valentinsulzer avatar Apr 03 '24 13:04 valentinsulzer

jax is an exception where we have to pin the exact version since every release changes the API

valentinsulzer avatar Apr 03 '24 14:04 valentinsulzer

Do you need help with this one?

I would appreciate that, being someone who hasn't used JAX a lot. I was able to get the tests to pass with newer versions of JAX (some of those can be ignored because it's probably not caching the solves properly on my machine). Some spatial methods tests are still failing, where I received IndexErrors – and my debugger doesn't help there

We can specify ranges but they should be as wide as possible

To add to this, we have been keeping the lower bounds in sync with the versions of the packages available on conda-forge (too much of a lower bound brought some trouble earlier during the time of the PyBaMM 23.9 release). It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code?

agriyakhetarpal avatar Apr 03 '24 14:04 agriyakhetarpal

It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code?

I was planning on putting up a PR for that this week. Seemed to align with the removal of ODEs and the removal of the Jax windows restrictions. I will probably just go ahead and make that PR while helping with the Jax stuff. I should have a bit of time to take a look this afternoon. Just share the branch you are working on and I will see what I can do to help out

kratman avatar Apr 03 '24 14:04 kratman

I don't have a branch or anything concrete, I was debugging only locally. I'll add the link here once I get back to it

agriyakhetarpal avatar Apr 03 '24 15:04 agriyakhetarpal

Yeah let's follow numpy's lead for which python versions we support, they have dropped support for 3.8

valentinsulzer avatar Apr 03 '24 17:04 valentinsulzer

A few related issues were solved with #3963, #3961, and #3962. I will take another stab at updating Jax in a few days

kratman avatar Apr 03 '24 22:04 kratman

I was checking this issue, hasn't this been solved by the PRs Eric referenced above? When I looked at the CI tests seem to be passing.

brosaplanella avatar May 20 '24 11:05 brosaplanella

Ah, that is still one part of the issue. The other thing is that we still need to unpin SciPy which is currently set to <1.13.0, IIRC.

agriyakhetarpal avatar May 20 '24 11:05 agriyakhetarpal

So, if we unpin SciPy then tests fail, right? Is there any branch where this is done so I can see the errors?

brosaplanella avatar May 20 '24 12:05 brosaplanella

Yes. I tried only locally last time and I was just going to open up a PR to show you the logs, but I'm facing a strange error locally right now:

nox > python run-tests.py --unit
nox > Command python run-tests.py --unit failed with exit code -9
nox > Session unit failed.

and zsh kills my shell for some reason. I think this is something because of #4092 that we merged a while ago which wasn't caught in CI for either of the architectures. Maybe this is related to the fact that I upgraded my macOS version a few hours ago.

agriyakhetarpal avatar May 20 '24 13:05 agriyakhetarpal

Edit: I see that you opened a PR just at the time I commented :)

agriyakhetarpal avatar May 20 '24 13:05 agriyakhetarpal