MaskGIT-pytorch
MaskGIT-pytorch copied to clipboard
vq_gan reconstruction results blurry using default code
Hi, thank you for this very interesting work! I'm currently trying to train the vq-gan part on my few-shot dataset (e.g. ~300 dog or cat images) with resolution 256x256. However using the default settings on the code, after training for 200 epochs the reconstruction results still look kinda blurry (as shown below, first row is real image, second row is reconstructed image after training)
And after comparing the code with the setup in the paper, I currently found two differences:
- the default embedding dimension is 256 in the code, where it is 768 used in the paper
- the non-local block is single head attention, where the paper used 8-head attention
I'm not sure whether these differences may cause the blurry results of this extent? or are there any other factors I need to pay attention to ? Thanks!
Hello, the first stage of the VQGAN implementation might not be perfect. The embedding dimension might be a worthy thing to try, I used 256 because my machine couldnt do much more. Also check out the original repo configs https://github.com/CompVis/taming-transformers/tree/master/configs and see if you could change some hyperparameters, also I believe that they used a latent dim of 256, where did you find the 768?. Also thanks for pointing out the missing heads, but I would actually think that the authors used SingleHead Attention. If you look at the original VQGAN repo: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L376 I see no signs of Multiple Heads.
And I actually dont know about other things that could be different from the original repo. The encoder and decoder are exactly the same in terms of architecture and parameter count, the codebook was the same as they used and the discriminator was taken from the PatchGAN paper, just as the authors did.
So I believe that it could only be hyperparameters which should be changed.
Have you tried the main repo paper of VQGAN and tried your dataset? How did it perform there? As expected?
Thank you for the prompt detailed reply! I haven't tried the images on VQGAN yet, and am currently trying to run it. Also, about the settings in the paper, it is in section 4.1 Experimental Setup in the paper, second paragraph, where it says all models have 24 layers, 8 attention heads, 768 embedding dimensions and 3072 hidden dimensions. I'm not sure is this the correct place to look? Also, what does 'hidden dimension" referring to?
where it says all models have 24 layers, 8 attention heads, 768 embedding dimensions and 3072 hidden dimensions
This is for the second stage of the VQGAN, so for the transformer part. This has nothing to do with the reconstruction part which is learned first. The correct place to look is the original VQGAN paper https://arxiv.org/pdf/2012.09841.pdf. MaskGIT is building on top and just replaces the second stage and leaves the reconstruction learning untouched.
Thanks for your clarification, this helps me a lot
No problem, also just drop your adjustments you made hear if you found something to work for your case. Maybe others can benefit from it too.
Sure. Also, could you share some pretrained MaskGIT so that we can play with a bit? Thanks!
I have meet the seem question in reconstruction, and when start the discriminator the image will add some patches of noise.
I met similar problem with the Taming-Transformer VQGAN implementation. I think a quick fix is to let the generator be trained for a couple epoch, then start the training of discriminator, and set smaller (in my case, it is 0.2) discriminator loss weight on generated image. https://github.com/CompVis/taming-transformers/issues/93