CompressAI icon indicating copy to clipboard operation
CompressAI copied to clipboard

EntropyBottleneck with adjustable bin width

Open ghost opened this issue 1 year ago • 2 comments

Feature

Support for Custom Bin Width in EntropyBottleneck

Motivation

So far only bin width equal to 1 is considered, but would be good to have this as tunable option.

Additional context

Are the only methods that should be changed quantize and _likelihood ? Or are there other important changes I am missing.. I'm not sure about _get_medians.

Here is how I would change quantize:

def quantize(
    self, inputs: Tensor, mode: str, means: Optional[Tensor] = None, bin_width: float = 1.0) -> Tensor:
    if mode not in ("noise", "dequantize", "symbols"):
        raise ValueError(f'Invalid quantization mode: "{mode}"')

    if mode == "noise":
        half = bin_width / 2
        noise = torch.empty_like(inputs).uniform_(-half, half)
        inputs = inputs + noise
        return inputs

    outputs = inputs.clone()
    if means is not None:
        outputs -= means

    outputs = torch.round(outputs / bin_width) * bin_width

    if mode == "dequantize":
        if means is not None:
            outputs += means
        return outputs

    assert mode == "symbols", mode
    outputs = outputs.int()
    return outputs

and here _likelihood:

def _likelihood(self, inputs: Tensor, bin_width: float = 1.0, stop_gradient: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
    half = bin_width / 2  # Adjust based on the bin width
    lower = self._logits_cumulative(inputs - half, stop_gradient=stop_gradient)
    upper = self._logits_cumulative(inputs + half, stop_gradient=stop_gradient)
    likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)
    return likelihood, lower, upper

Your further guidance is appreciated! Thank you!

ghost avatar Sep 21 '24 12:09 ghost

At a quick glance, this should work for training, though I wonder if the lossless entropy coder also needs adjustment.

A simpler method might be to just rescale the outputs by the desired bin width instead. Since both are uniform quantizers, it should be equivalent.

YodaEmbedding avatar Sep 22 '24 06:09 YodaEmbedding

Thank you for the suggestion! Here's how I implemented it:

Overview of Changes:

I added a bin_width parameter to the quantize, dequantize, compress, and decompress methods of the EntropyModel to handle different quantization bin widths at test time.

Quantize and Dequantize Methods:

def quantize(
    self, inputs: Tensor, mode: str, means: Optional[Tensor] = None, bin_width: float = 1.0
) -> Tensor:
    if mode not in ("noise", "dequantize", "symbols"):
        raise ValueError(f'Invalid quantization mode: "{mode}"')

    # Scale inputs by bin width before quantization
    inputs_scaled = inputs / bin_width

    if mode == "noise":
        half = float(0.5)
        noise = torch.empty_like(inputs_scaled).uniform_(-half, half)
        inputs_scaled = inputs_scaled + noise
        return inputs_scaled * bin_width  # Scale back after adding noise

    outputs = inputs_scaled.clone()
    if means is not None:
        outputs -= means / bin_width  # Scale means accordingly

    outputs = torch.round(outputs)  # Quantize to nearest integer

    if mode == "dequantize":
        if means is not None:
            outputs += means / bin_width
        return outputs * bin_width  # Scale back to original

    assert mode == "symbols", mode
    outputs = outputs.int()
    return outputs

@staticmethod
def dequantize(
    inputs: Tensor, means: Optional[Tensor] = None, dtype: torch.dtype = torch.float, bin_width: float = 1.0
) -> Tensor:
    if means is not None:
        outputs = inputs.type_as(means)
        outputs += means / bin_width  # Adjust means
    else:
        outputs = inputs.type(dtype)
    return outputs * bin_width  # Scale back to original

Compress and Decompress Methods:

def compress(self, inputs, indexes, means=None, bin_width: float = 1.0):
    symbols = self.quantize(inputs, "symbols", means, bin_width)

    if len(inputs.size()) < 2:
        raise ValueError(
            "Invalid `inputs` size. Expected a tensor with at least 2 dimensions."
        )

    if inputs.size() != indexes.size():
        raise ValueError("`inputs` and `indexes` should have the same size.")

    self._check_cdf_size()
    self._check_cdf_length()
    self._check_offsets_size()

    strings = []
    for i in range(symbols.size(0)):
        rv = self.entropy_coder.encode_with_indexes(
            symbols[i].reshape(-1).int().tolist(),
            indexes[i].reshape(-1).int().tolist(),
            self._quantized_cdf.tolist(),
            self._cdf_length.reshape(-1).int().tolist(),
            self._offset.reshape(-1).int().tolist(),
        )
        strings.append(rv)
    return strings

def decompress(
    self,
    strings: str,
    indexes: torch.IntTensor,
    dtype: torch.dtype = torch.float,
    means: torch.Tensor = None,
    bin_width: float = 1.0,
):
    if not isinstance(strings, (tuple, list)):
        raise ValueError("Invalid `strings` parameter type.")

    if not len(strings) == indexes.size(0):
        raise ValueError("Invalid strings or indexes parameters")

    if len(indexes.size()) < 2:
        raise ValueError(
            "Invalid `indexes` size. Expected a tensor with at least 2 dimensions."
        )

    self._check_cdf_size()
    self._check_cdf_length()
    self._check_offsets_size()

    if means is not None:
        if means.size()[:2] != indexes.size()[:2]:
            raise ValueError("Invalid means or indexes parameters")
        if means.size() != indexes.size():
            for i in range(2, len(indexes.size())):
                if means.size(i) != 1:
                    raise ValueError("Invalid means parameters")

    cdf = self._quantized_cdf
    outputs = cdf.new_empty(indexes.size())

    for i, s in enumerate(strings):
        values = self.entropy_coder.decode_with_indexes(
            s,
            indexes[i].reshape(-1).int().tolist(),
            cdf.tolist(),
            self._cdf_length.reshape(-1).int().tolist(),
            self._offset.reshape(-1).int().tolist(),
        )
        outputs[i] = torch.tensor(
            values, device=outputs.device, dtype=outputs.dtype
        ).reshape(outputs[i].size())
    outputs = self.dequantize(outputs, means, dtype, bin_width)
    return outputs

EntropyBottleneck Changes:

In the EntropyBottleneck, I added the bin_width parameter to the compress and decompress methods:

def compress(self, x, bin_width=1.0):
    indexes = self._build_indexes(x.size())
    medians = self._get_medians().detach()
    spatial_dims = len(x.size()) - 2
    medians = self._extend_ndims(medians, spatial_dims)
    medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1)))
    return super().compress(x, indexes, medians, bin_width)

def decompress(self, strings, size, bin_width=1.0):
    output_size = (len(strings), self._quantized_cdf.size(0), *size)
    indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
    medians = self._extend_ndims(self._get_medians().detach(), len(size))
    medians = medians.expand(len(strings), *([-1] * (len(size) + 1)))
    return super().decompress(strings, indexes, medians.dtype, medians, bin_width)

My goal with these changes is to test different bin_width values only at test time and analyze the tradeoff between the length of the bitstream and the reconstruction error.

Do you think rescaling inputs and means by bin_width inside the quantize and dequantize methods is the correct approach?

Thank you!

uzumaki671 avatar Oct 10 '24 10:10 uzumaki671

Apologies for the late response.

It looks like that should work for the runtime codec.

However, if you wanted to train it as well, you would also need to modify EntropyBottleneck.forward:

class EntropyBottleneck(EntropyModel):
    def forward(
        self, x: Tensor, training: Optional[bool] = None, bin_width=1.0
    ) -> Tuple[Tensor, Tensor]:
        ...

        outputs = self.quantize(
            values,
            "noise" if training else "dequantize",
            self._get_medians(),
            bin_width=bin_width,
        )

        ...

...And use it as follows:

@register_model("bmshj2018-factorized-vbr")
class FactorizedPriorVbr(CompressionModel):
    def forward(self, x, *, bin_width=1.0):
        y = self.g_a(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y, bin_width=bin_width)
        x_hat = self.g_s(y_hat)

    def compress(self, x, *, bin_width=1.0):
        y = self.g_a(x)
        y_strings = self.entropy_bottleneck.compress(y, bin_width=bin_width)
        return {"strings": [y_strings], "shape": y.size()[-2:]}

    def decompress(self, strings, shape, *, bin_width=1.0):
        assert isinstance(strings, list) and len(strings) == 1
        y_hat = self.entropy_bottleneck.decompress(strings[0], shape, bin_width=bin_width)
        x_hat = self.g_s(y_hat).clamp_(0, 1)
        return {"x_hat": x_hat}

QVRF[^1] does something quite similar, though with two differences:

  • They apply bin_width ($≝ 1 / a$) to the GaussianConditional for the mean-scale hyperprior (mbt2018-mean) model.
  • They "finetune" a single high-rate (λ=0.18) model over a set $A = \{ a_1, a_2, \ldots, a_n \}$ of different bin widths to ensure that the model works well over those bin widths.

    We use staged training strategies, where the network parameters are optimized on $λ=0.18$ for the first 2000 epochs. Then, $A$ are optimized jointly with noise approximation for 500 epochs and straight-through estimation for another 500 epochs.

[^1]: Tong et al. QVRF: A Quantization-error-aware Variable Rate Framework for Learned Image Compression. https://arxiv.org/pdf/2303.05744

YodaEmbedding avatar Nov 04 '24 22:11 YodaEmbedding

Thank you very much!

uzumaki671 avatar Nov 05 '24 09:11 uzumaki671

Hello, I was wondering whether I need to include any additional changes beyond those discussed in my previous comment.

In particular, I’m interested in reproducing the results shown in Figure 5(a) of the paper “Variable Rate Deep Image Compression With a Conditional Autoencoder” by Yoojin Choi, Mostafa El-Khamy, and Jungwon Lee:

Image

Any guidance on necessary modifications or configurations would be greatly appreciated.

Thanks in advance!

uzumaki671 avatar Jul 31 '25 13:07 uzumaki671