llama icon indicating copy to clipboard operation
llama copied to clipboard

Question about total_len and max_gen_len

Open dnatarajan00 opened this issue 1 year ago • 1 comments

Line 165 in generation.py sets total_len as follows: total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

The description of max_gen_len here is:

max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. If not provided, it's set to the model's maximum sequence length minus 1.

Consider the following example for text completion:

Number of prompts = 2 prompt 1 has 8 initial input tokens prompt 2 has 13 initial input tokens max_gen_len = 64 max_seq_len = 512

In this case, min_prompt_len = 8, max_prompt_len = 13, max_gen_len + max_prompt_len = 77, and total_len = min(512, 77) = 77. The model ends up producing tokens for both prompts until each has 77 tokens total. This means the model generated 69 tokens for the first prompt (and 64 tokens for the second prompt). This seems to be a violation of what max_gen_len is meant to enforce -- that the model should only be able to generate a maximum of 64 tokens per prompt.

Should line 165 instead say: total_len = min(params.max_seq_len, max_gen_len + min_prompt_len) ?

dnatarajan00 avatar Jan 05 '24 00:01 dnatarajan00

Hi @dnatarajan00!

Using total_len = min(params.max_seq_len, max_gen_len + min_prompt_len) appears to be semantically more correct, but note that it reduces the effective max_gen_len for longer prompts. In your example, the output for prompt 2 (input_len=13) caps the generated output length to 59 despite setting max_gen_len=64.

Using total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) is based on the longest prompt in the batch, so for shorter prompts the model can indeed generate more tokens than max_gen_len.

I think we defaulted to this implementation because it's easier to check for and truncate a long sequence, than regenerate a much-too-short sequence... but I agree that the semantic used here is tricky.

subramen avatar Jan 10 '24 16:01 subramen

File "/home/ljn/project/llama/llama/generation.py", line 165, in generate total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) TypeError: can only concatenate str (not "int") to str why this happened?

debfhfhkgref avatar Apr 18 '24 12:04 debfhfhkgref