MNTP Question
Hi, great work on this!
Just had a question about the MNTP. In the paper, you mention " when predicting a masked token at position i, we compute the loss based on the logits obtained from the token representation at the previous position i − 1, not the masked position itself "
I was a bit confused about this and also why this is? Could you provide a more detailed explanation to this and the intuition behind it?
Thanks, Brett
Hi @bdytx5,
thanks for your interest in our work. We did this to align our training objective with the pre-training setup of decoder-only LLMs. Decoder only LMs are trained to predict the token at position i by using the embedding of token at position i-1. By making sure our training objective follows a similar pattern, the intuition is that we will maximally use the inherent capabilities of the model.
Let me know if you have any further questions.
ok, thanks!
Hi @bdytx5,
thanks for your interest in our work. We did this to align our training objective with the pre-training setup of decoder-only LLMs. Decoder only LMs are trained to predict the token at position i by using the embedding of token at position i-1. By making sure our training objective follows a similar pattern, the intuition is that we will maximally use the inherent capabilities of the model.
Let me know if you have any further questions.
hello,i confused about the picture of mntp task in your paper
in this picture,the formulas showed that w3 depend on w4 and [mask],but in decoder-only model,the answer will be predicted after input w2,right?
@vaibhavad
i understand , the inference is a concurrent process