Can't use gaussian_blur if sigma is a tensor on gpu
🐛 Describe the bug
Admittedly perhaps an unconventional use, but I'm using gaussian_blur in my model to blur attention maps and I want to have the sigma be a parameter.
It would work, except for this function: https://github.com/pytorch/vision/blob/06ad737628abc3a1e617571dc03cbdd5b36ea96a/torchvision/transforms/_functional_tensor.py#L725
x is not moved to the device that sigma is on.
I believe it is like this in all torchvision versions.
WORKS:
import torch
from torchvision.transforms.functional import gaussian_blur
k = 15
s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True)
blurred = gaussian_blur(torch.randn(1, 3, 256, 256), k, [s])
blurred.mean().backward()
print(s.grad)
>>> tensor(-4.6193e-05)
DOES NOT:
import torch
from torchvision.transforms.functional import gaussian_blur
k = 15
s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda')
blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s])
blurred.mean().backward()
print(s.grad)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[D:\Temp\ipykernel_39000\3525683463.py](file:///D:/Temp/ipykernel_39000/3525683463.py) in <module>
4 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda')
5
----> 6 blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s])
7 blurred.mean().backward()
8 print(s.grad)
[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\functional.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/functional.py) in gaussian_blur(img, kernel_size, sigma)
1361 t_img = pil_to_tensor(img)
1362
-> 1363 output = F_t.gaussian_blur(t_img, kernel_size, sigma)
1364
1365 if not isinstance(img, torch.Tensor):
[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in gaussian_blur(img, kernel_size, sigma)
749
750 dtype = img.dtype if torch.is_floating_point(img) else torch.float32
--> 751 kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
752 kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
753
[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel2d(kernel_size, sigma, dtype, device)
736 kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
737 ) -> Tensor:
--> 738 kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
739 kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
740 kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel1d(kernel_size, sigma)
727
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
--> 729 pdf = torch.exp(-0.5 * (x / sigma).pow(2))
730 kernel1d = pdf / pdf.sum()
731
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
~~I don't know about the convention, like whether device should be passed in, but the simplest fix I believe would just be to change:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
to:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)~~
Actually that won't when sigma is just a float. So I guess there could be a check for whether its a float or a float tensor.
Versions
[pip3] efficientunet-pytorch==0.0.6 [pip3] ema-pytorch==0.4.5 [pip3] flake8==6.0.0 [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.24.3 [pip3] numpydoc==1.4.0 [pip3] pytorch-msssim==1.0.0 [pip3] siren-pytorch==0.1.7 [pip3] torch==2.2.2+cu118 [pip3] torch-cluster==1.6.0+pt113cu116 [pip3] torch_geometric==2.4.0 [pip3] torch-scatter==2.1.0+pt113cu116 [pip3] torch-sparse==0.6.16+pt113cu116 [pip3] torch-spline-conv==1.2.1+pt113cu116 [pip3] torch-tools==0.1.5 [pip3] torchaudio==2.2.2+cu118 [pip3] torchbearer==0.5.3 [pip3] torchmeta==1.8.0 [pip3] torchvision==0.17.2+cu118 [pip3] uformer-pytorch==0.0.8 [pip3] vit-pytorch==1.5.0 [conda] blas 1.0 mkl [conda] efficientunet-pytorch 0.0.6 pypi_0 pypi [conda] ema-pytorch 0.4.5 pypi_0 pypi [conda] mkl 2021.4.0 haa95532_640 [conda] mkl-service 2.4.0 py39h2bbff1b_0 [conda] mkl_fft 1.3.1 py39h277e83a_0 [conda] mkl_random 1.2.2 py39hf11a4ad_0 [conda] numpy 1.24.3 pypi_0 pypi [conda] numpydoc 1.4.0 py39haa95532_0 [conda] pytorch-cuda 11.6 h867d48c_1 pytorch [conda] pytorch-msssim 1.0.0 pypi_0 pypi [conda] pytorch-mutex 1.0 cuda pytorch [conda] siren-pytorch 0.1.7 pypi_0 pypi [conda] torch 1.13.0 pypi_0 pypi [conda] torch-cluster 1.6.0+pt113cu116 pypi_0 pypi [conda] torch-geometric 2.4.0 pypi_0 pypi [conda] torch-scatter 2.1.0+pt113cu116 pypi_0 pypi [conda] torch-sparse 0.6.16+pt113cu116 pypi_0 pypi [conda] torch-spline-conv 1.2.1+pt113cu116 pypi_0 pypi [conda] torch-tools 0.1.5 pypi_0 pypi [conda] torchaudio 0.9.1 pypi_0 pypi [conda] torchbearer 0.5.3 pypi_0 pypi [conda] torchmeta 1.8.0 pypi_0 pypi [conda] torchvision 0.17.2+cu118 pypi_0 pypi [conda] uformer-pytorch 0.0.8 pypi_0 pypi [conda] vit-pytorch 1.5.0 pypi_0 pypi
Hi @pmeier, it is a good-first issue? Will it be suitable for a beginner?
Hi @Xact-sniper, I think a possible fix is that we can add torch.device to this function call here.
Can you pls send a reproducible code snippet?
@pmeier @NicolasHug any possible suggestions to this?
Thanks for the report @Xact-sniper . This will be fixed by https://github.com/pytorch/vision/pull/8426.
What you're trying to do won't work on the V2 version, I'm not sure why just yet, I've opened https://github.com/pytorch/vision/issues/8450 to keep track of that separately.