llm-foundry
llm-foundry copied to clipboard
WIP: Preventing the loss from being computed when the input token is EOS Token
The model should not be trained to predict the word after the eos_token, because it comes from a different sequence. This PR implements this logic.
TODO: Experimental verification.
I think having this option is good, some users almost certainly want it.
However, I think this should be optional, as I am not convinced it shouldn't learn to predict the token after EOS. I'd expect the model to learn that after EOS (if sequences are joined randomly) it can disregard all context and pick from the distribution of tokens which begin sequences. This is a different distribution than raw unigram frequencies, which are the probabilities it should use when picking a token not conditioned on EOS.
Then, if sequences are not joined randomly, as in that TSP NN method, we definitely want to compute loss.
Then, if sequences are not joined randomly, as in that TSP NN method, we definitely want to compute loss.
Thanks for your comment! Yes, what you said makes sense. This is still very much a work in progress, and I just wanted to run some experimental tests initially to sanity check. Also, this is mainly for the case where we do sequence id based masking. In that case, the eos token is still a part of the previous sequence, but its target is the first word of the next sequence.
@samhavens should we also add the option to not predict BOS (assuming the previous tok is the end of the previous seq).
@vchiley for models which have both EOS and BOS, are you saying don't learn that BOS comes after EOS? it isn't worth learning, true, but also... we'll always stop generating at EOS so it wouldn't matter... or am I misunderstanding
as discussed on Slack, I think that:
- EOS is effectively a BOS token, and so we want P(t|EOS) to be different than P(t), so we don't want to mask this loss
- however, when doing seq id masking, we currently mask EOS for every token other than the first, so we learn P(t_0|EOS), P(t_1|t_0), P(t_2|t_0, t_1), ...
- So @ShashankMosaicML will confirm this and if it is happening, shift the mask so that EOS is visible after t_0