EntropyBottleneck with adjustable bin width
Feature
Support for Custom Bin Width in EntropyBottleneck
Motivation
So far only bin width equal to 1 is considered, but would be good to have this as tunable option.
Additional context
Are the only methods that should be changed quantize and _likelihood ? Or are there other important changes I am missing..
I'm not sure about _get_medians.
Here is how I would change quantize:
def quantize(
self, inputs: Tensor, mode: str, means: Optional[Tensor] = None, bin_width: float = 1.0) -> Tensor:
if mode not in ("noise", "dequantize", "symbols"):
raise ValueError(f'Invalid quantization mode: "{mode}"')
if mode == "noise":
half = bin_width / 2
noise = torch.empty_like(inputs).uniform_(-half, half)
inputs = inputs + noise
return inputs
outputs = inputs.clone()
if means is not None:
outputs -= means
outputs = torch.round(outputs / bin_width) * bin_width
if mode == "dequantize":
if means is not None:
outputs += means
return outputs
assert mode == "symbols", mode
outputs = outputs.int()
return outputs
and here _likelihood:
def _likelihood(self, inputs: Tensor, bin_width: float = 1.0, stop_gradient: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
half = bin_width / 2 # Adjust based on the bin width
lower = self._logits_cumulative(inputs - half, stop_gradient=stop_gradient)
upper = self._logits_cumulative(inputs + half, stop_gradient=stop_gradient)
likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)
return likelihood, lower, upper
Your further guidance is appreciated! Thank you!
At a quick glance, this should work for training, though I wonder if the lossless entropy coder also needs adjustment.
A simpler method might be to just rescale the outputs by the desired bin width instead. Since both are uniform quantizers, it should be equivalent.
Thank you for the suggestion! Here's how I implemented it:
Overview of Changes:
I added a bin_width parameter to the quantize, dequantize, compress, and decompress methods of the EntropyModel to handle different quantization bin widths at test time.
Quantize and Dequantize Methods:
def quantize(
self, inputs: Tensor, mode: str, means: Optional[Tensor] = None, bin_width: float = 1.0
) -> Tensor:
if mode not in ("noise", "dequantize", "symbols"):
raise ValueError(f'Invalid quantization mode: "{mode}"')
# Scale inputs by bin width before quantization
inputs_scaled = inputs / bin_width
if mode == "noise":
half = float(0.5)
noise = torch.empty_like(inputs_scaled).uniform_(-half, half)
inputs_scaled = inputs_scaled + noise
return inputs_scaled * bin_width # Scale back after adding noise
outputs = inputs_scaled.clone()
if means is not None:
outputs -= means / bin_width # Scale means accordingly
outputs = torch.round(outputs) # Quantize to nearest integer
if mode == "dequantize":
if means is not None:
outputs += means / bin_width
return outputs * bin_width # Scale back to original
assert mode == "symbols", mode
outputs = outputs.int()
return outputs
@staticmethod
def dequantize(
inputs: Tensor, means: Optional[Tensor] = None, dtype: torch.dtype = torch.float, bin_width: float = 1.0
) -> Tensor:
if means is not None:
outputs = inputs.type_as(means)
outputs += means / bin_width # Adjust means
else:
outputs = inputs.type(dtype)
return outputs * bin_width # Scale back to original
Compress and Decompress Methods:
def compress(self, inputs, indexes, means=None, bin_width: float = 1.0):
symbols = self.quantize(inputs, "symbols", means, bin_width)
if len(inputs.size()) < 2:
raise ValueError(
"Invalid `inputs` size. Expected a tensor with at least 2 dimensions."
)
if inputs.size() != indexes.size():
raise ValueError("`inputs` and `indexes` should have the same size.")
self._check_cdf_size()
self._check_cdf_length()
self._check_offsets_size()
strings = []
for i in range(symbols.size(0)):
rv = self.entropy_coder.encode_with_indexes(
symbols[i].reshape(-1).int().tolist(),
indexes[i].reshape(-1).int().tolist(),
self._quantized_cdf.tolist(),
self._cdf_length.reshape(-1).int().tolist(),
self._offset.reshape(-1).int().tolist(),
)
strings.append(rv)
return strings
def decompress(
self,
strings: str,
indexes: torch.IntTensor,
dtype: torch.dtype = torch.float,
means: torch.Tensor = None,
bin_width: float = 1.0,
):
if not isinstance(strings, (tuple, list)):
raise ValueError("Invalid `strings` parameter type.")
if not len(strings) == indexes.size(0):
raise ValueError("Invalid strings or indexes parameters")
if len(indexes.size()) < 2:
raise ValueError(
"Invalid `indexes` size. Expected a tensor with at least 2 dimensions."
)
self._check_cdf_size()
self._check_cdf_length()
self._check_offsets_size()
if means is not None:
if means.size()[:2] != indexes.size()[:2]:
raise ValueError("Invalid means or indexes parameters")
if means.size() != indexes.size():
for i in range(2, len(indexes.size())):
if means.size(i) != 1:
raise ValueError("Invalid means parameters")
cdf = self._quantized_cdf
outputs = cdf.new_empty(indexes.size())
for i, s in enumerate(strings):
values = self.entropy_coder.decode_with_indexes(
s,
indexes[i].reshape(-1).int().tolist(),
cdf.tolist(),
self._cdf_length.reshape(-1).int().tolist(),
self._offset.reshape(-1).int().tolist(),
)
outputs[i] = torch.tensor(
values, device=outputs.device, dtype=outputs.dtype
).reshape(outputs[i].size())
outputs = self.dequantize(outputs, means, dtype, bin_width)
return outputs
EntropyBottleneck Changes:
In the EntropyBottleneck, I added the bin_width parameter to the compress and decompress methods:
def compress(self, x, bin_width=1.0):
indexes = self._build_indexes(x.size())
medians = self._get_medians().detach()
spatial_dims = len(x.size()) - 2
medians = self._extend_ndims(medians, spatial_dims)
medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1)))
return super().compress(x, indexes, medians, bin_width)
def decompress(self, strings, size, bin_width=1.0):
output_size = (len(strings), self._quantized_cdf.size(0), *size)
indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
medians = self._extend_ndims(self._get_medians().detach(), len(size))
medians = medians.expand(len(strings), *([-1] * (len(size) + 1)))
return super().decompress(strings, indexes, medians.dtype, medians, bin_width)
My goal with these changes is to test different bin_width values only at test time and analyze the tradeoff between the length of the bitstream and the reconstruction error.
Do you think rescaling inputs and means by bin_width inside the quantize and dequantize methods is the correct approach?
Thank you!
Apologies for the late response.
It looks like that should work for the runtime codec.
However, if you wanted to train it as well, you would also need to modify EntropyBottleneck.forward:
class EntropyBottleneck(EntropyModel):
def forward(
self, x: Tensor, training: Optional[bool] = None, bin_width=1.0
) -> Tuple[Tensor, Tensor]:
...
outputs = self.quantize(
values,
"noise" if training else "dequantize",
self._get_medians(),
bin_width=bin_width,
)
...
...And use it as follows:
@register_model("bmshj2018-factorized-vbr")
class FactorizedPriorVbr(CompressionModel):
def forward(self, x, *, bin_width=1.0):
y = self.g_a(x)
y_hat, y_likelihoods = self.entropy_bottleneck(y, bin_width=bin_width)
x_hat = self.g_s(y_hat)
def compress(self, x, *, bin_width=1.0):
y = self.g_a(x)
y_strings = self.entropy_bottleneck.compress(y, bin_width=bin_width)
return {"strings": [y_strings], "shape": y.size()[-2:]}
def decompress(self, strings, shape, *, bin_width=1.0):
assert isinstance(strings, list) and len(strings) == 1
y_hat = self.entropy_bottleneck.decompress(strings[0], shape, bin_width=bin_width)
x_hat = self.g_s(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}
QVRF[^1] does something quite similar, though with two differences:
- They apply
bin_width($≝ 1 / a$) to theGaussianConditionalfor the mean-scale hyperprior (mbt2018-mean) model. - They "finetune" a single high-rate (λ=0.18) model over a set $A = \{ a_1, a_2, \ldots, a_n \}$ of different bin widths to ensure that the model works well over those bin widths.
We use staged training strategies, where the network parameters are optimized on $λ=0.18$ for the first 2000 epochs. Then, $A$ are optimized jointly with noise approximation for 500 epochs and straight-through estimation for another 500 epochs.
[^1]: Tong et al. QVRF: A Quantization-error-aware Variable Rate Framework for Learned Image Compression. https://arxiv.org/pdf/2303.05744
Thank you very much!
Hello, I was wondering whether I need to include any additional changes beyond those discussed in my previous comment.
In particular, I’m interested in reproducing the results shown in Figure 5(a) of the paper “Variable Rate Deep Image Compression With a Conditional Autoencoder” by Yoojin Choi, Mostafa El-Khamy, and Jungwon Lee:
Any guidance on necessary modifications or configurations would be greatly appreciated.
Thanks in advance!