transformers
transformers copied to clipboard
Support customized vocabulary for decoding (in model.generate)
Feature request
Use case:
Given a small list of tokens that is a subset of the whole vocabulary of the tokenizers for T5. For example, ["put", "move", "pick", "up", "on", "in", "apple", "bag", ....]
And when we decode by using model.generate(), we want the model only output sentences that consist of words in the above list (i.e., limited vocabulary for beam searching or sampling).
Maybe it is already supported in some way?
Motivation
For some applications, we only want to decode sentences with a limited vocabulary instead of allowing open-ended generation.
Your contribution
I'm not sure what is the best way to add this feature, if it is easy to limit the vocab for generate functions, then I can help add this PR.
I have read this post: https://huggingface.co/blog/constrained-beam-search
But it seems that such Constraints can only support constraints of ensuring some tokens are part of the sentences but cannot prevent other tokens to be selected during decoding.
Found this post to use bad_word_list as the whole vocab - customized vocab as the input:
https://stackoverflow.com/questions/63920887/whitelist-tokens-for-text-generation-xlnet-gpt-2-in-huggingface-transformers
Will have a try but sounds like a bit awkward to use.
cc @gante
Hey @yuchenlin 👋
My first approach would be to use bad_word_list, passing to it all but the tokens you want to use. It's a no-code approach, but perhaps not the most efficient computationally.
Alternatively, you can write your own processor class that sets to -inf the logits of all but the tokens you want to consider. To do it, you would have to:
- Write your own class that implements the logic. You can see plenty of examples in this file
- Use your class at generation time, e.g.
tokens_to_keep = tokenizer(xxx) # xxx = list with your valid words
my_processor = MyLogitsProcessorClass(tokens_to_keep=tokens_to_keep)
model.generate(inputs, ..., logits_processor=LogitsProcessorList([my_processor]))
I hope this short guide helps 🤗
Hi @gante ,
Thanks a lot! Yeah I have tried with the bad_wordLlist (see example below) and I found that the generated outputs are much worse than before although they are indeed constrained to the given vocabulary. I was using beam search and I'm not sure if it is because that the vocab is so small that the normalization or other process becomes unstable.
I will try the logit processor idea as well. Thank you! :D
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
whitelist = ["move", "it", "pick", "up", "focus", "on"]
whitelist_ids = [tokenizer.encode(word)[0] for word in whitelist]
bad_words_ids=[[id] for id in range(tokenizer.vocab_size) if id not in whitelist_ids]
encoder_input_str = "Explain this concept to me: machine learning"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
outputs = model.generate(
input_ids,
num_beams=10,
do_sample=False,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
bad_words_ids = bad_words_ids,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
@yuchenlin haha yes, the quality of the output will likely decrease significantly, that is to be expected!
Instead of whitelisting words, consider the "soft-whitelisting" alternative: increase the odds of picking a token from the whitelist. You can easily implement this by changing the repetition penalty logits processor to boost the odds of certain tokens :)
Thanks a lot for the advice! I currently used a simpler method --- adding some random tokens (say 30% of the whole vocab) to the whitelist and it seems to help.
Will also try your idea soon! Thanks again! :D
Just in case you are interested in more diversity of these constraints, I wrote a whole package and paper about this idea: https://github.com/Hellisotherpeople/Constrained-Text-Generation-Studio
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
An alternative path: using the transformers-compatible outlines library 🤗
https://github.com/outlines-dev/outlines
wow outlines is so cool!