torch matmul does not handle different dtypes
The spec requires that matmul follows the type promotion rules for the arguments, but pytorch requires that the dtypes match:
In [3]: import array_api_strict as xp
In [5]: xp.ones(3, dtype=xp.float32) @ xp.ones(3, dtype=xp.float64)
Out[5]: Array(3., dtype=array_api_strict.float64)
In [6]: torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Double
It's not immediately clear to me whether we want to paper over it in compat- or leave the conversion to end users: it's easy to imagine a use case were the copying overhead is significant.
There's no good way to override @ behavior I think. For matmul we can do same-kind type promotion I think, there shouldn't be extra overhead - no other library has mixed-dtype implementations either AFAIK (e.g., see np.matmul.types).
Numpy seems to do it:
In [8]: np.ones(3) @ np.ones(3, dtype=complex)
Out[8]: np.complex128(3+0j)
cross-ref https://discuss.pytorch.org/t/matmul-mixed-dtypes/216044 for a pytorch discourse question.
Yeah I know, I didn't say it doesn't - I meant it does internal upcasting and then calls a routine with both dtypes being the same.