vision icon indicating copy to clipboard operation
vision copied to clipboard

Rotation sampling grid is undesireably low-res

Open danielgordon10 opened this issue 8 months ago • 3 comments

🐛 Describe the bug

https://github.com/pytorch/vision/blob/f799a5348d990de02f9189d9496f369bacfe5cf3/torchvision/transforms/_functional_tensor.py#L594

This function that creates the resampling grid will use the default tensor dtype. If this is bfloat16/float16 and the image is sufficiently large (over 250x250 px), this results in significant quantizing that seems undesirable. The output will look extremely subsampled from the original. This would not occur if the dtype was higher precision.

My suggestion would be to update these linspace to either be dtype=theta.dtype or explicitly dtype=torch.float32 to prevent this from happening.

Versions

Current

danielgordon10 avatar Apr 22 '25 01:04 danielgordon10

Thanks for the report @danielgordon10 . This area of the transforms isn't actively developed / maintained anymore. We are now supporting the torchvision.transforms.v2 namespace. Can you please clarify whether you're observing the same issue with the utilities in v2? Also, can you please provide a minimal reproducing example, to give us an idea of which public APIs are impacted? Thank you

NicolasHug avatar Apr 22 '25 11:04 NicolasHug

v2 has a similar issue though it may not be exactly the same. depending on the type of input, the sampling grid will be lower resolution than makes sense resulting in blocky outputs. You can especially see it on the border of the image.

>>> import cv2
>>> import torch
>>> from torchvision.transforms.v2 import functional as tvf
>>> img = cv2.imread("lenna.png")
>>> cv2.imshow("lenna", img)
>>> cv2.waitKey(0)

Image

>>> img_t = torch.as_tensor(img)
>>> img2 = tvf.rotate(img_t.to(dtype=torch.float32, device="cuda").permute(2, 0, 1), 10).permute(1, 2, 0).to(dtype=torch.uint8, device="cpu")
>>> cv2.imshow("lenna2", img2.numpy())
>>> cv2.waitKey(0)

Image

>>> img2 = tvf.rotate(img_t.to(dtype=torch.bfloat16, device="cuda").permute(2, 0, 1), 10).permute(1, 2, 0).to(dtype=torch.uint8, device="cpu")
>>> cv2.imshow("lenna2", img2.numpy())
>>> cv2.waitKey(0)

Image

danielgordon10 avatar May 02 '25 21:05 danielgordon10

Have you tried using a different interpolation method? I tried your code usinig interpolation=tvf.InterpolationMethod.BILINEAR and got improved results.

float32: Image

bfloat16: Image

chengolivia avatar Jun 15 '25 02:06 chengolivia