llama icon indicating copy to clipboard operation
llama copied to clipboard

Logits for all positions?

Open mawilson1234 opened this issue 1 year ago • 0 comments

In model.Transformer.forward, the following line says it'll only compute the logits for the last position in h:

output = self.output(h[:, -1, :])  # only compute last logits

I'm interested in getting surprisal values for each word in a sentence, so I'd like logits for every position.

It looks like first, I need to fix up the inputs by converting the pad_ids to eos_id, since pad_id is -1, which doesn't have an embedding. In contrast, eos_id is 2, which does have an embedding (though I'm not bothering to examine the logits for it or anything after—it's just to be able to run batches of sentences with unequal lengths).

After I do this, is it as simple as changing the line above to the following to get the logits for each position for each example in the batch? Just want to make sure I'm not missing anything obvious.

output = self.output(h)

mawilson1234 avatar Apr 30 '23 03:04 mawilson1234