Rotation sampling grid is undesireably low-res
🐛 Describe the bug
https://github.com/pytorch/vision/blob/f799a5348d990de02f9189d9496f369bacfe5cf3/torchvision/transforms/_functional_tensor.py#L594
This function that creates the resampling grid will use the default tensor dtype. If this is bfloat16/float16 and the image is sufficiently large (over 250x250 px), this results in significant quantizing that seems undesirable. The output will look extremely subsampled from the original. This would not occur if the dtype was higher precision.
My suggestion would be to update these linspace to either be dtype=theta.dtype or explicitly dtype=torch.float32 to prevent this from happening.
Versions
Current
Thanks for the report @danielgordon10 . This area of the transforms isn't actively developed / maintained anymore. We are now supporting the torchvision.transforms.v2 namespace. Can you please clarify whether you're observing the same issue with the utilities in v2? Also, can you please provide a minimal reproducing example, to give us an idea of which public APIs are impacted? Thank you
v2 has a similar issue though it may not be exactly the same. depending on the type of input, the sampling grid will be lower resolution than makes sense resulting in blocky outputs. You can especially see it on the border of the image.
>>> import cv2
>>> import torch
>>> from torchvision.transforms.v2 import functional as tvf
>>> img = cv2.imread("lenna.png")
>>> cv2.imshow("lenna", img)
>>> cv2.waitKey(0)
>>> img_t = torch.as_tensor(img)
>>> img2 = tvf.rotate(img_t.to(dtype=torch.float32, device="cuda").permute(2, 0, 1), 10).permute(1, 2, 0).to(dtype=torch.uint8, device="cpu")
>>> cv2.imshow("lenna2", img2.numpy())
>>> cv2.waitKey(0)
>>> img2 = tvf.rotate(img_t.to(dtype=torch.bfloat16, device="cuda").permute(2, 0, 1), 10).permute(1, 2, 0).to(dtype=torch.uint8, device="cpu")
>>> cv2.imshow("lenna2", img2.numpy())
>>> cv2.waitKey(0)
Have you tried using a different interpolation method? I tried your code usinig interpolation=tvf.InterpolationMethod.BILINEAR and got improved results.
float32:
bfloat16: