llama
llama copied to clipboard
Question about the generate method
When running the generate
method, the logits are obtained like this:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
Initially, prev_pos=0
, so the first step will return the predictions based on all tokens from 0 to the length of the shortest example in the batch (=cur_pos
, initially). But after this, prev_pos
gets set to cur_pos
, and then cur_pos
gets incremented by 1 (until we reach the maximum length). The next token is determined on the basis of these logits (either by sampling, argmax, or replacement with the provided token for prompts that are longer than the shortest one), and added to the prompt before the next iteration of the loop.
But this means that on each subsequent iteration, prev_pos
= cur_pos
- 1, so tokens[:, prev_pos:cur_pos]
only gives us a single token for each example in the batch on all but the first iteration of the loop. Does this mean that subsequent prediction steps only give predictions on the basis of the immediately preceding token, rather than all preceding tokens in the prompt? That seems odd for the longer prompts in the batch, where I'd want it to consider all the preceding context, not just the token right before the end when generating. Am I misunderstanding something about how the forward
method is working here that would account for this?
Edit: Another way of putting this question would be to ask what the difference is between using the snippet above to get the logits compared to doing this:
logits = self.model.forward(tokens[:, 0:cur_pos], 0)
Second edit:
Is this not an issue because cache_k
and cache_v
get updated when the forward
method of the attention heads gets called?