llama
llama copied to clipboard
Question about total_len and max_gen_len
Line 165 in generation.py sets total_len
as follows:
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
The description of max_gen_len
here is:
max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. If not provided, it's set to the model's maximum sequence length minus 1.
Consider the following example for text completion:
Number of prompts = 2 prompt 1 has 8 initial input tokens prompt 2 has 13 initial input tokens max_gen_len = 64 max_seq_len = 512
In this case, min_prompt_len = 8
, max_prompt_len = 13
, max_gen_len + max_prompt_len = 77
, and total_len = min(512, 77) = 77
. The model ends up producing tokens for both prompts until each has 77 tokens total. This means the model generated 69 tokens for the first prompt (and 64 tokens for the second prompt). This seems to be a violation of what max_gen_len
is meant to enforce -- that the model should only be able to generate a maximum of 64 tokens per prompt.
Should line 165 instead say:
total_len = min(params.max_seq_len, max_gen_len + min_prompt_len)
?
Hi @dnatarajan00!
Using total_len = min(params.max_seq_len, max_gen_len + min_prompt_len)
appears to be semantically more correct, but note that it reduces the effective max_gen_len
for longer prompts. In your example, the output for prompt 2 (input_len=13
) caps the generated output length to 59 despite setting max_gen_len=64
.
Using total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
is based on the longest prompt in the batch, so for shorter prompts the model can indeed generate more tokens than max_gen_len
.
I think we defaulted to this implementation because it's easier to check for and truncate a long sequence, than regenerate a much-too-short sequence... but I agree that the semantic used here is tricky.
File "/home/ljn/project/llama/llama/generation.py", line 165, in generate total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) TypeError: can only concatenate str (not "int") to str why this happened?