transformers icon indicating copy to clipboard operation
transformers copied to clipboard

(Not So) Bad words list for text generation

Open iiglesias-asapp opened this issue 1 year ago • 9 comments

Feature request

Support a soft penalization logits processor in the transformers generate method (extends NoBadWordsLogitsProcessor).

Motivation

  • The NoBadWordsLogitsProcessor forbids the generation of certain tokens in absolute terms by overwriting the logits to minus infinity
  • The request is to add a softer version of this, one in which certain tokens are penalized or boosted but only mildly
  • This is in the spirit of the logit_bias parameter in the generate methods here (OpenAI) and here (Cohere)
  • Possible use cases include, but are not limited to: enhance extractiveness during document summarization by boosting tokens present in the input and style guidance by penalizing/boosting the appropriate vocabulary

Your contribution

Overview

  • A new class is defined as BendLogitsProcessor based on the current NoBadWordsLogitsProcessor class
  • The current argument bad_words_ids is enriched to include a float value per list of tokens_ids, aka the penalization/boosting score. Positive large values encourage the token to be generated while negative large values do the opposite
  • Penalization/boosting scores are unbounded but could be later scaled as it seems to be the case in the implementations referenced above, e.g. logit bias is in [-10,10] here and [-100,100] here
  • Observe that NoBadWordsLogitsProcessor behavior could be recovered just by explicitly setting penalization/boosting scores to float(“-Inf”)

The new class This is very much the same as NoBadWordsLogitsProcessor, I tried to keep as much as possible intact. There might be a more efficient implementation.

class BendLogitsProcessor(LogitsProcessor):
    """
    [`LogitsProcessor`] that softly penalizes or boosts certain token/s
    Args:
        bend_list (`List[Union[float, List[int]]]`):
            List of list of lists with penalization/boosting coefficients and list of token ids.
            In order to get the token ids of the words, use `tokenizer(bad_words, add_prefix_space=True,
            add_special_tokens=False).input_ids`.
        eos_token_id (`int`):
            The id of the *end-of-sequence* token.
    """
    def __init__(self, bend_list: List[Union[float, List[int]]], eos_token_id: int):     
        self.bend_list = bend_list
        coefs = [coef for coef,tok in self.bend_list]
        words_ids = [tok for coef,tok in self.bend_list]
        
        if not isinstance(bend_list, List) or len(bend_list) == 0:
            raise ValueError(f"`bend_list` has to be a non-empty list, but is {bend_list}.")
        if any(not isinstance(word_ids, list) for word_ids in words_ids):
            raise ValueError(f"`words_ids` has to be a list of lists, but is {words_ids}.")
        if any(
            any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in word_ids)
            for word_ids in words_ids
        ):
            raise ValueError(
                f"Each list in `words_ids` has to be a list of positive integers, but is {words_ids}."
            )
        if any(not isinstance(coef, float) for coef in coefs):
            raise ValueError(f"`coefs` has to be a float, but is {coefs}.")
        
        words_ids = list(filter(lambda token_seq: token_seq != [eos_token_id], words_ids))
        self.words_id_length_1, self.coefs_length_1 = [],[]
        self.words_id_length_greater_than_1, self.coefs_length_greater_than_1 = [],[]
        for coef,word in zip(coefs,words_ids):
            if len(word) == 1:
                self.words_id_length_1.append(word[0])
                self.coefs_length_1.append(coef)
            else:
                self.words_id_length_greater_than_1.append(word)
                self.coefs_length_greater_than_1.append(coef)

        for token_seq in self.words_id_length_greater_than_1:
            if len(token_seq) == 0:
                raise ValueError(f"Words token sequences {words_ids} cannot have an empty list")

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        masks_length_1, scores_length_1 = [], torch.zeros_like(scores)
        masks_length_greater_than_1, scores_length_greater_than_1 = [], torch.zeros_like(scores)
        if len(self.words_id_length_1) > 0:
            for word_id,coef in zip(self.words_id_length_1,self.coefs_length_1):
                mask = self._get_mask_length_1(scores,word_id) 
                masks_length_1.append(mask)
                if coef >= 0:
                    score = scores.masked_fill(scores.masked_fill(~mask,0) < 0,0) * (1 + coef) + \
                            scores.masked_fill(scores.masked_fill(~mask,0) >= 0,0) / (1 + coef)
                if coef < 0:
                    score = scores.masked_fill(scores.masked_fill(~mask,0) < 0,0) / (1 + abs(coef)) + \
                            scores.masked_fill(scores.masked_fill(~mask,0) >= 0,0) * (1 + abs(coef))
                scores_length_1 += score

        if len(self.words_id_length_greater_than_1) > 0:
            for word_ids,coef in zip(self.words_id_length_greater_than_1,self.coefs_length_greater_than_1):
                mask = self._get_mask_length_greater_than_1(input_ids.tolist(),scores,word_ids) 
                masks_length_greater_than_1.append(mask)
                if coef >= 0:
                    score = scores.masked_fill(scores.masked_fill(~mask,0) < 0,0) * (1 + coef) + \
                            scores.masked_fill(scores.masked_fill(~mask,0) >= 0,0) / (1 + coef)
                if coef < 0:
                    score = scores.masked_fill(scores.masked_fill(~mask,0) < 0,0) / (1 + abs(coef)) + \
                            scores.masked_fill(scores.masked_fill(~mask,0) >= 0,0) * (1 + abs(coef))
                scores_length_greater_than_1 += score

        masks_all_lengths = masks_length_1 + masks_length_greater_than_1
        one_large_mask = torch.zeros_like(scores).bool()
        for mask in masks_all_lengths:
            one_large_mask = torch.bitwise_or(one_large_mask,mask)
        
        base_scores = scores.masked_fill(one_large_mask,0.)
        new_scores = base_scores + scores_length_1 + scores_length_greater_than_1 

        return new_scores
    
    def _get_mask_length_1(self, scores: torch.FloatTensor, word_id:List[int]) -> torch.BoolTensor:
        mask = torch.zeros(scores.shape[1])
        mask[word_id] = 1
        return mask.unsqueeze(0).to(scores.device).bool()

    def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool:
        if len(tokens) == 0:
            return True
        elif len(tokens) > len(prev_tokens):
            return False
        else:
            return prev_tokens[-len(tokens) :] == tokens

    def _calc_word_ids(self, prev_input_ids: List[List[int]], word_ids:List[int]) -> Iterable[int]:
        tokens = []
        for prev_input_ids_slice in prev_input_ids:
            tokens_slice = []
            if self._tokens_match(prev_input_ids_slice, word_ids[:-1]):
                tokens_slice.append(word_ids[-1])

            tokens.append(tokens_slice)

        return tokens

    def _get_mask_length_greater_than_1(self, input_ids: list, scores: torch.FloatTensor, word_ids:List[int]) -> torch.BoolTensor:
        dynamic_tokens = self._calc_word_ids(input_ids, word_ids)
        mask_list = []
        for idx, batch_tokens in enumerate(dynamic_tokens):
            for token in batch_tokens:
                # Eliminates invalid bad word IDs that are over the vocabulary size.
                if token <= scores.shape[1]:
                    mask_list.append([idx, token])
                else:
                    logger.error(
                        f"An invalid bad word ID is defined: {token}. This ID is not contained in the "
                        "vocabulary, and is therefore ignored."
                    )
        if not mask_list:
            mask = torch.zeros_like(scores).bool()

        else:
            mask = torch.LongTensor(mask_list)
            indices = torch.ones(len(mask))
            mask = (
                torch.sparse.LongTensor(mask.t(), indices, scores.size())
                .to(scores.device)
                .to_dense()
                .bool()
            )
        return mask

An example Take the summarization example in BART documentation here. Set add_prefix_space=True in the tokenizer and remove the max_length = 20 in the generate method call.

from transformers import AutoTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn", add_prefix_space=True)

ARTICLE_TO_SUMMARIZE = (
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")

# Generate Summary
summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0)
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

This yields the following summary:

Nearly 800 thousand customers were scheduled to be affected by the shutoffs. PG&E stated it scheduled the blackouts in response to forecasts for high winds.

At this point the new logits processor class is applied. The objective will be to make the model output the number of customers affected as digits and replace the word “shutoffs”. We do so by penalizing the token ids for “thousand” and “shutoffs” while boosting the ones for “shutdowns”.

logits_processor = LogitsProcessorList(
    [
        BendLogitsProcessor(
            bend_list = [[-10000.,[7673]], # thousand
                          [1000.,[5001, 29]], # shutdowns
                          [-1000000.,[2572, 10816]], # shutoffs
                          [-1000000.,[2572, 1529]], # shutoffs
                         ], 
            eos_token_id=model.config.eos_token_id
        )
    ]
)

# Generate Summary
summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, logits_processor=logits_processor, renormalize_logits=True)
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

If we call the the summary generation again, this time including the logits processor and renormalizing we get:

Nearly 800,000 customers were scheduled to be affected by the shutdowns. PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions.

iiglesias-asapp avatar Mar 14 '23 23:03 iiglesias-asapp

cc @gante

amyeroberts avatar Mar 15 '23 13:03 amyeroberts

Hey @iiglesias-asapp 👋 Thank you for the suggestion!

Before we dive into adding code, a disclaimer -- one of the current problems with .generate() is that there are too many options, scaring users away from the docs. This means that I will be conservative before giving the green light to add more options 🤗

We do have an option to have control over extractive vs abstraction summarization, the encoder_repetition_penalty (docs). This is a multiplicative factor to the logits that increases/decreases the odds of reusing the tokens in the input.

Do you have more use cases in mind, where your suggestion would be critical?

gante avatar Mar 21 '23 12:03 gante

Hi @gante! Thanks for the reply.

I agree that there many options already 😅 I wasn't thinking of this as an additional option but more like an "upgrade" of the existing feature since it gives the user a bit more flexibility while keeping the previous functionality, i.e. tokens are boosted/penalized instead of forced/forbidden and users willing to forbid the appearance of certain token can still input float("-Inf") as score.

Main use case in mind was cheap model customization by a set of score,[tokens]. I guess, more generally, it is desirable to allow the model to generate a certain token if there is no natural replacement for it and discourage it otherwise; the sort of soft penalization that is allowed in other APIs.

iiglesias-asapp avatar Mar 21 '23 12:03 iiglesias-asapp

@iiglesias-asapp I see your point - controlling at a token level may be advantageous. Nevertheless, i) without a specific common use case in mind and ii) having not heard the demand for this feature before, I'm reluctant to add it. Remember that custom logits processors can be used, so not adding it to the codebase doesn't mean that it can't be used 🤗

Let's not close this issue and do the following. If this comment gets 10 reactions/this issue gets mentioned 10 times, then it means that folks have been searching for this feature. In that case, let's roll back my decision above, and add it to the codebase. That way, we can balance HF's limited maintenance resources with actual feature demand! (Whoever does the 10th react, plz tag me)

@iiglesias-asapp does it sound good to you?

gante avatar Mar 21 '23 15:03 gante

Sounds good! Thanks for considering it @gante

iiglesias-asapp avatar Mar 21 '23 15:03 iiglesias-asapp

Please add this because I have alpaca model and it was trained on a bad dataset with many cases of input and output fields having "<noinput" and "nooutput>" text in them which causes my LLM to constantly respond with those words :/

teknium1 avatar Mar 30 '23 16:03 teknium1

@teknium1 I think that bad_words_list as it is would be enough for your example. But if you still feel something like the logit_bias parameter is what you need, react to @gante comment to make this available

iiglesias-asapp avatar Mar 30 '23 17:03 iiglesias-asapp

@teknium1 I think that bad_words_list as it is would be enough for your example. But if you still feel something like the logit_bias parameter is what you need, react to @gante comment to make this available

Oh can you point me to where/how I can use the bad_words_list

edit: nvm found it ty

teknium1 avatar Mar 30 '23 23:03 teknium1

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 24 '23 15:04 github-actions[bot]

custom logits processors

@iiglesias-asapp I see your point - controlling at a token level may be advantageous. Nevertheless, i) without a specific common use case in mind and ii) having not heard the demand for this feature before, I'm reluctant to add it. Remember that custom logits processors can be used, so not adding it to the codebase doesn't mean that it can't be used 🤗

Let's not close this issue and do the following. If this comment gets 10 reactions/this issue gets mentioned 10 times, then it means that folks have been searching for this feature. In that case, let's roll back my decision above, and add it to the codebase. That way, we can balance HF's limited maintenance resources with actual feature demand! (Whoever does the 10th react, plz tag me)

@iiglesias-asapp does it sound good to you?

@gante

There are many use cases:

  1. Increase length of generated text, by making end of text token less probable.

  2. If you use few shot learning, and you have problem with labels that use used, you can increase probability of a label. for example: instruction: write me a joke about cars answer: some response instruction: write me a joke about [subject2] answer: some response instruction: write me a joke about [subject3] answer: some response

    then you need to increase probability for answer: in some cases, when not everything work as it should. encoded norepeat engrams is one option, but it sometimes generates strange text.

2a) The same thing if you do a few shot learning to generate html text. For example, when you want text not to repeat, if you set params for that, then also html tags wont be repeated and text will be strangely formated. So then you just increase the probability of html tags and you get much better output.

  1. paraphrasing for dataset multiplying to get more unique paraphrases, it is good to lower probability of original words

  2. openai has this feature, i really doubt they would implement something, and write documentation for that, if they did not think that some users would use it.

Oxi84 avatar Jun 03 '23 21:06 Oxi84

@gante Here comes the 10th reaction! Thanks for considering adding this feature. Really need this since I'm currently working on building APIs similar to OpenAI API. It would be convenient if it is officially supported!

andyh0913 avatar Jun 14 '23 09:06 andyh0913

As promised, I've added it to my queue! 🫡

gante avatar Jun 14 '23 11:06 gante

Hey everyone 👋 A way to bias specific tokens has been added on main. You can check its docs here (which contains a thorough example) and the corresponding GenerationConfig flag here. Let me know if it is not working properly 🤗

Tagging the folks that have upvoted the comment above and/or replied on this thread for visibility: @iiglesias-asapp @teknium1 @liaeh @skevy @talkhaldi @francislabountyjr @tristanvdb @thjwhite @NanoCode012 @zhuhl98 @Oxi84 @andyh0913 @Vonathar

gante avatar Jun 21 '23 11:06 gante