magvit2-pytorch icon indicating copy to clipboard operation
magvit2-pytorch copied to clipboard

Training difficulties

Open LouisSerrano opened this issue 1 year ago • 51 comments

Hi, I am experiencing some difficulties during the training of magvit2. I don't know if I made some mistakes somewhere or where the problem might be coming from.

It seems that my understanding of the paper might me be erroneous, I tried with 2 codebooks of size 512 and I can't seem to fit the training data. The training is really unstable. I tried to replace the LFQ with a classical VQ and it was more stable and was able to converge. What is the config that you tried for training the model ?

LouisSerrano avatar Nov 02 '23 10:11 LouisSerrano

@LouisSerrano hey Louis and thank you for reporting this

that is disappointing to hear, as I had high hopes for LFQ. there is the possibility that I implemented it incorrectly, but the authors had already given it a code review, so that seems less likely. i'll be training the model tonight (regaining access to my deep learning rig) so i can see this instability for myself.

lucidrains avatar Nov 02 '23 13:11 lucidrains

@LouisSerrano as a silver lining, i guess your experiments show the rest of the framework to be working ok. i can run a few toy experiments on LFQ this morning before running it on a big image dataset tonight and see what the issue is

lucidrains avatar Nov 02 '23 13:11 lucidrains

@LouisSerrano just to make sure, are you on the latest version of the vector-quantize-pytorch library?

lucidrains avatar Nov 02 '23 13:11 lucidrains

Ok thank you very much, it might simply be an error from my side, probably in the configuration. I used the model config that you suggested in the Readme.md, and for LFQ I used 1 codebook with size 512.

LouisSerrano avatar Nov 02 '23 13:11 LouisSerrano

@lucidrains yes I am using 1.10.4

LouisSerrano avatar Nov 02 '23 13:11 LouisSerrano

@LouisSerrano no that should be fine, you could try increasing to 2048 or 4096, but shouldn't make a big difference

lucidrains avatar Nov 02 '23 13:11 lucidrains

@lucidrains As a precision I tried with a different dataset, a smaller one actually, but which should be less challenging than the ones from the paper.

LouisSerrano avatar Nov 02 '23 13:11 LouisSerrano

@LouisSerrano this is kind of a stab in the dark, but would you like to try adding lfq_activation = nn.Tanh(), # from torch import nn to your VideoTokenizer init? (v0.0.64)

lucidrains avatar Nov 02 '23 14:11 lucidrains

they didn't have this in the paper, but i think it should make sense

lucidrains avatar Nov 02 '23 14:11 lucidrains

@LouisSerrano let me do some toy experiments right now and make sure it isn't some obvious bug

lucidrains avatar Nov 02 '23 14:11 lucidrains

@lucidrains Ok I am going to try to increase the codebook size, just in case. Sure, I can check with tanh activation.

LouisSerrano avatar Nov 02 '23 14:11 LouisSerrano

@LouisSerrano thank you! 🙏

lucidrains avatar Nov 02 '23 14:11 lucidrains

@lucidrains Thanks for the awesome work !

LouisSerrano avatar Nov 02 '23 14:11 LouisSerrano

@LouisSerrano yea no problem! well, we'll see if this work pans out. thank you for attempting to replicate!

lucidrains avatar Nov 02 '23 14:11 lucidrains

Screen Shot 2023-11-02 at 7 25 07 AM

lucidrains avatar Nov 02 '23 14:11 lucidrains

for the toy task, LFQ looks ok compared to VQ. Tanh won't work, but you can try Tanh x 10

lfq_activation = lambda x: torch.nn.functional.tanh(x) * 10,

lucidrains avatar Nov 02 '23 14:11 lucidrains

Ok thanks I will try this ! I'll let you know if I encounter some issues. Also what kind of weights do you use for the commitment and entropy loss ?

LouisSerrano avatar Nov 02 '23 14:11 LouisSerrano

@LouisSerrano i think for commitment loss it is ok to keep it at the value as regular VQ of 1., but i'm not sure about the respective per-sample and batch entropy

lucidrains avatar Nov 02 '23 14:11 lucidrains

@LouisSerrano i just increased the batch entropy weight a bit, to what works for the toy task (just fashion mnist)

lucidrains avatar Nov 02 '23 14:11 lucidrains

@lucidrains Ok great, thanks for the tips.

LouisSerrano avatar Nov 02 '23 14:11 LouisSerrano

@LouisSerrano hey Louis, just noticed that you used the default layer structure from the readme. feel free to try the updated one and see how that fares with LFQ

would still be interested to know if tanh(x) * 10 helps resolve your previous instability, if you have an experiment in progress.

lucidrains avatar Nov 02 '23 15:11 lucidrains

This was with my previous config. I benchmarked against fsq. I also show the aux_loss, which is going crazy for lfq_tanh

Capture d’écran 2023-11-02 à 16 51 12

LouisSerrano avatar Nov 02 '23 15:11 LouisSerrano

@LouisSerrano thank you! do you have the original unstable LFQ plot too? and wow, you actually gave FSQ a test drive; what level settings did you use for it?

lucidrains avatar Nov 02 '23 15:11 lucidrains

@LouisSerrano did you compare it to baseline VQ by any chance?

lucidrains avatar Nov 02 '23 15:11 lucidrains

that plot for FSQ looks very encouraging, maybe i'll add it tonight to the repository to further research. thanks again for running the experiment and sharing it

lucidrains avatar Nov 02 '23 16:11 lucidrains

I think my main concern was on the auxiliary loss, I did not look to much into details but I assumed that somehow the model struggled to have a diverse codebook and good reconstructions Capture d’écran 2023-11-02 à 17 02 35

LouisSerrano avatar Nov 02 '23 16:11 LouisSerrano

@lucidrains I did not use VQ yet but I am gonna launch it tonight. So I'll let you know when I get the results. Also, I tested two recommended configurations for fsq: levels_big = [8,8,8,6,5] and levels_mid = [8,5,5,5]. On my simple problem I would have expected similar results for both and it is the case, so I am happy with it !

LouisSerrano avatar Nov 02 '23 16:11 LouisSerrano

@lucidrains I will give a shot to your config, thanks again !

LouisSerrano avatar Nov 02 '23 16:11 LouisSerrano

@LouisSerrano ok, let me know how VQ goes on your end! i'm hoping it is worse than both LFQ and FSQ, on that broken layer architecture you used (which is still a good benchmark)

if VQ outperforms both LFQ and FSQ on both the old and new layer architecture, then i don't know what to do. that would be the worst outcome

lucidrains avatar Nov 02 '23 16:11 lucidrains

@LouisSerrano i'll be running experiments tonight and making sure image pretraining is seamless. will update on what i see

lucidrains avatar Nov 02 '23 16:11 lucidrains