functorch icon indicating copy to clipboard operation
functorch copied to clipboard

`__getitem__` fails to vmap for one dimensional tensors

Open zdevito opened this issue 2 years ago • 3 comments

Example code:

from functorch import vmap
import torch
def f(x):
    return torch.ones(4)[x]

y = torch.arange(2)
print(vmap(f)(y))

Output:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-3-8b227da701ad>](https://r6s5at9hp4a-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220422-060045-RC00_443622635#) in <module>()
      5 
      6 y = torch.arange(2)
----> 7 print(vmap(f)(y))

1 frames
[<ipython-input-3-8b227da701ad>](https://r6s5at9hp4a-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220422-060045-RC00_443622635#) in f(x)
      2 import torch
      3 def f(x):
----> 4     return torch.ones(4)[x]
      5 
      6 y = torch.arange(2)

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

Expected Output:

tensor([1., 1.])

Commentary:

This is likely __getitem__ applying an optimization when an input has 0 dims (not including its single batch dim), not realizing that there are batch dimensions. Indexing with batched tensors that have at least one normal dimension does work, e.g. torch.ones(4)[x[None]][0])

zdevito avatar Apr 25 '22 22:04 zdevito

Duplicate of https://github.com/pytorch/functorch/issues/745 ?

vfdev-5 avatar Apr 25 '22 22:04 vfdev-5

Closed #745 in favor of this.

zou3519 avatar Apr 26 '22 16:04 zou3519