Output of `torch.sum` with unsigned input should be unsigned
According to the standard, the documentation of sum states for the dtype parameter:
If
None, the returned array must have the same data type asx, unlessxhas an integer data type supporting a smaller range of values than the default integer data type... In those latter cases: ... ifxhas an unsigned integer data type, the returned array must have an unsigned integer data type having the same number of bits as the default integer data type.
If I understand correctly, then the sums for unsigned dtype below should have uint64 dtype:
from array_api_compat import torch as xp
for dtype in [xp.int8, xp.int16, xp.int32, xp.int64,
xp.uint8, xp.uint16, xp.uint32, xp.uint64,
xp.float32, xp.float64, xp.complex32, xp.complex64]:
x = xp.asarray([1, 2, 3], dtype=dtype)
try:
print(xp.sum(x).dtype, dtype)
except RuntimeError as e:
print(e)
But the output is:
torch.int64 torch.int8
torch.int64 torch.int16
torch.int64 torch.int32
torch.int64 torch.int64
torch.int64 torch.uint8
torch.int64 torch.uint16
torch.int64 torch.uint32
torch.int64 torch.uint64
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64
I think this is at least partially fixable within array-api-compat.
Also, torch doesn't seem to natively support sum for most uint dtypes or complex32. If we change xp.sum(x) to xp.sum(x, dtype=dtype) in the code above, the output is:
torch.int8 torch.int8
torch.int16 torch.int16
torch.int32 torch.int32
torch.int64 torch.int64
torch.uint8 torch.uint8
"sum_cpu" not implemented for 'UInt16'
"sum_cpu" not implemented for 'UInt32'
"sum_cpu" not implemented for 'UInt64'
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64
It would be helpful if array-api-compat would implement sum for these types even if that means upcasting to a supported type before summing and then downcasting. (There is a slightly larger chance of overflow with int64 than with uint64, and it's possible that the conversion will not be safe, so it's up for discussion what should happen in those cases.)
Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries? If not, should I report this to PyTorch (if it is not already reported)?
Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries?
So far, I think things have just been reported upstream on an ad-hoc basis.
+1 to reporting it upstream. Ideally, there's an upstream issue, plus a reference to in a workaround in -compat.
EDIT: This isn't a hard requirement, of course.
Looks like support for uints is already tracked at pytorch/pytorch#58743. It happens to be near the top. Oh, and it has its own issue, pytorch/pytorch#58734.