array-api-tests icon indicating copy to clipboard operation
array-api-tests copied to clipboard

Tests are too slow for libraries with eager-mode dispatch overhead

Open jakevdp opened this issue 1 year ago • 13 comments

For example this job: https://github.com/google/jax/actions/runs/6804780902/job/18502996983, which is associated with https://github.com/google/jax/pull/16099 took 4.5 hours to run about 1000 tests.

========== 567 failed, 439 passed, 116 skipped in 16419.47s (4:33:39) ==========

This is slow enough to be virtually unusable.

~Is there anything I can do to speed up the tests during development of the JAX array API?~

Edit: as mentioned below, this is due to the frequent use of patterns within the tests that every value in each array via indexing, and the fact that in eager execution, each of these indexing operations have a small dispatch overhead that leads to slow tests when arrays are large.

jakevdp avatar Nov 09 '23 16:11 jakevdp

They shouldn't be that slow. The NumPy tests finish in a fraction of that time https://github.com/data-apis/array-api-compat/actions/runs/6565885145/job/17835396340. It might be worth investigating what is going on.

Although even there, the NumPy tests took an hour to run, which seems like a lot. We should investigate if something slowed down recently.

Generally, though, you can speed things up by lowering the hypothesis --max-examples. By default it is 100, but something like 50 should make the tests run in roughly half the time.

asmeurer avatar Nov 09 '23 19:11 asmeurer

I just ran ARRAY_API_TESTS_MODULE=numpy.array_api pytest array_api_tests/ locally and it took 12 minutes. So that's the about time the test should run in (there were a few failures which speed up the runtime a bit, but in general it shouldn't take more than 20 minutes).

asmeurer avatar Nov 09 '23 20:11 asmeurer

Can you run something like pytest --durations=10 to print the 10 slowest tests? That would help to pin down what is going on.

asmeurer avatar Nov 09 '23 20:11 asmeurer

I think the slowness comes from test patterns that look like this: https://github.com/data-apis/array-api-tests/blob/f82c7bc8627cc2c3a44fa3e425f53a253a609aa8/array_api_tests/test_creation_functions.py#L358-L364

In JAX, each operation outside the context of JIT compilation has a small amount of overhead related to dispatch and device placement for the output, so running $\mathcal{O}[N^2]$ indexing operations in a loop will accumulate that overhead and be very slow.

jakevdp avatar Nov 09 '23 20:11 jakevdp

There's a variable that limits the max array size https://github.com/data-apis/array-api-tests/blob/37bbb580975b21a40c23463d651830dfc4dd35d0/array_api_tests/hypothesis_helpers.py#L167. The default is 10000 but we should make it configurable. Lowering it to 1000 or so for JAX would probably fix this.

asmeurer avatar Nov 09 '23 20:11 asmeurer

I pasted the slowest 20 tests below.

For the most part I think it comes down to the slow repeated indexing I mentioned above: every one of these that I've checked indexes each element of the output to check it against a reference implementation.

54.85s call     array_api_tests/test_creation_functions.py::test_linspace
35.72s call     array_api_tests/test_creation_functions.py::test_eye
25.87s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)]
21.24s call     array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
20.68s call     array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
20.64s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)]
19.93s call     array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x, s)]
19.88s call     array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
19.72s call     array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
19.59s call     array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)]
18.77s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)]
18.14s call     array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
17.24s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)]
16.71s call     array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
14.99s call     array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
13.70s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)]
13.69s call     array_api_tests/test_operators_and_elementwise_functions.py::test_logical_or
13.46s call     array_api_tests/test_operators_and_elementwise_functions.py::test_atan2
13.00s call     array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]

jakevdp avatar Nov 09 '23 23:11 jakevdp

Would you be open to changing logic that looks like this: https://github.com/data-apis/array-api-tests/blob/f82c7bc8627cc2c3a44fa3e425f53a253a609aa8/array_api_tests/test_creation_functions.py#L358-L364 to something more like this?

k = kw.get("k", 0)
expected = [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)]
assert xp.all(xp.asarray(expected) == out)

It does depend on asarray and all working properly, but the previous approach depends on indexing working properly. You lose the granularity of the error, but that could be addressed using where when generating the error message.

Rewriting tests this way would make the test suite usable for JAX and other libraries that have non-negligible dispatch overhead in eager mode.

I would be happy to prepare a PR if you think this is the right direction for the testing suite.

jakevdp avatar Nov 10 '23 19:11 jakevdp

That would work for the creation functions like eye, but how would you change the elementwise function tests? For example, test_bitwise_left_shift (one of your slowest tests):

https://github.com/data-apis/array-api-tests/blob/f82c7bc8627cc2c3a44fa3e425f53a253a609aa8/array_api_tests/test_operators_and_elementwise_functions.py#L822-L837

The loop is done in this helper https://github.com/data-apis/array-api-tests/blob/f82c7bc8627cc2c3a44fa3e425f53a253a609aa8/array_api_tests/test_operators_and_elementwise_functions.py#L311

You have to loop through the input arrays to generate the exact outputs. Maybe there's some way to extract the original input array list from hypothesis (assuming it didn't go through any further transformations after being converted to an array). @honno, thoughts?

Anyway, I think changing it to be like that for the creation functions is fine.

asmeurer avatar Nov 10 '23 21:11 asmeurer

Anyway, I think the best solution for you is to make MAX_ARRAY_SIZE configurable. I expect 1000 (or even 100) would be fine for most test cases, but would drop the runtime of those tests down correspondingly.

asmeurer avatar Nov 10 '23 21:11 asmeurer

Another option would be to use dlpack to export the full array to numpy or some other standard, where eager-mode indexing is not problematic.

jakevdp avatar Nov 10 '23 22:11 jakevdp

I did a bit of profiling and testing to figure out what's going on. The result of:

$ py-spy record -o profile.svg -- pytest array_api_tests/

is this:

profile

That takes about 15 minutes, and ~75% of the time is spent in test_special_cases.py. Special cases are really not of interest to start with when developing an array API implementation, so the next step bash to add:

--ignore=array_api_tests/test_special_cases.py

That brings it down to ~4 minutes. With NumPy 1.26.0, there are still 4 failures when running on numpy.array_api. These 4 failures can be reproduced also with

--max-examples=1

and that runs in about 3 seconds. Adding back the special cases tests makes it run in 18 seconds with --max-examples=1.

Right now we are in the position that we need this 3 second run to pass on libraries. That is far more important than anything else; there are, as of now, zero implementations that actually pass 100%. The array_api_compat layer is missing array methods like .mT and .to_device for numpy and shows casting rule issues (as expected until numpy 2.0); plain numpy.array_api comes closest but still needs a couple of fixes.

It looks to me like we need to focus on that. And in addition, make some of the extra-costly checks for JAX vectorized, as we're discussing in gh-200 right now.

rgommers avatar Nov 14 '23 10:11 rgommers

Thanks - with --max-examples=5 and a number of skips related to mutation and other issues, the jax.experimental.array_api PR now passes the test suite! https://github.com/google/jax/pull/16099

I did find that it errors with this week's hypothesis 6.88.4 release; perhaps I should file a bug for that separately.

jakevdp avatar Nov 14 '23 19:11 jakevdp

Awesome, that's great to see!

rgommers avatar Nov 14 '23 19:11 rgommers