xla
xla copied to clipboard
Remove upsample_*.vec ops
Fixes https://github.com/pytorch/pytorch/pull/85638
The upstream PR makes upsample_*.vec ops
CompositeImplicit, so the backward does not need an explicit implementation anymore.
The upstream PR's base branch is a bit outdated. Left a comment to rebase the PR.
PyTorch python op tests are failing:
======================================================================
ERROR: test_upsamplingNearest2d_xla (__main__.TestNNDeviceTypeXLA)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 391, in instantiated_test
raise rte
File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 378, in instantiated_test
result = test(self, **param_kwargs)
File "/tmp/pytorch/xla/test/../../test/test_nn.py", line 14537, in test_upsamplingNearest2d
helper(torch.contiguous_format, "nearest")
File "/tmp/pytorch/xla/test/../../test/test_nn.py", line 14500, in helper
out_t.backward(torch.randn_like(out_t))
File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 484, in backward
self, gradient, retain_graph, create_graph, inputs=inputs
File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 199, in backward
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I was able to see similar pt/xla cpp unit tests fail with the same error. I was in the assumption that the upstream PR makes the backward.vec
op's explicit implementation not required. @bdhirsh, do you have any idea why this error is caused in PyTorch?
cc @JackCaoG
@bdhirsh This message is too vaugue, do you spot anything that's obviously wrong here?
At risk of stating the obvious, that message happens when you try to run backward
on an element that does not have a requires_grad=True
or comes from a computation with gradient. E.g.
>>> import torch
>>> torch.randn(1).backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/lib/python3.10/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/usr/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Ok, so I think what's happening here is that you are removing the whole backward function, while in the other PR, they are removing only the backward function for the .vec
overload.
Ok, after removing the upsample_nearest2d.vec
and updating the existing GetOutputSizeWithScale
function to accept scale_h
and scale_w
as such: https://github.com/pytorch/xla/blob/1d0c3393fb48cb8740379e4fea9c37a0e131a7dd/torch_xla/csrc/aten_xla_type.cpp#L165-L175
I can confirm the related cpp tests are passing: UpsampleNearest2D
, UpsampleNearest2DWithScale
, UpsampleNearest2DBackward
, and UpsampleNearest2DBackwardWithScale
.
The CI will fail for now because the pinned PyTorch PR is not rebased with latest master. Once the upstream rebases and updates the commit pin, I'll re-run the CI.
Upstream PR has merged, deleting torch_pin and merging this.