simple-generation icon indicating copy to clipboard operation
simple-generation copied to clipboard

Truncation by tokenizer not working correctly

Open lorelupo opened this issue 11 months ago • 1 comments

Hello,

Truncation of the input_ids during tokenization, .i.e., line 336, does not work properly, throwing the following warning:

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.

And then, in the generation loop :

Error The size of tensor a (8192) must match the size of tensor b (10824) at non-singleton dimension 3 Generation failed. Skipping batch.

I suggest replacing lambda x: self.tokenizer(x["text"], truncation=True) with

lambda x: self.tokenizer(
    x["text"],
    truncation=True,
    max_length=self.model.config.max_position_embeddings-current_generation_args["max_new_tokens"]-8,
    )

and modifying the _prepare_generation_args method accordingly:

def _prepare_generation_args(self, **generation_kwargs):
        current_generation_args = self.generation_config.to_dict()

        logger.info("Setting pad_token_id to eos_token_id for open-end generation")
        current_generation_args["pad_token_id"] = self.tokenizer.eos_token_id
        current_generation_args["eos_token_id"] = self.tokenizer.eos_token_id

        # We fix when some model default to the outdated "max_length" parameter
        if "max_new_tokens" in current_generation_args:
            if "max_length" in current_generation_args:
                logger.warning(
                    "Found 'max_length' in the model's default generation config. Using 'max_new_tokens' instead."
                )
                current_generation_args.pop(
                    "max_length"
                )
        elif "max_length" in current_generation_args:
            logger.warning(
                "Found 'max_length' in the model's default generation config. Renaming it 'max_new_tokens'."
            )
            current_generation_args["max_new_tokens"] = current_generation_args.pop(
                "max_length"
            )
        else:
            current_generation_args["max_new_tokens"] = 1000

        if len(generation_kwargs) > 0:
            logger.info(
                "Custom generation args passed. Any named parameters will override the same default one."
            )
            current_generation_args.update(generation_kwargs)

        # Postprocess generation kwargs
        if (
            "temperature" in current_generation_args
            and current_generation_args["temperature"] == 0
        ):
            logger.info("Temperature cannot be 0. Setting it to 1e-4.")
            current_generation_args["temperature"] = 1e-4

        return current_generation_args


I can do a PR if needed :)

lorelupo avatar Feb 27 '24 19:02 lorelupo