satflow icon indicating copy to clipboard operation
satflow copied to clipboard

Try Teacher Forcing for RNNs

Open jacobbieker opened this issue 3 years ago • 2 comments

This is where, with some random chance, we give the RNN like ConvLSTM the ground truth label when its generating sequences in training. This can help with convergenc especially in the beginning, since if the model messes up in the second timestep, all later ones will be useless, but if we give it the GT image to continue the sequence it might still learn better representations.

jacobbieker avatar Jul 07 '21 14:07 jacobbieker

The chance that a single output of the RNN is swapped with a GT label decreases somewhat rapidly, like going from nearly 1 near the beginning of training, and dropping down to 0 after 5 or so epochs.

jacobbieker avatar Jul 07 '21 14:07 jacobbieker

A good approach to do this is putting a param on the model like:

teacher_forcing_prob = 0
...
def forward(self, x, targets):
  if self.teacher_forcing_prob:
    replace the model input with targets

and then in the training:

model.teacher_forcing_prob = 1 - epoch/total_epochs

tcapelle avatar Oct 07 '21 09:10 tcapelle