transformers
transformers copied to clipboard
Add hallucination filter in generate()
Feature request
Adding a filter of some sorts in the generate function to prevent x number of words from outside of the input from appearing in a generated text.
This could work in a number of ways. It could be a filter on the number of out of source words appearing in the generated text (e.g 2 would mean that a maximum of 2 words could be present in the generated text but not in the source) or it could be some sort of damping variable (0 to 1) that's applied to the probabilities of each generated word therefore reducing the likelihood that out of source words would appear in the generated text. If the probability was set to 0 then the generation task would be purely extractive and have no risk of hallucinations.
Motivation
To control the risk of hallucinations in the generated text.
Your contribution
Happy to work on a PR for this if I just get a bit of guidance on the best place to start.
WDYT of this request @gante ?
Hi @KMFODA 👋
I'm not sure if I got the proposal right, let me try to explain in my own words to double-check :) You would like some sort of filter that prevents a large number of different tokens (that are not present in the input prompt) from being generated. Taking your example, considering words=tokens -- if the input is this is
and the maximum number of new tokens is 2
, this is a cat
would be a desirable output and this is a brown dog
would be undesirable (because it uses 3 new words). However, This is what it is
would be okay, because is
is present in the original prompt. Correct?
Before going into the implementation, let's take a step back. I have two questions:
- One of the current issues with
generate
is that it has many options, so we need to be mindful when adding new functionality, in order to contain its complexity. Don't take this as "we don't want your suggestion", but rather as "let's see if we can make it happen without adding code" :) For instance, the repetition penalty seems very close to what you want -- it adds a penalty to existing output tokens. A full list of logits processors and constraints can be seen here. - Assuming there is no combination of logits processors and constraints that can achieve what you had on your mind: what is the use case? Can you give me an example?
Hi @gante. Apologies I don't think I've properly clarified the use case I think this could solve. I unfortunately can't share my model here or my dataset (but happy to do so privately) as they're both private and sensitive but I can try and give as much context as possible. This problem arrises for me when using a PEGASUS model for summarisation on a private dataset of meeting segments and summaries.
The input for scenarios where this occurs looks like this:
Person A: text
Person B: text
Person A: text
Person B :text
and the output ends up looking like this:
Person C and Person B met today to discuss..
In this example Person C was never in the input text. This is the type of hallucination I was hoping to fix in the generate function as it can be really off-putting to someone using the model. It isn't necessarily restricted to a person's name also. It could be an address / company name / product / number etc.
Reading parts of the linked repetition penalty paper, I believe this parameter is designed to penalise previously generated tokens which is not exactly what is needed for this use case. This use case could be stated as reducing the probability that a new word such as Person C appears in the output.
I've looked at other LogitsProcessors and I don't think any would do this out of the box. I don't know wether it would be helpful to have a processor that handles this use case. If it is I was thinking it could be done in one of 3 ways:
- Penalising the generation of new tokens
- Boosting the probabilities of input tokens
- Having a hard limit on the number of new tokens that can appear in the output
Thank you for the clarification @KMFODA, it makes total sense for a summarization task!
There is a chance you can solve it with existing code :D Can you have a look at the constrained beam search documentation, which has an example, and attempt to constrain the generation to include the name of the individuals in the conversation? See the examples in this PR header as well -- https://github.com/huggingface/transformers/pull/15761
Let us know if this strategy helps :)
Thanks for the suggestion @gante. The constrained beam search option is really cool. Playing around with it though shows that it doesn't necessarily prevent hallucinations. If for example you put the constraint "Person A and Person B" (which in itself is not very generalisable as you don't always want a summary with every participant in the meeting) you still get Person C appear in the text.
What would effectively solve this is a constraint to not include the tokens "Person C". Or potentially to boost all the vocabulary in the original text thereby boosting Person A and Person B over Person C.
If this seems like a very specific case happy to just work on it privately. Just thought I'd raise it in case it would benefit the wider community.
Thank you for trying it @KMFODA 🙏 I wasn't sure whether it would help here. And apologies for adding so many speed bumps along the way -- our generate function has many many options, and we are trying to be more conservative before adding more flags. We need to rule out potential duplicates :)
I'd like to get the input of @patrickvonplaten (who's currently off) on this topic. Maybe there are specific solutions for this problem that he knows of, or maybe we can create novel research work from this problem!
Meanwhile, if you'd like to experiment, here are some pointers:
- Adapting the logits can be done with subclasses of
LogitsProcessor
. There are many examples below in this file, and they are relatively simple; - If a processor gets appended to the list of processors (here), then your
generate
will feel the effects of the transformation; - For experimenting, I'd recommend to hardcode the inclusion of the new processor to the list of processors (the line linked in 2.) -- we can worry about the whole
generate
API later, if the results are positive; - For your particular problem:
a. The input tokens are inside
encoder_input_ids
, which are an input to the function linked in 2. You can store then inside the processor's__init__
for later use b. A modified version ofRepetitionPenaltyLogitsProcessor
will probably do the trick -- in its__call__
, instead of gathering and scattering overinputs_ids
(the generated tokens), you want to do it over everything that's NOT inencoder_input_ids
(the input tokens). Or perhaps it will be more efficient to add the penalty everywhere (multiplication by a constant) and reverse the penalty addition to the tokens in the original input. c. Note: do not removeinput_ids
andscores
from the signature in__call__
even if they are not used -- it will raise exceptions.
Let us know if you get any interesting results. Depending on @patrickvonplaten's suggestion, we may build the tool ourselves!
Thanks @gante that's very helpful. I worked on a draft PR (just in case this is deemed useful to anyone else) and initial results look promising. If I feed a hallucination penalty of 2 (instead of the default 1) to the greedy search function the text goes from:
Person C will send Person B an email
where Person C was not in the input text and is therefore classed as a hallucination.
to:
Person A will send Person B an email
I'm aware this is just one datapoint so I want to test this even further but I only have 3 data points with hallucinations in my sample so far. I'll try and find an open sourced dataset focused on hallucinations that I can use to test this out even further and report back.
Super interesting discussion here! Thanks for writing this all down @gante and @KMFODA :-) The PR looks nice to me in general - thanks a lot for opening it @KMFODA!
Just FYI we also have a similar processor to not repeat ngrams: https://github.com/huggingface/transformers/blob/06d1ba1a55a12b3fb3ca081bdd4f812fda800c37/src/transformers/generation_logits_process.py#L364
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Commenting to confirm this is not stale. A request to review my latest changes to the PR is out.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Commenting again to confirm this is not stale. PR has incorporate all comments and passed all tests. I believe it's just waiting on a second pair of 👀
PR merged. Closing this now. Thanks for all the help @gante and @patrickvonplaten.
Is there an example of how to use this logits processor.