xla
xla copied to clipboard
Ops unnecessarily failing on tensors with integer values
🐛 Bug
Tensors defined with integer values causing superfluous errors. Behavior is inconsistent across torch and xla, and error message is unclear. The error message notes that we are trying to lower a Float[1] tensor, but then says it got S64 when expecting a float or complex number.
To Reproduce
Define a tensor using integers
>>> x_xla = torch.tensor([3], device=xm.xla_device())
>>> y_xla = torch.tan(x_xla)
>>> y_xla
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 344, in __repr__
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
File "/opt/conda/lib/python3.7/site-packages/torch/_tensor_str.py", line 484, in _str
return _str_intern(self, tensor_contents=tensor_contents)
File "/opt/conda/lib/python3.7/site-packages/torch/_tensor_str.py", line 341, in _str_intern
self = self.to('cpu')
RuntimeError: Error while lowering: [Float[1]] aten::tan, location=<module>@<stdin>:1
XLA builder error: INVALID_ARGUMENT: Expected element type in shape to be floating or complex for cosine operation; got S64.:
Frames:
<module> (<stdin>:1)
Note that things work fine using a float explicitly:
>>> x_xla = torch.tensor([3.0], device=xm.xla_device())
>>> y_xla = torch.tan(x_xla)
>>> y_xla
tensor([-0.1425], device='xla:0')
Expected behavior
For non-XLA PyTorch, we can do the following
>>> x = torch.tensor([3])
>>> x
tensor([3])
>>> y = torch.tan(x)
>>> y
tensor([-0.1425])
I would expect to be able to do this in XLA, or at least get a clearer error message.
Environment
- Reproducible on XLA backend [CPU/TPU]: CPU
- torch_xla version: built from HEAD
Additional context
my guess is that xla:tan only handles floating point... this is easily fixable, we can do a shape check for tan and cast it to floating point before passing to xla::tan in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/ops/ops_lower_fn.cpp#L139. through I am not sure how many people actually pass an integer to torch.tan.
Addressed the issue with https://github.com/pytorch/xla/pull/4333