CompressAI icon indicating copy to clipboard operation
CompressAI copied to clipboard

How to .update() a GaussianConditional entropy model and properly call .compress() and .decompress()?

Open johnli25 opened this issue 7 months ago • 2 comments

Hello! I'm trying to train my own custom autoencoder model while integrating EntropyBottleneck and GaussianConditional. Here's a snippet of my class:

class AEWithEntropy(nn.Module):
    def __init__(self, freeze_base=True):
        super().__init__()
        self.base = PNC32() # original ae

        C = self.base.encoder2.out_channels  # 32
        # 1×1 convs for hyperprior (no need to modify PNC32)
        self.hyper_encoder = nn.Conv2d(C, C, kernel_size=1)
        self.hyper_decoder = nn.Conv2d(C, C, kernel_size=1)

        self.entropy_z = EntropyBottleneck(C)
        self.gauss_y = GaussianConditional(None)

        if freeze_base:
            for p in self.base.parameters(): p.requires_grad = False

    def forward(self, x, tail_length=None, quantize_level=0):
        y   = self.base.encode(x) # y is the original compressed/encoded latent feature 
        z   = self.hyper_encoder(y) # z captures additional information (e.g., statistics like variance) about 'y' to improve compression.
        z_q, z_lh = self.entropy_z(z) # z_q is the quantized version of z, and z_lh is the likelihood of z given the model.
        sigma = self.hyper_decoder(z_q) # sigma is used to model the distribution of 'y' (e.g., as a Gaussian with mean and variance).
        y_q, y_lh  = self.gauss_y(y, sigma) # y_q is the quantized version of y, and y_lh is the likelihood of y given the model, which will be fed into decoder/reconstructer
        recon  = self.base.decode(y_q)
        return recon, y_lh, z_lh 
    
    def compress(self, x):
        y = self.base.encode(x)
        z = self.hyper_encoder(y)
        z_bytes = self.entropy_z.compress(z)
        z_q = self.entropy_z.decompress(z_bytes)
        sigma = self.hyper_decoder(z_q)
        y_bytes = self.gauss_y.compress(y, sigma)
        return {"z": z_bytes, "y": y_bytes}

    def decompress(self, streams):
        z_q  = self.entropy_z.decompress(streams["z"])
        sigma = self.hyper_decoder(z_q)
        y_q  = self.gauss_y.decompress(streams["y"], sigma)
        return self.base.decode(y_q)

The model trains just fine, but during evaluation, I'm trying to run the compress() and decompress() methods to print the total number of bytes my model can compress/encode images into. I'm aware I'm supposed to call some .update(), and I successfully do for the EntropyBottleneck via model.entropy_z.update(force=True)`, but I can't seem to do the same thing with GaussianConditional. I notice that I need to do something with CDF/scale tables but I'm stuck here. Here's the full error/output log:

Traceback (most recent call last):
  File "autoencoder_train.py", line 499, in <module>
    final_test_loss, final_psnr, final_ssim = eval_autoencoder(model=model, dataloader=test_loader, criterion=criterion, device=device, max_tail_length=drops, quantize=args.quantize)
  File "autoencoder_train.py", line 255, in eval_autoencoder
    outputs, _, _ = model(x=inputs, tail_length=max_tail_length, quantize_level=quantize)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/_utils.py", line 543, in reraise
    raise exception
ValueError: Caught ValueError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "autoencoder_train.py", line 40, in forward
    strings = self.entropy_z.compress(y) # .compress(y) quantizes/compresses the latent feature y, and returns a string of bytes
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 541, in compress
    return super().compress(x, indexes, medians)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 254, in compress
    self._check_cdf_size()
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 216, in _check_cdf_size
    raise ValueError("Uninitialized CDFs. Run update() first")
ValueError: Uninitialized CDFs. Run update() first

Any help/suggestions? Thank you!

johnli25 avatar May 08 '25 23:05 johnli25