xla icon indicating copy to clipboard operation
xla copied to clipboard

Error when trying to cast int32 to int64

Open qihqi opened this issue 1 year ago • 0 comments

🐛 Bug

If I add a int32 tensor, to a int64 scalar tensor (of rank 0); then the result cannot be casted to int64.

The error message is

RuntimeError: Error while lowering: [] xla::cast, location=test_math_1@test_expor
t_misc.py:53, xla_shape=s64[1,10]{1,0}, dynamic_dims: (), type=s64, dtype=Long, stype=Int
XLA builder error: INVALID_ARGUMENT: Binary op shift-left with different element types: s64[] and s32[].:

It looks like is triggered here: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/convert_ops.cpp#L27

To Reproduce

Run the following script:

import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()

input_ids = torch.tensor(1, dtype=torch.int32)
input_ids2 = torch.tensor(
    [[ 478, 1896, 1097,  299,  373, 1668, 1906, 1579,  145,  112]],
   device=device,
   dtype=torch.int32)
print('works', (input_ids + input_ids2).long())

input_ids = 1
input_ids2 = torch.tensor(
    [[ 478, 1896, 1097,  299,  373, 1668, 1906, 1579,  145,  112]],
   device=device,
   dtype=torch.int32)
print('works', (input_ids + input_ids2).long())

input_ids = torch.tensor(1)
input_ids2 = torch.tensor(
    [[ 478, 1896, 1097,  299,  373, 1668, 1906, 1579,  145,  112]],
   device=device,
   dtype=torch.int32)
print('fails', (input_ids + input_ids2).long())

qihqi avatar Feb 15 '24 02:02 qihqi