satflow
satflow copied to clipboard
Try Teacher Forcing for RNNs
This is where, with some random chance, we give the RNN like ConvLSTM the ground truth label when its generating sequences in training. This can help with convergenc especially in the beginning, since if the model messes up in the second timestep, all later ones will be useless, but if we give it the GT image to continue the sequence it might still learn better representations.
The chance that a single output of the RNN is swapped with a GT label decreases somewhat rapidly, like going from nearly 1 near the beginning of training, and dropping down to 0 after 5 or so epochs.
A good approach to do this is putting a param on the model like:
teacher_forcing_prob = 0
...
def forward(self, x, targets):
if self.teacher_forcing_prob:
replace the model input with targets
and then in the training:
model.teacher_forcing_prob = 1 - epoch/total_epochs