taming-transformers
taming-transformers copied to clipboard
Reconstructing image from codes highlights issues with VectorQuantizer2
Using a VQGAN model, we can encode an image to zs, but it also returns some other things including the indices for the codebook. I'm trying to reconstruct an image from these codes (actually new ones from a transformer but this is the simpler case).
The first thing I tried was the useful-looking function decode_code (taming/models/vqgan.py line 66). Sadly this gives an error:
AttributeError: 'VectorQuantizer2' object has no attribute 'embed_code'
There was an issue raised about this (https://github.com/CompVis/taming-transformers/issues/42) but it was simply dismissed as 'old code no longer used.'
Trying to find a workaround I ran into various other bugs in VectorQuantizer2 - for example,
unmap_to_all(self, inds) references self.used which isn't defined anywhere, so attempting to call unmap_to_all gives an error.
It's probably worth removing or fixing these bits of old code so others don't have similar issues. For example, decode_code could do something like:
quant_z = self.quantize.embedding(code_b).reshape(1, 16, 16, 256).permute(0,3,1,2)
return self.decode(quant_z)
where the equivalent z shape here is (1, 256, 16, 16) and code_b is the indices.
(Or even better use quant_z = self.quantize.get_codebook_entry(code_b, shape) )
I don't know the code well enough to tell if/how things like unmap_to_all should be modified but hopefully the above info is enough to either fix or remove decode_code.
Code showing image -> codebook entries (idx) -> back to an image we can display for anyone else stuck on this (assumes f16 model):
im = Image.open('im.jpeg').convert('RGB').resize((256, 256))
im_tensor = torch.tensor(np.array(im)).permute(2, 0, 1) / 255
z, a, b = vqgan_model.encode(im_tensor.to(device).unsqueeze(0) * 2 - 1)
idx = b[-1] # The codebook entries (16x16 tokens so shape is [256])
z_q = vqgan_model.quantize.embedding(idx).reshape(1, 16, 16, 256).permute(0,3,1,2)
plt.imshow(vqgan_model.decode(z_q).add(1).div(2).cpu().squeeze().permute(1, 2, 0))
I hope this is helpful to someone.