functorch
functorch copied to clipboard
`__getitem__` fails to vmap for one dimensional tensors
Example code:
from functorch import vmap
import torch
def f(x):
return torch.ones(4)[x]
y = torch.arange(2)
print(vmap(f)(y))
Output:
RuntimeError Traceback (most recent call last)
[<ipython-input-3-8b227da701ad>](https://r6s5at9hp4a-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220422-060045-RC00_443622635#) in <module>()
5
6 y = torch.arange(2)
----> 7 print(vmap(f)(y))
1 frames
[<ipython-input-3-8b227da701ad>](https://r6s5at9hp4a-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220422-060045-RC00_443622635#) in f(x)
2 import torch
3 def f(x):
----> 4 return torch.ones(4)[x]
5
6 y = torch.arange(2)
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.
Expected Output:
tensor([1., 1.])
Commentary:
This is likely __getitem__
applying an optimization when an input has 0 dims (not including its single batch dim), not realizing that there are batch dimensions. Indexing with batched tensors that have at least one normal dimension does work, e.g. torch.ones(4)[x[None]][0])
Duplicate of https://github.com/pytorch/functorch/issues/745 ?
Closed #745 in favor of this.