array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

Output of `torch.sum` with unsigned input should be unsigned

Open mdhaber opened this issue 11 months ago • 3 comments

According to the standard, the documentation of sum states for the dtype parameter:

If None, the returned array must have the same data type as x, unless x has an integer data type supporting a smaller range of values than the default integer data type... In those latter cases: ... if x has an unsigned integer data type, the returned array must have an unsigned integer data type having the same number of bits as the default integer data type.

If I understand correctly, then the sums for unsigned dtype below should have uint64 dtype:

from array_api_compat import torch as xp
for dtype in [xp.int8, xp.int16, xp.int32, xp.int64,
              xp.uint8, xp.uint16, xp.uint32, xp.uint64,
              xp.float32, xp.float64, xp.complex32, xp.complex64]:
    x = xp.asarray([1, 2, 3], dtype=dtype)
    try:
        print(xp.sum(x).dtype, dtype)
    except RuntimeError as e:
        print(e)

But the output is:

torch.int64 torch.int8
torch.int64 torch.int16
torch.int64 torch.int32
torch.int64 torch.int64
torch.int64 torch.uint8
torch.int64 torch.uint16
torch.int64 torch.uint32
torch.int64 torch.uint64
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64

I think this is at least partially fixable within array-api-compat.

Also, torch doesn't seem to natively support sum for most uint dtypes or complex32. If we change xp.sum(x) to xp.sum(x, dtype=dtype) in the code above, the output is:

torch.int8 torch.int8
torch.int16 torch.int16
torch.int32 torch.int32
torch.int64 torch.int64
torch.uint8 torch.uint8
"sum_cpu" not implemented for 'UInt16'
"sum_cpu" not implemented for 'UInt32'
"sum_cpu" not implemented for 'UInt64'
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64

It would be helpful if array-api-compat would implement sum for these types even if that means upcasting to a supported type before summing and then downcasting. (There is a slightly larger chance of overflow with int64 than with uint64, and it's possible that the conversion will not be safe, so it's up for discussion what should happen in those cases.)

Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries? If not, should I report this to PyTorch (if it is not already reported)?

mdhaber avatar Jan 24 '25 13:01 mdhaber

Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries?

So far, I think things have just been reported upstream on an ad-hoc basis.

lucascolley avatar Jan 26 '25 11:01 lucascolley

+1 to reporting it upstream. Ideally, there's an upstream issue, plus a reference to in a workaround in -compat.

EDIT: This isn't a hard requirement, of course.

ev-br avatar Jan 26 '25 12:01 ev-br

Looks like support for uints is already tracked at pytorch/pytorch#58743. It happens to be near the top. Oh, and it has its own issue, pytorch/pytorch#58734.

mdhaber avatar Jan 29 '25 06:01 mdhaber