array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

torch matmul does not handle different dtypes

Open ev-br opened this issue 10 months ago • 3 comments

The spec requires that matmul follows the type promotion rules for the arguments, but pytorch requires that the dtypes match:

In [3]: import array_api_strict as xp

In [5]: xp.ones(3, dtype=xp.float32) @ xp.ones(3, dtype=xp.float64)
Out[5]: Array(3., dtype=array_api_strict.float64)

In [6]: torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)

RuntimeError: dot : expected both vectors to have same dtype, but found Float and Double

It's not immediately clear to me whether we want to paper over it in compat- or leave the conversion to end users: it's easy to imagine a use case were the copying overhead is significant.

ev-br avatar Jan 30 '25 10:01 ev-br

There's no good way to override @ behavior I think. For matmul we can do same-kind type promotion I think, there shouldn't be extra overhead - no other library has mixed-dtype implementations either AFAIK (e.g., see np.matmul.types).

rgommers avatar Jan 30 '25 10:01 rgommers

Numpy seems to do it:

In [8]: np.ones(3) @ np.ones(3, dtype=complex)
Out[8]: np.complex128(3+0j)

cross-ref https://discuss.pytorch.org/t/matmul-mixed-dtypes/216044 for a pytorch discourse question.

ev-br avatar Jan 30 '25 10:01 ev-br

Yeah I know, I didn't say it doesn't - I meant it does internal upcasting and then calls a routine with both dtypes being the same.

rgommers avatar Jan 30 '25 10:01 rgommers