xla
xla copied to clipboard
Error when trying to cast int32 to int64
🐛 Bug
If I add a int32 tensor, to a int64 scalar tensor (of rank 0); then the result cannot be casted to int64.
The error message is
RuntimeError: Error while lowering: [] xla::cast, location=test_math_1@test_expor
t_misc.py:53, xla_shape=s64[1,10]{1,0}, dynamic_dims: (), type=s64, dtype=Long, stype=Int
XLA builder error: INVALID_ARGUMENT: Binary op shift-left with different element types: s64[] and s32[].:
It looks like is triggered here: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/convert_ops.cpp#L27
To Reproduce
Run the following script:
import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()
input_ids = torch.tensor(1, dtype=torch.int32)
input_ids2 = torch.tensor(
[[ 478, 1896, 1097, 299, 373, 1668, 1906, 1579, 145, 112]],
device=device,
dtype=torch.int32)
print('works', (input_ids + input_ids2).long())
input_ids = 1
input_ids2 = torch.tensor(
[[ 478, 1896, 1097, 299, 373, 1668, 1906, 1579, 145, 112]],
device=device,
dtype=torch.int32)
print('works', (input_ids + input_ids2).long())
input_ids = torch.tensor(1)
input_ids2 = torch.tensor(
[[ 478, 1896, 1097, 299, 373, 1668, 1906, 1579, 145, 112]],
device=device,
dtype=torch.int32)
print('fails', (input_ids + input_ids2).long())