CompressAI
CompressAI copied to clipboard
How to .update() a GaussianConditional entropy model and properly call .compress() and .decompress()?
Hello! I'm trying to train my own custom autoencoder model while integrating EntropyBottleneck and GaussianConditional. Here's a snippet of my class:
class AEWithEntropy(nn.Module):
def __init__(self, freeze_base=True):
super().__init__()
self.base = PNC32() # original ae
C = self.base.encoder2.out_channels # 32
# 1×1 convs for hyperprior (no need to modify PNC32)
self.hyper_encoder = nn.Conv2d(C, C, kernel_size=1)
self.hyper_decoder = nn.Conv2d(C, C, kernel_size=1)
self.entropy_z = EntropyBottleneck(C)
self.gauss_y = GaussianConditional(None)
if freeze_base:
for p in self.base.parameters(): p.requires_grad = False
def forward(self, x, tail_length=None, quantize_level=0):
y = self.base.encode(x) # y is the original compressed/encoded latent feature
z = self.hyper_encoder(y) # z captures additional information (e.g., statistics like variance) about 'y' to improve compression.
z_q, z_lh = self.entropy_z(z) # z_q is the quantized version of z, and z_lh is the likelihood of z given the model.
sigma = self.hyper_decoder(z_q) # sigma is used to model the distribution of 'y' (e.g., as a Gaussian with mean and variance).
y_q, y_lh = self.gauss_y(y, sigma) # y_q is the quantized version of y, and y_lh is the likelihood of y given the model, which will be fed into decoder/reconstructer
recon = self.base.decode(y_q)
return recon, y_lh, z_lh
def compress(self, x):
y = self.base.encode(x)
z = self.hyper_encoder(y)
z_bytes = self.entropy_z.compress(z)
z_q = self.entropy_z.decompress(z_bytes)
sigma = self.hyper_decoder(z_q)
y_bytes = self.gauss_y.compress(y, sigma)
return {"z": z_bytes, "y": y_bytes}
def decompress(self, streams):
z_q = self.entropy_z.decompress(streams["z"])
sigma = self.hyper_decoder(z_q)
y_q = self.gauss_y.decompress(streams["y"], sigma)
return self.base.decode(y_q)
The model trains just fine, but during evaluation, I'm trying to run the compress() and decompress() methods to print the total number of bytes my model can compress/encode images into. I'm aware I'm supposed to call some .update(), and I successfully do for the EntropyBottleneck via model.entropy_z.update(force=True)`, but I can't seem to do the same thing with GaussianConditional. I notice that I need to do something with CDF/scale tables but I'm stuck here. Here's the full error/output log:
Traceback (most recent call last):
File "autoencoder_train.py", line 499, in <module>
final_test_loss, final_psnr, final_ssim = eval_autoencoder(model=model, dataloader=test_loader, criterion=criterion, device=device, max_tail_length=drops, quantize=args.quantize)
File "autoencoder_train.py", line 255, in eval_autoencoder
outputs, _, _ = model(x=inputs, tail_length=max_tail_length, quantize_level=quantize)
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
output.reraise()
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/_utils.py", line 543, in reraise
raise exception
ValueError: Caught ValueError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "autoencoder_train.py", line 40, in forward
strings = self.entropy_z.compress(y) # .compress(y) quantizes/compresses the latent feature y, and returns a string of bytes
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 541, in compress
return super().compress(x, indexes, medians)
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 254, in compress
self._check_cdf_size()
File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 216, in _check_cdf_size
raise ValueError("Uninitialized CDFs. Run update() first")
ValueError: Uninitialized CDFs. Run update() first
Any help/suggestions? Thank you!