fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

[data2vec] Loss goes down and up again

Open Tomsen1410 opened this issue 2 years ago • 10 comments

I try to train data2vec on music data (the FMA dataset). I've made some modifications to the feature extractor ConvNet (I've made it a small ResNet essentially), and reduced the size of the transformer encoder (8 layers, 8 attention heads, 512 emb dim, 2048 emb_ffn_dim). Since I am new to training transformer models, I've tried to reproduce the paper's hyperparameters as closely as possible. My batch size is much smaller, so I've adjusted the learning rate accordingly etc..

The first few epochs seem promising. The loss goes down quite quickly and the embedding variance of the teacher and student models increase over time respectively (i.e. the model is not collapsing to a constant representation). However, after a while it abruptly changes and the loss goes up again and stays high over the remaining training (the variances stay high though). I've also experienced similar issues with another exponential moving average (EMA) model, DINO (the ResNet variant), a while ago.

When reducing the learning rate, the effect just happens later.

There indirectly seems to be a similar issue here.

Does anyone have an idea why this might be happening?

Tomsen1410 avatar Feb 02 '22 15:02 Tomsen1410

Okay, so I've plotted the teacher and student variances from two test runs with a small dataset and only for a few epochs:

https://imgur.com/a/dynF7P9

The student variance begins to converge to the teacher variance until a specific point. After that the loss quickly increases and there is a relatively constant gab between the two variances.

Not sure how to interpret this. Maybe this is the desired behaviour, since the teacher should always be "better" in order to lead the student.

Tomsen1410 avatar Feb 03 '22 14:02 Tomsen1410

Okay, so I've plotted the teacher and student variances from two test runs with a small dataset and only for a few epochs:

https://imgur.com/a/dynF7P9

The student variance begins to converge to the teacher variance until a specific point. After that the loss quickly increases and there is a relatively constant gab between the two variances.

Not sure how to interpret this. Maybe this is the desired behaviour, since the teacher should always be "better" in order to lead the student.

Hello, @Tomsen1410 ! Have you found a reasonable explanation? Follow is my strange loss plot... image

Ramlinbird avatar Feb 14 '22 06:02 Ramlinbird

Okay, so I've plotted the teacher and student variances from two test runs with a small dataset and only for a few epochs: https://imgur.com/a/dynF7P9 The student variance begins to converge to the teacher variance until a specific point. After that the loss quickly increases and there is a relatively constant gab between the two variances. Not sure how to interpret this. Maybe this is the desired behaviour, since the teacher should always be "better" in order to lead the student.

Hello, @Tomsen1410 ! Have you found a reasonable explanation? Follow is my strange loss plot... image

Hey! Unfortunately not. Your loss seems to be ok though? I am not sure how the "correct" loss function should look like. It is strange that my loss values are two orders of magnitude smaller. Would also be great to see a training plot from the authors. Could you tell me which hyperparameters you have used? (batch size, ...)

Tomsen1410 avatar Feb 15 '22 12:02 Tomsen1410

Okay, so I've plotted the teacher and student variances from two test runs with a small dataset and only for a few epochs: https://imgur.com/a/dynF7P9 The student variance begins to converge to the teacher variance until a specific point. After that the loss quickly increases and there is a relatively constant gab between the two variances. Not sure how to interpret this. Maybe this is the desired behaviour, since the teacher should always be "better" in order to lead the student.

Hello, @Tomsen1410 ! Have you found a reasonable explanation? Follow is my strange loss plot... image

Hey! Unfortunately not. Your loss seems to be ok though? I am not sure how the "correct" loss function should look like. It is strange that my loss values are two orders of magnitude smaller. Would also be great to see a training plot from the authors. Could you tell me which hyperparameters you have used? (batch size, ...)

Thanks. I didn't change any hyperparameters except max_tokens (since the memory error) in base_librispeech.yaml. @alexeib could you share your training plot with us, and help us to figure out this? Thanks so much.

Ramlinbird avatar Feb 16 '22 01:02 Ramlinbird

hey, if variances are jumping up and down, that looks like a collapse and you may want to lower you learning rate. i dont have variance plots for the nlp models (wasnt logging it when it was trained) but here is the loss curve:

image

here are a couple example plots from a reduced speech setup with variance etc (it should look somewhat similar with nlp but not exactly the same). this is also using a tri-stage lr scheduler that holds learning rate at peak rate for 90% of the training:

loss: image

pred var: image

target var: image

alexeib avatar Feb 16 '22 02:02 alexeib

@alexeib Thanks a lot for your quick reply! According to your loss plot, my training seems all right? And I also check my variance plot, image However, the pred var and target var is not flat in the end, they are still dropping, is this OK? What's the actual meaning of these values, and what relation should they have? (I read the source code, they are just standard variance of network's outputs.)

Ramlinbird avatar Feb 16 '22 02:02 Ramlinbird

Was there also instance_norm applied to the targets at the reduced speech setup @alexeib ?

Tomsen1410 avatar Mar 03 '22 17:03 Tomsen1410

hey, if variances are jumping up and down, that looks like a collapse and you may want to lower you learning rate. i dont have variance plots for the nlp models (wasnt logging it when it was trained) but here is the loss curve:

image

here are a couple example plots from a reduced speech setup with variance etc (it should look somewhat similar with nlp but not exactly the same). this is also using a tri-stage lr scheduler that holds learning rate at peak rate for 90% of the training:

loss: image

pred var: image

target var: image

@alexeib thank you for the logs. This behavior of the loss (going down and up and down again) is related to the near-optimality in BYOL (Bootstrap Your Own Latent) ?. The BYOL demonstrate the importance of near-optimal predictor for preventing collapse.

Orlllem avatar Jul 05 '22 15:07 Orlllem

hey, if variances are jumping up and down, that looks like a collapse and you may want to lower you learning rate. i dont have variance plots for the nlp models (wasnt logging it when it was trained) but here is the loss curve:

image

here are a couple example plots from a reduced speech setup with variance etc (it should look somewhat similar with nlp but not exactly the same). this is also using a tri-stage lr scheduler that holds learning rate at peak rate for 90% of the training:

loss: image

pred var: image

target var: image

@alexeib Hi, why does the NLP curve looks so different from the speech curve, and seems like it does not converge? Actually, my audio modality training curve is similar to your nlp curve. And my predict and target var does not seem to be collapsing. How do I check if my audio model converges?

image image image

a43992899 avatar Oct 04 '22 16:10 a43992899

FWIW, I was getting target_var < 0.1 error while training data2vec2.0 on my data. I lowered the learning rate from (default) 0.00075 to 0.00050 and the error vanished.

saurabh-kataria avatar Jan 24 '23 17:01 saurabh-kataria