array-api-tests
array-api-tests copied to clipboard
Tests are too slow for libraries with eager-mode dispatch overhead
For example this job: https://github.com/google/jax/actions/runs/6804780902/job/18502996983, which is associated with https://github.com/google/jax/pull/16099 took 4.5 hours to run about 1000 tests.
========== 567 failed, 439 passed, 116 skipped in 16419.47s (4:33:39) ==========
This is slow enough to be virtually unusable.
~Is there anything I can do to speed up the tests during development of the JAX array API?~
Edit: as mentioned below, this is due to the frequent use of patterns within the tests that every value in each array via indexing, and the fact that in eager execution, each of these indexing operations have a small dispatch overhead that leads to slow tests when arrays are large.
They shouldn't be that slow. The NumPy tests finish in a fraction of that time https://github.com/data-apis/array-api-compat/actions/runs/6565885145/job/17835396340. It might be worth investigating what is going on.
Although even there, the NumPy tests took an hour to run, which seems like a lot. We should investigate if something slowed down recently.
Generally, though, you can speed things up by lowering the hypothesis --max-examples
. By default it is 100, but something like 50 should make the tests run in roughly half the time.
I just ran ARRAY_API_TESTS_MODULE=numpy.array_api pytest array_api_tests/
locally and it took 12 minutes. So that's the about time the test should run in (there were a few failures which speed up the runtime a bit, but in general it shouldn't take more than 20 minutes).
Can you run something like pytest --durations=10
to print the 10 slowest tests? That would help to pin down what is going on.
I think the slowness comes from test patterns that look like this: https://github.com/data-apis/array-api-tests/blob/f82c7bc8627cc2c3a44fa3e425f53a253a609aa8/array_api_tests/test_creation_functions.py#L358-L364
In JAX, each operation outside the context of JIT compilation has a small amount of overhead related to dispatch and device placement for the output, so running $\mathcal{O}[N^2]$ indexing operations in a loop will accumulate that overhead and be very slow.
There's a variable that limits the max array size https://github.com/data-apis/array-api-tests/blob/37bbb580975b21a40c23463d651830dfc4dd35d0/array_api_tests/hypothesis_helpers.py#L167. The default is 10000 but we should make it configurable. Lowering it to 1000 or so for JAX would probably fix this.
I pasted the slowest 20 tests below.
For the most part I think it comes down to the slow repeated indexing I mentioned above: every one of these that I've checked indexes each element of the output to check it against a reference implementation.
54.85s call array_api_tests/test_creation_functions.py::test_linspace
35.72s call array_api_tests/test_creation_functions.py::test_eye
25.87s call array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)]
21.24s call array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
20.68s call array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
20.64s call array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)]
19.93s call array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x, s)]
19.88s call array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
19.72s call array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
19.59s call array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)]
18.77s call array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)]
18.14s call array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
17.24s call array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)]
16.71s call array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
14.99s call array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
13.70s call array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)]
13.69s call array_api_tests/test_operators_and_elementwise_functions.py::test_logical_or
13.46s call array_api_tests/test_operators_and_elementwise_functions.py::test_atan2
13.00s call array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
Would you be open to changing logic that looks like this: https://github.com/data-apis/array-api-tests/blob/f82c7bc8627cc2c3a44fa3e425f53a253a609aa8/array_api_tests/test_creation_functions.py#L358-L364 to something more like this?
k = kw.get("k", 0)
expected = [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)]
assert xp.all(xp.asarray(expected) == out)
It does depend on asarray
and all
working properly, but the previous approach depends on indexing working properly. You lose the granularity of the error, but that could be addressed using where
when generating the error message.
Rewriting tests this way would make the test suite usable for JAX and other libraries that have non-negligible dispatch overhead in eager mode.
I would be happy to prepare a PR if you think this is the right direction for the testing suite.
That would work for the creation functions like eye
, but how would you change the elementwise function tests? For example, test_bitwise_left_shift
(one of your slowest tests):
https://github.com/data-apis/array-api-tests/blob/f82c7bc8627cc2c3a44fa3e425f53a253a609aa8/array_api_tests/test_operators_and_elementwise_functions.py#L822-L837
The loop is done in this helper https://github.com/data-apis/array-api-tests/blob/f82c7bc8627cc2c3a44fa3e425f53a253a609aa8/array_api_tests/test_operators_and_elementwise_functions.py#L311
You have to loop through the input arrays to generate the exact outputs. Maybe there's some way to extract the original input array list from hypothesis (assuming it didn't go through any further transformations after being converted to an array). @honno, thoughts?
Anyway, I think changing it to be like that for the creation functions is fine.
Anyway, I think the best solution for you is to make MAX_ARRAY_SIZE
configurable. I expect 1000 (or even 100) would be fine for most test cases, but would drop the runtime of those tests down correspondingly.
Another option would be to use dlpack to export the full array to numpy
or some other standard, where eager-mode indexing is not problematic.
I did a bit of profiling and testing to figure out what's going on. The result of:
$ py-spy record -o profile.svg -- pytest array_api_tests/
is this:
That takes about 15 minutes, and ~75% of the time is spent in test_special_cases.py
. Special cases are really not of interest to start with when developing an array API implementation, so the next step bash to add:
--ignore=array_api_tests/test_special_cases.py
That brings it down to ~4 minutes. With NumPy 1.26.0, there are still 4 failures when running on numpy.array_api
. These 4 failures can be reproduced also with
--max-examples=1
and that runs in about 3 seconds. Adding back the special cases tests makes it run in 18 seconds with --max-examples=1
.
Right now we are in the position that we need this 3 second run to pass on libraries. That is far more important than anything else; there are, as of now, zero implementations that actually pass 100%. The array_api_compat
layer is missing array methods like .mT
and .to_device
for numpy and shows casting rule issues (as expected until numpy 2.0); plain numpy.array_api
comes closest but still needs a couple of fixes.
It looks to me like we need to focus on that. And in addition, make some of the extra-costly checks for JAX vectorized, as we're discussing in gh-200 right now.
Thanks - with --max-examples=5
and a number of skips related to mutation and other issues, the jax.experimental.array_api
PR now passes the test suite! https://github.com/google/jax/pull/16099
I did find that it errors with this week's hypothesis 6.88.4 release; perhaps I should file a bug for that separately.
Awesome, that's great to see!