simple-generation
simple-generation copied to clipboard
Truncation by tokenizer not working correctly
Hello,
Truncation of the input_ids
during tokenization, .i.e., line 336, does not work properly, throwing the following warning:
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
And then, in the generation loop :
Error The size of tensor a (8192) must match the size of tensor b (10824) at non-singleton dimension 3 Generation failed. Skipping batch.
I suggest replacing lambda x: self.tokenizer(x["text"], truncation=True)
with
lambda x: self.tokenizer(
x["text"],
truncation=True,
max_length=self.model.config.max_position_embeddings-current_generation_args["max_new_tokens"]-8,
)
and modifying the _prepare_generation_args
method accordingly:
def _prepare_generation_args(self, **generation_kwargs):
current_generation_args = self.generation_config.to_dict()
logger.info("Setting pad_token_id to eos_token_id for open-end generation")
current_generation_args["pad_token_id"] = self.tokenizer.eos_token_id
current_generation_args["eos_token_id"] = self.tokenizer.eos_token_id
# We fix when some model default to the outdated "max_length" parameter
if "max_new_tokens" in current_generation_args:
if "max_length" in current_generation_args:
logger.warning(
"Found 'max_length' in the model's default generation config. Using 'max_new_tokens' instead."
)
current_generation_args.pop(
"max_length"
)
elif "max_length" in current_generation_args:
logger.warning(
"Found 'max_length' in the model's default generation config. Renaming it 'max_new_tokens'."
)
current_generation_args["max_new_tokens"] = current_generation_args.pop(
"max_length"
)
else:
current_generation_args["max_new_tokens"] = 1000
if len(generation_kwargs) > 0:
logger.info(
"Custom generation args passed. Any named parameters will override the same default one."
)
current_generation_args.update(generation_kwargs)
# Postprocess generation kwargs
if (
"temperature" in current_generation_args
and current_generation_args["temperature"] == 0
):
logger.info("Temperature cannot be 0. Setting it to 1e-4.")
current_generation_args["temperature"] = 1e-4
return current_generation_args
I can do a PR if needed :)