vision icon indicating copy to clipboard operation
vision copied to clipboard

Can't use gaussian_blur if sigma is a tensor on gpu

Open Xact-sniper opened this issue 1 year ago • 2 comments

🐛 Describe the bug

Admittedly perhaps an unconventional use, but I'm using gaussian_blur in my model to blur attention maps and I want to have the sigma be a parameter.

It would work, except for this function: https://github.com/pytorch/vision/blob/06ad737628abc3a1e617571dc03cbdd5b36ea96a/torchvision/transforms/_functional_tensor.py#L725

x is not moved to the device that sigma is on.

I believe it is like this in all torchvision versions.

WORKS:

import torch
from torchvision.transforms.functional import gaussian_blur
k = 15
s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True)

blurred = gaussian_blur(torch.randn(1, 3, 256, 256), k, [s])
blurred.mean().backward()
print(s.grad)
>>> tensor(-4.6193e-05)

DOES NOT:

import torch
from torchvision.transforms.functional import gaussian_blur
k = 15
s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda')

blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s])
blurred.mean().backward()
print(s.grad)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[D:\Temp\ipykernel_39000\3525683463.py](file:///D:/Temp/ipykernel_39000/3525683463.py) in <module>
      4 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda')
      5 
----> 6 blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s])
      7 blurred.mean().backward()
      8 print(s.grad)

[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\functional.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/functional.py) in gaussian_blur(img, kernel_size, sigma)
   1361         t_img = pil_to_tensor(img)
   1362 
-> 1363     output = F_t.gaussian_blur(t_img, kernel_size, sigma)
   1364 
   1365     if not isinstance(img, torch.Tensor):

[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in gaussian_blur(img, kernel_size, sigma)
    749 
    750     dtype = img.dtype if torch.is_floating_point(img) else torch.float32
--> 751     kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    752     kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
    753 

[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel2d(kernel_size, sigma, dtype, device)
    736     kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
    737 ) -> Tensor:
--> 738     kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    739     kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    740     kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])

[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel1d(kernel_size, sigma)
    727 
    728     x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
--> 729     pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    730     kernel1d = pdf / pdf.sum()
    731 

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

~~I don't know about the convention, like whether device should be passed in, but the simplest fix I believe would just be to change: 728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) to: 728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)~~

Actually that won't when sigma is just a float. So I guess there could be a check for whether its a float or a float tensor.

Versions

[pip3] efficientunet-pytorch==0.0.6 [pip3] ema-pytorch==0.4.5 [pip3] flake8==6.0.0 [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.24.3 [pip3] numpydoc==1.4.0 [pip3] pytorch-msssim==1.0.0 [pip3] siren-pytorch==0.1.7 [pip3] torch==2.2.2+cu118 [pip3] torch-cluster==1.6.0+pt113cu116 [pip3] torch_geometric==2.4.0 [pip3] torch-scatter==2.1.0+pt113cu116 [pip3] torch-sparse==0.6.16+pt113cu116 [pip3] torch-spline-conv==1.2.1+pt113cu116 [pip3] torch-tools==0.1.5 [pip3] torchaudio==2.2.2+cu118 [pip3] torchbearer==0.5.3 [pip3] torchmeta==1.8.0 [pip3] torchvision==0.17.2+cu118 [pip3] uformer-pytorch==0.0.8 [pip3] vit-pytorch==1.5.0 [conda] blas 1.0 mkl [conda] efficientunet-pytorch 0.0.6 pypi_0 pypi [conda] ema-pytorch 0.4.5 pypi_0 pypi [conda] mkl 2021.4.0 haa95532_640 [conda] mkl-service 2.4.0 py39h2bbff1b_0 [conda] mkl_fft 1.3.1 py39h277e83a_0 [conda] mkl_random 1.2.2 py39hf11a4ad_0 [conda] numpy 1.24.3 pypi_0 pypi [conda] numpydoc 1.4.0 py39haa95532_0 [conda] pytorch-cuda 11.6 h867d48c_1 pytorch [conda] pytorch-msssim 1.0.0 pypi_0 pypi [conda] pytorch-mutex 1.0 cuda pytorch [conda] siren-pytorch 0.1.7 pypi_0 pypi [conda] torch 1.13.0 pypi_0 pypi [conda] torch-cluster 1.6.0+pt113cu116 pypi_0 pypi [conda] torch-geometric 2.4.0 pypi_0 pypi [conda] torch-scatter 2.1.0+pt113cu116 pypi_0 pypi [conda] torch-sparse 0.6.16+pt113cu116 pypi_0 pypi [conda] torch-spline-conv 1.2.1+pt113cu116 pypi_0 pypi [conda] torch-tools 0.1.5 pypi_0 pypi [conda] torchaudio 0.9.1 pypi_0 pypi [conda] torchbearer 0.5.3 pypi_0 pypi [conda] torchmeta 1.8.0 pypi_0 pypi [conda] torchvision 0.17.2+cu118 pypi_0 pypi [conda] uformer-pytorch 0.0.8 pypi_0 pypi [conda] vit-pytorch 1.5.0 pypi_0 pypi

Xact-sniper avatar May 01 '24 23:05 Xact-sniper

Hi @pmeier, it is a good-first issue? Will it be suitable for a beginner?

Bhavay-2001 avatar May 08 '24 11:05 Bhavay-2001

Hi @Xact-sniper, I think a possible fix is that we can add torch.device to this function call here.

Can you pls send a reproducible code snippet?

@pmeier @NicolasHug any possible suggestions to this?

Bhavay-2001 avatar May 09 '24 11:05 Bhavay-2001

Thanks for the report @Xact-sniper . This will be fixed by https://github.com/pytorch/vision/pull/8426.

What you're trying to do won't work on the V2 version, I'm not sure why just yet, I've opened https://github.com/pytorch/vision/issues/8450 to keep track of that separately.

NicolasHug avatar May 29 '24 12:05 NicolasHug