xla icon indicating copy to clipboard operation
xla copied to clipboard

Ops unnecessarily failing on tensors with integer values

Open steventk-g opened this issue 3 years ago • 1 comments

🐛 Bug

Tensors defined with integer values causing superfluous errors. Behavior is inconsistent across torch and xla, and error message is unclear. The error message notes that we are trying to lower a Float[1] tensor, but then says it got S64 when expecting a float or complex number.

To Reproduce

Define a tensor using integers

>>> x_xla = torch.tensor([3], device=xm.xla_device())
>>> y_xla = torch.tan(x_xla)
>>> y_xla
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 344, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/opt/conda/lib/python3.7/site-packages/torch/_tensor_str.py", line 484, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/opt/conda/lib/python3.7/site-packages/torch/_tensor_str.py", line 341, in _str_intern
    self = self.to('cpu')
RuntimeError: Error while lowering: [Float[1]] aten::tan, location=<module>@<stdin>:1
XLA builder error: INVALID_ARGUMENT: Expected element type in shape to be floating or complex for cosine operation; got S64.: 
Frames:
  <module> (<stdin>:1)

Note that things work fine using a float explicitly:

>>> x_xla = torch.tensor([3.0], device=xm.xla_device())
>>> y_xla = torch.tan(x_xla)
>>> y_xla
tensor([-0.1425], device='xla:0')

Expected behavior

For non-XLA PyTorch, we can do the following

>>> x = torch.tensor([3])
>>> x
tensor([3])
>>> y = torch.tan(x)
>>> y
tensor([-0.1425])

I would expect to be able to do this in XLA, or at least get a clearer error message.

Environment

  • Reproducible on XLA backend [CPU/TPU]: CPU
  • torch_xla version: built from HEAD

Additional context

steventk-g avatar Jul 14 '22 23:07 steventk-g

my guess is that xla:tan only handles floating point... this is easily fixable, we can do a shape check for tan and cast it to floating point before passing to xla::tan in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/ops/ops_lower_fn.cpp#L139. through I am not sure how many people actually pass an integer to torch.tan.

JackCaoG avatar Jul 14 '22 23:07 JackCaoG

Addressed the issue with https://github.com/pytorch/xla/pull/4333

steventk-g avatar Dec 15 '22 23:12 steventk-g