CheckerboardLatentCodec broken in torch<2.0
Bug
Using the CheckerboardLatentCodec with a non-identity context_prediction module results in a runtime error during the forward pass. I believe this should only occur when using a torch version less than 2.0.
To Reproduce
Steps to reproduce the behavior:
- Instantiate a
CheckerboardLatentCodec. - Create any tensor and pass it to the
forward()method of the latent codec. - Observe bug.
Minimal working example:
import torch
from compressai.latent_codecs import CheckerboardLatentCodec, GaussianConditionalLatentCodec
from compressai.layers.layers import CheckerboardMaskedConv2d
lc = CheckerboardLatentCodec(
latent_codec = {
"y": GaussianConditionalLatentCodec()
},
context_prediction = CheckerboardMaskedConv2d(4, 8, kernel_size=5, stride=1, padding=2)
)
t = torch.randn((1, 4, 64, 64)) # arbitrary shape, just must match channel size in context_prediction layer
ctx = torch.randn((1, 8, 16, 16)) # arbitrary shape
output = lc(t, ctx)
This code results in the error:
File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/software/compressai/compressai/latent_codecs/checkerboard.py:149, in CheckerboardLatentCodec.forward(self, y, side_params)
147 return self._forward_onepass(y, side_params)
148 if self.forward_method == "twopass":
--> 149 return self._forward_twopass(y, side_params)
150 if self.forward_method == "twopass_faster":
151 return self._forward_twopass_faster(y, side_params)
File ~/software/compressai/compressai/latent_codecs/checkerboard.py:192, in CheckerboardLatentCodec._forward_twopass(self, y, side_params)
187 B, C, H, W = y.shape
189 params = y.new_zeros((B, C * 2, H, W))
191 y_hat_anchors = self._forward_twopass_step(
--> 192 y, side_params, params, self._y_ctx_zero(y), "anchor"
193 )
195 y_hat_non_anchors = self._forward_twopass_step(
196 y, side_params, params, self.context_prediction(y_hat_anchors), "non_anchor"
197 )
199 y_hat = y_hat_anchors + y_hat_non_anchors
File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File ~/software/compressai/compressai/latent_codecs/checkerboard.py:272, in CheckerboardLatentCodec._y_ctx_zero(self, y)
269 @torch.no_grad()
270 def _y_ctx_zero(self, y: Tensor) -> Tensor:
271 """Create a zero tensor with correct shape for y_ctx."""
--> 272 y_ctx_meta = self.context_prediction(y.to("meta"))
273 return y.new_zeros(y_ctx_meta.shape)
File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/software/compressai/compressai/layers/layers.py:144, in MaskedConv2d.forward(self, x)
141 def forward(self, x: Tensor) -> Tensor:
142 # TODO(begaintj): weight assigment is not supported by torchscript
143 self.weight.data = self.weight.data * self.mask
--> 144 return super().forward(x)
File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/nn/modules/conv.py:457, in Conv2d.forward(self, input)
456 def forward(self, input: Tensor) -> Tensor:
--> 457 return self._conv_forward(input, self.weight, self.bias)
File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/nn/modules/conv.py:453, in Conv2d._conv_forward(self, input, weight, bias)
449 if self.padding_mode != 'zeros':
450 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
451 weight, bias, self.stride,
452 _pair(0), self.dilation, self.groups)
--> 453 return F.conv2d(input, weight, bias, self.stride,
454 self.padding, self.dilation, self.groups)
NotImplementedError: convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function
Expected behavior
The code should not throw an error.
Environment
Output from python3 -m torch.utils.collect_env:
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.8.18 (default, Sep 11 2023, 13:40:15) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-44-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.7.99
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] open-clip-torch==2.7.0
[pip3] pytorch-lightning==1.4.2
[pip3] pytorch-msssim==1.0.0
[pip3] torch==1.12.1
[pip3] torch_geometric==2.5.3
[pip3] torchaudio==0.12.1
[pip3] torchmetrics==0.6.0
[pip3] torchvision==0.13.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 hb98b00a_13 conda-forge
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.23.1 py38h6c91a56_0
[conda] numpy-base 1.23.1 py38ha15fc14_0
[conda] open-clip-torch 2.7.0 pypi_0 pypi
[conda] pytorch 1.12.1 py3.8_cuda11.3_cudnn8.3.2_0 pytorch
[conda] pytorch-lightning 1.4.2 pypi_0 pypi
[conda] pytorch-msssim 1.0.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.13.0+cu117 pypi_0 pypi
[conda] torch-geometric 2.5.3 pypi_0 pypi
[conda] torchaudio 0.13.0+cu117 pypi_0 pypi
[conda] torchmetrics 0.6.0 pyhd8ed1ab_0 conda-forge
[conda] torchvision 0.14.0+cu117 pypi_0 pypi
- PyTorch / CompressAI Version: 1.21.1 / 1.2.6
- OS: Linux, Ubuntu 22.04.3
- How you installed PyTorch / CompressAI: source
- Build command you used (if compiling from source):
git clone https://github.com/InterDigitalInc/CompressAI compressai
cd compressai
pip install -U pip && pip install -e .
- Python version: 3.8.18
- CUDA/cuDNN version: 11.7
- GPU models and configuration: 1x NVIDIA GeForce RTX 3090
- Any other relevant information: N/A
Additional context
I am quite certain this is due to the fact that older pytorch versions do not support operations on tensors which are on the "meta" device. I think this was introduced with PyTorch 2.0 but I couldn't find anything definitive from a quick search.
I traced this back to commit eddb1bc, which uses meta device tensors to compute the expected size of the checkerboard context tensor. Replacing these lines with the previous version resolved the issue for me.
Thanks for the report.
I could add a version check for torch<2.0:
from packaging.version import Version
class CheckerboardLatentCodec(LatentCodec):
def _y_ctx_zero(self, y: Tensor) -> Tensor:
if Version(torch.__version__) < Version("2.0.0"):
return self._mask(self.context_prediction(y).detach(), "all")
return y.new_zeros(self.context_prediction(y.to("meta")).shape)
...but perhaps simpler is just to revert:
class CheckerboardLatentCodec(LatentCodec):
def _y_ctx_zero(self, y: Tensor) -> Tensor:
return self._mask(self.context_prediction(y).detach(), "all")
To be fair I don't actually know which is the earliest torch version that supports meta device tensors as I couldn't find any solid information.
Although I think the simpler fix is probably good enough. On my machine with a 14900K and a 3090 and an (unreasonably large) context size of (16, 192, 512, 512) it takes 0.06ms to execute that line on GPU. It does take about 4 seconds on CPU, but with a more reasonable context size of (16, 192, 32, 32) it takes roughly 80ms on CPU.