PyBaMM
PyBaMM copied to clipboard
Hangs encountered in `IDAKLUJax` unit tests (`test_jacrev_vmap` and others)
PyBaMM Version
develop
Python Version
3.11.8
Describe the bug
The test_jacrev_vmap
test case in the TestIDAKLUJax
class (present in tests/unit/test_solvers/test_idaklu_jax.py
) hangs quite a lot during local development. It is one of the slowest tests to pass, to the point that coverage logging almost gets stuck indefinitely on 99% and that this test, in particular, takes time in several orders of magnitude more to complete when compared to the rest of the tests.
This is most likely coming from the recent migration to using pytest
for running the unit tests (#3857), which also brought support for pytest-xdist
for parallel execution of unit tests, where JAX-related unit tests take up a lot of time in CI in parallel mode.
Here's an SVG from a profiling sample with the pytest-profiling
plugin from @prady0t earlier in the #infrastructure
channel on Slack:
Expand to view
which reveals that something is up with the JAX-related tests.
Steps to Reproduce
There isn't a better reproducer at this time, but to reproduce one can run nox -s coverage
or its pytest --cov
equivalent in the root directory – it is a bit slower than nox -s unit
, but both of them seem to have the same issue.
Relevant log output
No response
Temporary solution: wrap all classes and the methods with @pytest.mark.xdist_group(name="serial execution")
to run them inside the same worker in serial mode.
This resolves the test_jacrev_vmap
execution and it runs as normal – but test_jacrev_vector_getvars
, and test_solver_
+ test_solver_sensitivities
from the JAX BDF solver are three other culprits which still aren't happy – they still take a lot longer than any other test. I assume we can extract more speedups with the newly enabled parallel testing, however, these test cases do not seem to budge and are causing bottlenecks. Maybe @jsbrittain would have some suggestions here?
To bring some rudimentary sense of the issue at hand, here's what I can see locally on an M-series macOS machine:
Running the entire coverage suite with the JAX tests included:
1596 passed, 2 skipped in 663.93s (0:11:03)
And running it again, except test_idaklu_jax.py
and test_jax_bdf_solver.py
brings 98% of the test suite to completion at just 87 seconds!
I saw this on python 3.9 as well
@agriyakhetarpal To me it looks like the functions in there are parallel. So parallel tests with parallel code means a ton of extra threads. On my Mac those tests seem to use 4 threads each
I want to help with this issue. Please let me know if you need another hand.
Thanks, @prady0t, well, it's really just the attempt in https://github.com/pybamm-team/PyBaMM/issues/3948#issuecomment-2028878236 that's helped with one out of the four tests, so we need to dig deeper since the tests are running in parallel which slows it down but also that they are really slow in serial execution too. I think this can be tackled at a later time since the tests pass, at least. I think you're doing #3940 with @lorenzofavaro and that is a higher-priority issue we need to tackle at this moment :)
Oops, I just realised that I mentioned and I tagged the wrong person, I am sorry. By all means, please feel free to help out here, @cringeyburger!