llama
llama copied to clipboard
Logits for all positions?
In model.Transformer.forward
, the following line says it'll only compute the logits for the last position in h:
output = self.output(h[:, -1, :]) # only compute last logits
I'm interested in getting surprisal values for each word in a sentence, so I'd like logits for every position.
It looks like first, I need to fix up the inputs by converting the pad_id
s to eos_id
, since pad_id
is -1
, which doesn't have an embedding. In contrast, eos_id
is 2
, which does have an embedding (though I'm not bothering to examine the logits for it or anything after—it's just to be able to run batches of sentences with unequal lengths).
After I do this, is it as simple as changing the line above to the following to get the logits for each position for each example in the batch? Just want to make sure I'm not missing anything obvious.
output = self.output(h)