PyBaMM
PyBaMM copied to clipboard
JAX BDF solver tests failing / update `[jax]` versions (due to `scipy.linalg.tril` deprecation)
The JAX BDF solver tests are failing on all PRs (#3846, #3945, etc.) for Python 3.9 and later because SciPy removed some linear algebra routines in v1.13.0. The Python 3.8 tests are passing because SciPy has dropped support for it earlier
I'm guessing we need to bump the jax
and jaxlib
versions now or relax the pin in the requirements, because there have been quite many releases since v0.4.20
– the current version available at the time of writing is v0.4.25
.
Checklist
- [x] Fix Jax test issues with updated versions
- [ ] Unpin scipy version (#3962)
It's probably not as trivial as bumping the JAX version because there are a few other errors that I don't understand with JAX's JIT and spectral volumes, so I'm putting this aside for a bit to return to soon and let others proceed if there is progress
Bumping to v0.4.24 fixes at least part of the tests, earlier versions still have the SciPy error
It is worthwhile to bump jax up as high as possible. We have people that are experienced with Jax that might be able to help. We are going to get into more compatibility issues as the code ages
I agree with you – v0.4.26 is their latest release, should we drop the pin altogether? It might break on v0.5.X, so having >0.4, <0.5
bounds is another option
Pinning is fine so there are not unexpected changes. Realistically we should have all major dependencies pinned. Something like dependabot should do the updates so the failures are all in one place
Do you need help with this one?
we should have all major dependencies pinned
We shouldn't pin to exact versions as that may cause compatibility issues for our users (if they try to use pybamm + another package that happens to pin e.g. numpy to a different version). We can specify ranges but they should be as wide as possible
jax is an exception where we have to pin the exact version since every release changes the API
Do you need help with this one?
I would appreciate that, being someone who hasn't used JAX a lot. I was able to get the tests to pass with newer versions of JAX (some of those can be ignored because it's probably not caching the solves properly on my machine). Some spatial methods tests are still failing, where I received IndexError
s – and my debugger doesn't help there
We can specify ranges but they should be as wide as possible
To add to this, we have been keeping the lower bounds in sync with the versions of the packages available on conda-forge (too much of a lower bound brought some trouble earlier during the time of the PyBaMM 23.9 release). It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code?
It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code?
I was planning on putting up a PR for that this week. Seemed to align with the removal of ODEs and the removal of the Jax windows restrictions. I will probably just go ahead and make that PR while helping with the Jax stuff. I should have a bit of time to take a look this afternoon. Just share the branch you are working on and I will see what I can do to help out
I don't have a branch or anything concrete, I was debugging only locally. I'll add the link here once I get back to it
Yeah let's follow numpy's lead for which python versions we support, they have dropped support for 3.8
A few related issues were solved with #3963, #3961, and #3962. I will take another stab at updating Jax in a few days
I was checking this issue, hasn't this been solved by the PRs Eric referenced above? When I looked at the CI tests seem to be passing.
Ah, that is still one part of the issue. The other thing is that we still need to unpin SciPy which is currently set to <1.13.0, IIRC.
So, if we unpin SciPy then tests fail, right? Is there any branch where this is done so I can see the errors?
Yes. I tried only locally last time and I was just going to open up a PR to show you the logs, but I'm facing a strange error locally right now:
nox > python run-tests.py --unit
nox > Command python run-tests.py --unit failed with exit code -9
nox > Session unit failed.
and zsh
kills my shell for some reason. I think this is something because of #4092 that we merged a while ago which wasn't caught in CI for either of the architectures. Maybe this is related to the fact that I upgraded my macOS version a few hours ago.
Edit: I see that you opened a PR just at the time I commented :)