array-api-tests icon indicating copy to clipboard operation
array-api-tests copied to clipboard

test_sum can fail with unstable summation

Open asmeurer opened this issue 1 year ago • 6 comments

This test fails with pytorch's sum() (note you need to use

E           AssertionError: out=5.0, but should be roughly 7.0 [sum()]
E           Falsifying example: test_sum(
E               x=tensor([[ 0.0000e+00,  3.3554e+07, -3.3554e+07],
E                       [ 1.0000e+00,  1.0000e+00,  1.0000e+00]]), data=data(...),
E           )
E           Draw 1 (kw): {}

Note that you need to use https://github.com/data-apis/array-api-compat/pull/14 because torch.sum has other issues with its signature and type promotion that will fail the test earlier.

torch truncates the float values, but here is the full tensor:

torch.tensor([[ 0.0000000000e+00,  3.3554432000e+07, -3.3554428000e+07],[ 1.0000000000e+00,  1.0000000000e+00,  1.0000000000e+00]])

Note that this is float32, the torch default dtype. The problem is:

>>> np.float32(33554432.0) + np.float32(1.0)
33554432.0

So if you add the elements in the wrong order, you get the wrong sum (5 instead of 7).

NumPy's sum does this correctly. I don't know if we run the tests on NumPy for long enough if it will also come up with a similar situation. It's really hard to do summation stably. My guess is that the only reason NumPy passes is because NumPy happens to do exactly what we are doing in the test, i.e., upcast to float64 (i.e., float), and sum in C-order. But PyTorch's default dtype is float32, so it isn't required to return a float64 from sum(dtype=None) and therefore doesn't upcast internally.

The spec doesn't require sum to use a stable algorithm (like Kahan) vs. naive term-wise summation, and doesn't require the terms to be summed in any particular order. So we should avoid generating examples that can lead to loss of significance. This might require some research, but I think it should be sufficient to avoid generating examples that are too far from each other in magnitude. Another idea would be to compute the exact sum (e.g., using Fraction) and compare it to a naive summation. One problem is that the amount of loss of significance depends on the order of summation, and there's no guarantee in which order the library will sum the array.

Also, even if we do this, it's problematic to check the summation by first upcasting to float when the input is float32, because that will inherently produce a more accurate result than if the algorithm worked directly on the float32.

asmeurer avatar Feb 03 '23 00:02 asmeurer