transformers
transformers copied to clipboard
fix(generation): stop beam search per-instance when heuristic satisfied
What does this PR do?
This PR fixes a bug in beam search generation where early stopping heuristics (when early_stopping=False) was incorrectly applied across the entire batch, instead of per instance.
🔍 Problem
When early_stopping=False, the generation heuristic is supposed to stop generating once it’s unlikely that any beam will improve. However, the current behavior waits until all batch instances satisfy this heuristic before halting. This causes:
- Instances that are already “done” (according to the heuristic) to continue generating,
- Unnecessarily long and repetitive outputs,
- Inconsistent behavior depending on batch composition.
✅ Fix
We now apply the early stopping heuristic per-instance. As soon as a single instance has no beams left that can improve, generation for that instance is not used for updating answers. This restores expected behavior and leads to:
- Consistency between single-instance and batched generation,
- Parity with behavior in transformers < 4.50.
🧪 Reproduction Example
Working case (single input)
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
olmo_model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B-Instruct")
olmo_model = olmo_model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct", padding_side="left")
generation_config = GenerationConfig(
num_beams=10,
max_new_tokens=256,
length_penalty=2,
)
question = [ {"role": "user", "content": "What is 3+5?"} ]
question = tokenizer.apply_chat_template(
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda")
outputs = olmo_model.generate(
**inputs,
generation_config=generation_config,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
Produces clean output:
The sum of 3 and 5 is 8.
So, 3 + 5 = 8.
...
The sum of 3 and 5 is \(\boxed{8}\).
Broken case (batched input)
question = [ {"role": "user", "content": "What is 3+5?"} ]
cot_question = [ {"role": "user", "content": "What is 3+5? Explain your reasoning step by step, and provide the final answer at the end."} ]
question = tokenizer.apply_chat_template(
question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
cot_question = tokenizer.apply_chat_template(
cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda")
outputs = olmo_model.generate(
**inputs,
generation_config=generation_config,
)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(responses[0])
Produces repetitive output:
The sum of 3 and 5 is 8.
...
The sum of \(3 + 5\) is \(\boxed{8}\).
If you have any more questions or need further assistance, feel free to ask!
The sum of \(3 + 5\) is \(\boxed{8}\).
If you have any more questions or need further assistance, feel free to ask!
The sum of \(3 + 5\) is \(\boxed{8}\).
If you have any more questions or need further assistance, feel free to ask!
The sum of \(3 + 5\) is \(\boxed{8}\).
If you have any more questions or need further assistance, feel free to ask!
The sum of \(3 + 5\) is \(\boxed{8}\).
If you have any more questions or need further assistance, feel free to ask!
This undesirable repetition happens only when batched with longer examples. It can occur even with default settings like length_penalty=1.
This bug appears in recent versions with vectorized beam search. It does not appear in transformers < 4.50.0.
Who can review?
@gante Could you please take a look at this? Thanks!