transformers
transformers copied to clipboard
Generate: support passing position_ids
Thank you @tengomucho, for uncovering this bug.
The problem
In a nutshell, passing the correct position_ids
to generate
should result in exactly the same results as not passing them. In other words, the following test should pass on all models, if added to GenerationTesterMixin
. We can see that it is failing in general.
def test_passing_position_ids(self):
# Check that passing position ids to generate yields the same results as not passing them, if the position ids
# are correctly built. If the test fails, it means one of two things:
# 1 - the manual position ids are not being piped correctly; OR
# 2 - the automated position ids are not being correctly built.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
if config.is_encoder_decoder:
self.skipTest("This model does not support position_ids")
# To truly test this property, let's create a batch where the second row corresponds to the test input with
# left padding of 1.
pad_token = torch.tensor([[config.pad_token_id or 0]], device=input_ids.device, dtype=input_ids.dtype)
input_ids = torch.cat((input_ids, torch.cat((pad_token, input_ids[:, 1:]), dim=1)), dim=0)
pad_mask = torch.zeros((1, 1), dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat((attention_mask, torch.cat((pad_mask, attention_mask[:, 1:]), dim=1)), dim=0)
position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0)
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
try:
output_position_ids = model.generate(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
max_new_tokens=10
)
except ValueError as exc:
if "The following `model_kwargs` are not used by the model: ['position_ids']" in str(exc):
self.skipTest("This model does not support position_ids")
else:
raise
output_no_position_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=10
)
self.assertListEqual(output_no_position_ids.tolist(), output_position_ids.tolist())
The fix
There are two root causes for this:
-
position_ids
is rejected in some models when it is passed (e.g. see here). These models often assume no padding whenposition_ids
is rejected. -
position_ids
is never updated, so it is only correct when created from scratch (=not passed).
As such, a fix to this problem should consist in updating position_ids
in generate
, with prepare_inputs_for_generation
only creating new position_ids
when they don't exist.
The test pasted above should be part of our tests after fixing the issue.
@zucchini-nlp FYI. We shouldn't fix this now, as it requires significant manual labor to update all models. After the static cache sprint we should have a look at this :)