mage icon indicating copy to clipboard operation
mage copied to clipboard

A question about vocab_size in token embbeding

Open tanbuzheng opened this issue 10 months ago • 23 comments

Hi! When I read your source code, I found you set vocab_size = self.codebook_size + 1000 + 1 in token embbeding stage. Why not directly set vocal_size=self.codebook_size? What does the extra 1001 embeddings mean? Are these embeddings of class labels and mask tokens? Can I understand it this way, that is, when there is no class condition, vocal_size should be set to self.codebook+1?

Looking forward to your reply!

tanbuzheng avatar Apr 24 '24 09:04 tanbuzheng

Hi, thanks for your interest! Yes, the vocab_size should be self.codebook_size + 1 when there is no class condition. We set it to self.codebook_size + 1000 + 1 just because our pre-trained checkpoint from JAX uses this redundant codebook size, and to load the pre-trained weights we need to keep it that way.

LTH14 avatar Apr 24 '24 14:04 LTH14

Thanks for your reply! It was very kind of you! But if so, how should I set self.fake_class_label? Is it reasonable to set it with any value within 0-1024?أ‿أ

tanbuzheng avatar Apr 24 '24 16:04 tanbuzheng

You can actually set it to any value larger than or equal to 1024, and smaller than 1024+1000+1 -- but the pre-trained model set it to 1100 (again, a legacy issue).

LTH14 avatar Apr 24 '24 16:04 LTH14

OK! Thanks a lot!

tanbuzheng avatar Apr 25 '24 02:04 tanbuzheng

By the way, the current code does not seem to contain the contrastive loss part, I would like to ask if you have any plans to release the complete training code including this part?

tanbuzheng avatar Apr 25 '24 13:04 tanbuzheng

Unfortunately I don't have access to the original JAX code now, so there is no plan to release contrastive training part. However, that part is quite straight-forward if you want to re-implement it -- simply a SimCLR-based contrastive loss similar to this one.

LTH14 avatar Apr 25 '24 19:04 LTH14

Hello, auther! Did you re-train the VQGAN by yourself? It seems different from the pre-trained model released by VQGAN. So if I want to apply MAGE to the other datasets, what should I pay attention to when training VQGAN?

tanbuzheng avatar Jun 02 '24 10:06 tanbuzheng

Yes, VQGAN is trained by ourselves. From my experience, VQGAN is much harder to train than a continuous VAE, especially the GAN loss and perceptual loss part is very important -- warmup epoch, GAN/perceptual weights, the discriminator you use, etc.. One important thing is to always visualize the reconstruction result -- you may get a low reconstruction error, but the visual quality could be quite blurry.

LTH14 avatar Jun 02 '24 12:06 LTH14

Thanks a lot! But I would like to ask, why don't you just use the pre-trained model provided by VQGAN? Is there any problem here?

tanbuzheng avatar Jun 02 '24 12:06 tanbuzheng

The reconstruction FID of the original VQGAN is too poor (7.94), which bounds the generation performance. We follow many practices of ViT-VQGAN in our training, including larger batch size (256) and styleGAN discriminator (instead of patchGAN). The reconstruction FID of the tokenizer in MAGE is 2-3, which is significantly better than the original VQGAN.

LTH14 avatar Jun 02 '24 18:06 LTH14

I got it!Thank you so much!

tanbuzheng avatar Jun 03 '24 01:06 tanbuzheng

Sorry to bother you again. My computing resources are limited, I wonder if I don't use contrastive loss or just use moco v2 in MAGE training, can I set the batch size to a small one, such as 256? The MAGE adopts ViT-B.

tanbuzheng avatar Jun 03 '24 02:06 tanbuzheng

I haven't tested small batch sizes. For SimCLR-based contrastive loss (as we used), a large batch size is typically needed. If you don't use the contrastive loss, a smaller batch size might also be fine (although MAE and MAGE both use large batch sizes)

LTH14 avatar Jun 03 '24 05:06 LTH14

I would say a 256 batch size will not give you too bad performance if you don't use contrastive loss -- just maybe slightly worse than a large batch size

LTH14 avatar Jun 03 '24 05:06 LTH14

Ok, thanks again!

tanbuzheng avatar Jun 03 '24 06:06 tanbuzheng

Hello, auther!Sorry to bother you again. According to my understanding, mage can be trained on a v100 when batchsize=64. Recently I made a preliminary attempt to train mage on a 3090 with batchsize 64, but out of memory appears. Do you have any experience solving this problem?

tanbuzheng avatar Jun 13 '24 04:06 tanbuzheng

The MAGE-B model can be trained with batch size=64. MAGE-L can be trained with batch size=32. I never used 3090 before -- the V100 we use has 32GB memory.

LTH14 avatar Jun 13 '24 05:06 LTH14

Thank you very much! I just fund out that I was training MAGE-L. I should try to train MAGE-B instead.

tanbuzheng avatar Jun 13 '24 07:06 tanbuzheng

Another question is that I accidentally found when I used VQGAN before. The ResnetBlock in the vqgan you used does not contain a real shortcut, so is this a special design?

tanbuzheng avatar Jun 13 '24 07:06 tanbuzheng

Good catch -- It is just a stupid bug in Google's internal implementation -- but the performance is actually fine.

LTH14 avatar Jun 13 '24 08:06 LTH14

Yes, don't mind. It performs very well! Thanks again!

tanbuzheng avatar Jun 13 '24 10:06 tanbuzheng

Dear author, I want to ask another queation. During the training of the mage encoder,why do you keep some masked tokens from being dropping? Why not adopting a dynamic mask dropping ratio, but instead setting the mask dropping ratio to a fixed value such as 0.5?

tanbuzheng avatar Jun 15 '24 05:06 tanbuzheng

The dropping ratio is set to the minimum masking ratio, which is 0.5. The actual masking ratio (which determines the number of masked tokens) is sampled from [0.5, 1.0].

LTH14 avatar Jun 17 '24 05:06 LTH14