PyBaMM icon indicating copy to clipboard operation
PyBaMM copied to clipboard

Hangs encountered in `IDAKLUJax` unit tests (`test_jacrev_vmap` and others)

Open agriyakhetarpal opened this issue 3 months ago • 7 comments

PyBaMM Version

develop

Python Version

3.11.8

Describe the bug

The test_jacrev_vmap test case in the TestIDAKLUJax class (present in tests/unit/test_solvers/test_idaklu_jax.py) hangs quite a lot during local development. It is one of the slowest tests to pass, to the point that coverage logging almost gets stuck indefinitely on 99% and that this test, in particular, takes time in several orders of magnitude more to complete when compared to the rest of the tests.

This is most likely coming from the recent migration to using pytest for running the unit tests (#3857), which also brought support for pytest-xdist for parallel execution of unit tests, where JAX-related unit tests take up a lot of time in CI in parallel mode.

Here's an SVG from a profiling sample with the pytest-profiling plugin from @prady0t earlier in the #infrastructure channel on Slack:

Expand to view

combined

which reveals that something is up with the JAX-related tests.

Steps to Reproduce

There isn't a better reproducer at this time, but to reproduce one can run nox -s coverage or its pytest --cov equivalent in the root directory – it is a bit slower than nox -s unit, but both of them seem to have the same issue.

Relevant log output

No response

agriyakhetarpal avatar Mar 31 '24 18:03 agriyakhetarpal

Temporary solution: wrap all classes and the methods with @pytest.mark.xdist_group(name="serial execution") to run them inside the same worker in serial mode.

This resolves the test_jacrev_vmap execution and it runs as normal – but test_jacrev_vector_getvars, and test_solver_ + test_solver_sensitivities from the JAX BDF solver are three other culprits which still aren't happy – they still take a lot longer than any other test. I assume we can extract more speedups with the newly enabled parallel testing, however, these test cases do not seem to budge and are causing bottlenecks. Maybe @jsbrittain would have some suggestions here?

agriyakhetarpal avatar Mar 31 '24 19:03 agriyakhetarpal

To bring some rudimentary sense of the issue at hand, here's what I can see locally on an M-series macOS machine:

Running the entire coverage suite with the JAX tests included:

1596 passed, 2 skipped in 663.93s (0:11:03)

And running it again, except test_idaklu_jax.py and test_jax_bdf_solver.py brings 98% of the test suite to completion at just 87 seconds!

agriyakhetarpal avatar Mar 31 '24 19:03 agriyakhetarpal

I saw this on python 3.9 as well

kratman avatar Apr 03 '24 18:04 kratman

@agriyakhetarpal To me it looks like the functions in there are parallel. So parallel tests with parallel code means a ton of extra threads. On my Mac those tests seem to use 4 threads each

kratman avatar Apr 03 '24 18:04 kratman

I want to help with this issue. Please let me know if you need another hand.

cringeyburger avatar Apr 06 '24 22:04 cringeyburger

Thanks, @prady0t, well, it's really just the attempt in https://github.com/pybamm-team/PyBaMM/issues/3948#issuecomment-2028878236 that's helped with one out of the four tests, so we need to dig deeper since the tests are running in parallel which slows it down but also that they are really slow in serial execution too. I think this can be tackled at a later time since the tests pass, at least. I think you're doing #3940 with @lorenzofavaro and that is a higher-priority issue we need to tackle at this moment :)

agriyakhetarpal avatar Apr 06 '24 22:04 agriyakhetarpal

Oops, I just realised that I mentioned and I tagged the wrong person, I am sorry. By all means, please feel free to help out here, @cringeyburger!

agriyakhetarpal avatar Apr 07 '24 21:04 agriyakhetarpal