array-api-compat
array-api-compat copied to clipboard
Some PyTorch fixes
Thanks @asmeurer. That all seems reasonable. For context: does it fix any issues, or CI or test suite failures visible somewhere?
Yes. They don't show up on CI yet because I need to finish fixing the tests for pytorch https://github.com/data-apis/array-api-tests/pull/266. But these were the failures:
================================================================================ FAILURES =================================================================================
_______________________________________________________________________________ test_solve ________________________________________________________________________________
@pytest.mark.unvectorized
> @pytest.mark.xp_extension('linalg')
array_api_tests/test_linalg.py:640:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
x1 = tensor([[[1.]]]), x2 = tensor([[0.]])
@pytest.mark.unvectorized
@pytest.mark.xp_extension('linalg')
@given(*solve_args())
def test_solve(x1, x2):
res = linalg.solve(x1, x2)
ph.assert_dtype("solve", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
if x2.ndim == 1:
expected_shape = x1.shape[:-2] + x2.shape[-1:]
_test_stacks(linalg.solve, x1, x2, res=res, dims=1,
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
else:
stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
expected_shape = stack_shape + x2.shape[-2:]
_test_stacks(linalg.solve, x1, x2, res=res, dims=2)
> ph.assert_result_shape("solve", in_shapes=[x1.shape, x2.shape],
out_shape=res.shape, expected=expected_shape)
E AssertionError: out.shape=torch.Size([1, 1]), but should be (1, 1, 1) [solve( torch.Size([1, 1, 1]) . torch.Size([1, 1]) )]
E Falsifying example: test_solve(
E x1=tensor([[[1.]]]),
E x2=tensor([[0.]]),
E )
array_api_tests/test_linalg.py:655: AssertionError
____________________________________________________________________________ test_vector_norm _____________________________________________________________________________
@pytest.mark.unvectorized
> @pytest.mark.xp_extension('linalg')
array_api_tests/test_linalg.py:960:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
x = tensor([0.+0.j, 0.+0.j]), data = data(...)
@pytest.mark.unvectorized
@pytest.mark.xp_extension('linalg')
@given(
x=arrays(dtype=all_floating_dtypes(), shape=shapes(min_side=1)),
data=data(),
)
def test_vector_norm(x, data):
kw = data.draw(
# We use data because axes is parameterized on x.ndim
kwargs(axis=axes(x.ndim),
keepdims=booleans(),
ord=one_of(
sampled_from([2, 1, 0, -1, -2, float("inf"), float("-inf")]),
integers(-max_ord, max_ord),
floats(-max_ord, max_ord),
)), label="kw")
res = linalg.vector_norm(x, **kw)
axis = kw.get('axis', None)
keepdims = kw.get('keepdims', False)
# TODO: Check that the ord values give the correct norms.
# ord = kw.get('ord', 2)
_axes = sh.normalise_axis(axis, x.ndim)
> ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape,
in_shape=x.shape, axes=_axes,
keepdims=keepdims, kw=kw)
E AssertionError: out.shape=torch.Size([1]), but should be (2,) [linalg.vector_norm(axis=())]
E Falsifying example: test_vector_norm(
E x=tensor([0.+0.j, 0.+0.j]),
E data=data(...),
E )
E Draw 1 (kw): {'axis': ()}
array_api_tests/test_linalg.py:985: AssertionError
========================================================================= short test summary info =========================================================================
FAILED array_api_tests/test_linalg.py::test_solve - AssertionError: out.shape=torch.Size([1, 1]), but should be (1, 1, 1) [solve( torch.Size([1, 1, 1]) . torch.Size([1, 1]) )]
FAILED array_api_tests/test_linalg.py::test_vector_norm - AssertionError: out.shape=torch.Size([1]), but should be (2,) [linalg.vector_norm(axis=())]
I'm not completely sure why these only showed up now. Maybe something changed in a recent pytorch version, or else something was broken in the tests.