nanoGPT
nanoGPT copied to clipboard
why in transformer we compute for all tokens but then use only the last token for prediction?
the input is (B,T) to the transformer and the output from the MLP is also (B,T) and we only use the embeddings of the last column to predict the next token why cant we do something with the embeddings of the other tokens? it's my first time learning transformers
try it to find out
I was wondering the same thing. It is useful during training to calculate the loss and adjust weights, but during prediction it seems the other token predictions are a waste.
yeah , apparently some implementation do this optimization of feeding the MLP layer only the result of the attention of the last token