mage
mage copied to clipboard
A question about vocab_size in token embbeding
Hi! When I read your source code, I found you set vocab_size = self.codebook_size + 1000 + 1 in token embbeding stage. Why not directly set vocal_size=self.codebook_size? What does the extra 1001 embeddings mean? Are these embeddings of class labels and mask tokens? Can I understand it this way, that is, when there is no class condition, vocal_size should be set to self.codebook+1?
Looking forward to your reply!
Hi, thanks for your interest! Yes, the vocab_size should be self.codebook_size + 1 when there is no class condition. We set it to self.codebook_size + 1000 + 1 just because our pre-trained checkpoint from JAX uses this redundant codebook size, and to load the pre-trained weights we need to keep it that way.
Thanks for your reply! It was very kind of you! But if so, how should I set self.fake_class_label? Is it reasonable to set it with any value within 0-1024?أ‿أ
You can actually set it to any value larger than or equal to 1024, and smaller than 1024+1000+1 -- but the pre-trained model set it to 1100 (again, a legacy issue).
OK! Thanks a lot!
By the way, the current code does not seem to contain the contrastive loss part, I would like to ask if you have any plans to release the complete training code including this part?
Unfortunately I don't have access to the original JAX code now, so there is no plan to release contrastive training part. However, that part is quite straight-forward if you want to re-implement it -- simply a SimCLR-based contrastive loss similar to this one.
Hello, auther! Did you re-train the VQGAN by yourself? It seems different from the pre-trained model released by VQGAN. So if I want to apply MAGE to the other datasets, what should I pay attention to when training VQGAN?
Yes, VQGAN is trained by ourselves. From my experience, VQGAN is much harder to train than a continuous VAE, especially the GAN loss and perceptual loss part is very important -- warmup epoch, GAN/perceptual weights, the discriminator you use, etc.. One important thing is to always visualize the reconstruction result -- you may get a low reconstruction error, but the visual quality could be quite blurry.
Thanks a lot! But I would like to ask, why don't you just use the pre-trained model provided by VQGAN? Is there any problem here?
The reconstruction FID of the original VQGAN is too poor (7.94), which bounds the generation performance. We follow many practices of ViT-VQGAN in our training, including larger batch size (256) and styleGAN discriminator (instead of patchGAN). The reconstruction FID of the tokenizer in MAGE is 2-3, which is significantly better than the original VQGAN.
I got it!Thank you so much!
Sorry to bother you again. My computing resources are limited, I wonder if I don't use contrastive loss or just use moco v2 in MAGE training, can I set the batch size to a small one, such as 256? The MAGE adopts ViT-B.
I haven't tested small batch sizes. For SimCLR-based contrastive loss (as we used), a large batch size is typically needed. If you don't use the contrastive loss, a smaller batch size might also be fine (although MAE and MAGE both use large batch sizes)
I would say a 256 batch size will not give you too bad performance if you don't use contrastive loss -- just maybe slightly worse than a large batch size
Ok, thanks again!
Hello, auther!Sorry to bother you again. According to my understanding, mage can be trained on a v100 when batchsize=64. Recently I made a preliminary attempt to train mage on a 3090 with batchsize 64, but out of memory appears. Do you have any experience solving this problem?
The MAGE-B model can be trained with batch size=64. MAGE-L can be trained with batch size=32. I never used 3090 before -- the V100 we use has 32GB memory.
Thank you very much! I just fund out that I was training MAGE-L. I should try to train MAGE-B instead.
Another question is that I accidentally found when I used VQGAN before. The ResnetBlock in the vqgan you used does not contain a real shortcut, so is this a special design?
Good catch -- It is just a stupid bug in Google's internal implementation -- but the performance is actually fine.
Yes, don't mind. It performs very well! Thanks again!
Dear author, I want to ask another queation. During the training of the mage encoder,why do you keep some masked tokens from being dropping? Why not adopting a dynamic mask dropping ratio, but instead setting the mask dropping ratio to a fixed value such as 0.5?
The dropping ratio is set to the minimum masking ratio, which is 0.5. The actual masking ratio (which determines the number of masked tokens) is sampled from [0.5, 1.0].