nitorch
nitorch copied to clipboard
Push/Pull: use undefined tensors in backward pass when values not needed
Currently, the backward passes of grid_push/grid_pull/etc require the forward input tensors as inputs even when the value they hold is not needed.
E.g., let the forward pass be grid_pull(image, grid)
, with image.requires_grad == True
and grid.requires_grad == False
. In that case, only grid
is needed in the backward pass. However, the current implementation still requires image
to be provided in order to know that its gradients are required and compute the output shape.
A better solution would be to provide an undefined tensor pseudo_image = torch.Tensor()
with pseudo_image.requires_grad == True
and pseudo_image.shape = image.shape + (0,)
.
(Or find another solution to specify the shape -> 0 batch dimension?)