functorch
functorch copied to clipboard
.item() error without calling .item
Hi, I have an error showing up that looks like it's occurring within Pytorch internals:
RuntimeError Traceback (most recent call last)
Input In [103], in <cell line: 2>()
1 # this isn't working, something about the memory access and in place ops make it unhappy
----> 2 hij = vectorized(vmap_sm, torch.tensor(modified_lengths), -1, 1)
File ~/anaconda3/envs/gpt/lib/python3.10/site-packages/functorch/_src/vmap.py:383, in vmap.<locals>.wrapped(*args, **kwargs)
381 try:
382 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 383 batched_outputs = func(*batched_inputs, **kwargs)
384 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
385 finally:
Input In [99], in torch_nw.sco(self, x, lengths, gap, temp)
85 print(sm["L"][0],sm["L"][1])
86 # return hij, sm
---> 87 hij = hij[sm["L"][0],sm["L"][1]]
88 # hij = hij.index_select(0, torch.tensor([sm["L"][0],sm["L"][1]]))
89 print(hij.requires_grad)
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.
@maxzvyagin can you post minimal code snippet to reproduce the issue ?
I think this is the same issue as https://github.com/pytorch/pytorch/issues/124423