vision
vision copied to clipboard
Vmap causing TypeError when applied to Rotate
I'm running into an issue trying to vmap over the torchvision rotate
function. rotate()
requires an int or float input and does not accept single valued tensors. However vmap requires all batched inputs come as tensors. Additionally, one can't create a helper function that calls rotate(im, angle.item())
, as item calls are not allowed in vmap
.
There might be a simple solution to this I'm not seeing, but if not it would be nice for Rotate to accept single value tensors as inputs.
from functorch import vmap
from torchvision.transforms.functional import rotate
vrot = vmap(rotate, in_dims=(0,0), out_dims=0)
b, dim = 10, 64
inp_ims = torch.rand((b, dim, dim))
angles = torch.rand((b))
vrot(inp_ims, angles)
This gives: TypeError: Argument angle should be int or float
cc @vfdev-5 @datumbox
@Art-MC I think this wont work even if rotate
could accept a tensor. rotate
applied on 4d input images (N, C, H, W) would rotate images by the same angle. Currently, it is not possible to rotate them by different angles as we construct a single affine matrix. In theory we could extend the code to make this possible and create N affine matrices.
@vfdev-5 I see what you mean. From what I can tell, it looks like it would be possible to enable vmapping once only, as at the end of the day torch.nn.functional.grid_sample
takes up to a 5D input. This is something that would certainly be useful to me, but I'm not sure how much to others. I think the main item would be expanding transforms.functional._get_inverse_affine_matrix
which, as you said, currently creates the single matrix.
@Art-MC basically, you would like to "vectorize" the following code:
import torch
torch.manual_seed(11)
batch_size = 4
image_size = (32, 32)
angles = torch.randint(-180, 180, size=(batch_size, ))
images = torch.rand(batch_size, 3, *image_size)
print(angles)
rotation_matrix = torch.zeros(batch_size, 2, 3)
rotation_matrix[:, 0, 0] = torch.cos(angles * torch.pi / 180.0)
rotation_matrix[:, 1, 1] = rotation_matrix[:, 0, 0]
rotation_matrix[:, 0, 1] = -torch.sin(angles * torch.pi / 180.0) # +/- sin(angle)
rotation_matrix[:, 1, 0] = -rotation_matrix[:, 0, 1]
rotation_grids = torch.nn.functional.affine_grid(rotation_matrix, (batch_size, 3, *image_size))
out_images = torch.nn.functional.grid_sample(images, rotation_grids)
With functorch and using vmap approach, there are two issues currently,
- it is tricky to create a rotation matrix of size (2, 3) using batched angle and without doing inplace ops
- there is no appropriate batch rule for
torch.nn.functional.affine_grid
I was thinking about the following code
def rotation(image, angle):
rad = angle * torch.pi / 180.0
rotation_matrix = torch.zeros(2, 3)
rotation_matrix[0, 0] = torch.cos(rad)
rotation_matrix[0, 1] = -torch.sin(rad)
rotation_matrix[1, 0] = torch.sin(rad)
rotation_matrix[1, 1] = torch.cos(rad)
rotation_grid = torch.nn.functional.affine_grid(rotation_matrix, (3, *image_size))
return torch.nn.functional.grid_sample(images, rotation_grid)
vrotation = vmap(rotation)
Maybe, it worth opening an issue on pytorch/functorch repo to get a help on writing vectorizable rotation function