transformers
transformers copied to clipboard
Allow setting different decoder_start_token_ids for each item in a batch in the generate function.
Feature request
@gante
The generate
function has a decoder_start_token_id
argument that allows the specification of the decoder start token when generating from an encoder-decoder model (e.g. mT5). Currently, decoder_start_token_id
must be an integer, which means that the same start token is used for all elements in the batch. I request that you allow the specification of different start tokens for each element of the batch. For this purpose, decoder_start_token_id
must be a tensor with shape (batch_size,)
.
Motivation
Some multilingual encoder-decoder models use the decoder_start_token_id
to indicate the target language. Thus, this change would allow generation into multiple target languages in parallel, as illustrated in the code below.
Your contribution
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
article_text = """Videos that say approved vaccines are dangerous and cause autism, cancer or infertility are among those that will be taken down, the company said. The policy includes the termination of accounts of anti-vaccine influencers. Tech giants have been criticised for not doing more to counter false health information on their sites. In July, US President Joe Biden said social media platforms were largely responsible for people's scepticism in getting vaccinated by spreading misinformation, and appealed for them to address the issue. YouTube, which is owned by Google, said 130,000 videos were removed from its platform since last year, when it implemented a ban on content spreading misinformation about Covid vaccines. In a blog post, the company said it had seen false claims about Covid jabs "spill over into misinformation about vaccines in general". The new policy covers long-approved vaccines, such as those against measles or hepatitis B. "We're expanding our medical misinformation policies on YouTube with new guidelines on currently administered vaccines that are approved and confirmed to be safe and effective by local health authorities and the WHO," the post said, referring to the World Health Organization."""
model_name = "csebuetnlp/mT5_m2m_crossSum_enhanced"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
get_lang_id = lambda lang: tokenizer._convert_token_to_id(
model.config.task_specific_params["langid_map"][lang][1]
)
target_langs = ["portuguese", "spanish"]
input_ids = tokenizer(
[WHITESPACE_HANDLER(article_text)],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)["input_ids"]
input_ids = input_ids.expand(len(target_langs), -1) # shape (num_target_languages, num_input_tokens)
decoder_start_token_id = torch.tensor(
[get_lang_id(t) for t in target_langs],
dtype=input_ids.dtype,
device=input_ids.device
) # shape (num_target_languages,)
output_ids = model.generate(
input_ids=input_ids,
decoder_start_token_id=decoder_start_token_id,
max_length=84,
no_repeat_ngram_size=2,
num_beams=4,
)
summaries = tokenizer.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
print(summaries)
cc @zucchini-nlp
@dpernes Hi, if you want to specify in different decoder_start_token_ids for each element, you can do it by passing a tensor of shape (batch_size, seq_len)
. In your case adding this line before the generate
is called will solve the issue:
decoder_start_token_id = decoder_start_token_id.unsqueeze(1) # shape (num_target_languages, 1)
Great, thank you @zucchini-nlp! This behavior is not documented, though:
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
You may want to change it to something like:
decoder_start_token_id (`Union[int, torch.LongTensor]`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. Optionally, use a `torch.LongTensor` of shape `(batch_size, sequence_length)` to specify a prompt for the decoder.
But why isn't this the same as passing decoder_input_ids
to generate
? I tried passing the same tensor as decoder_input_ids
instead of decoder_start_token_id
and the results do not match.
Thanks, I added a PR extending the docs.
Regarding your question, there is a subtle difference between them. The decoder_start_token_id
is used as the very first token in generation, BOS
token in most cases. But decoder_input_ids
are used to start/continue the sentence from them. In most cases you do not provide decoder_input_ids
yourself when calling generate
, so they will be filled with decoder_start_token_id
to start generation from BOS
.
The general format is [decoder_start_token_id, decoder_input_ids]
and the generate
automatically fills in decoder_start_token_id
from config if you do not provide them.
Hi,
Is there any way to specify decoder_start_token_id
during training as well?
Like
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
decoder_start_token_id=decoder_start_token_id,
)
loss = outputs.loss
Each batch may require a different decoder_start_token_id during training. This is because each batch has a specific input language and output language. Sometimes, the output language is <ENG> and some other times it is <FRE>.
Changing model.config.decoder_start_token_id
per each batch doesn't seem to be a good approach. Specifically, it seems it causes lots of inconsistency when using Accelerator with DeepSpeed.
Hey @tehranixyz , you do not need to specify decoder_start_token_ids
while training. All you need is to prepare the decoder_input_ids
and pass it to the forward. We use the start token from model config only when we do not find decoder_input_ids
from the user (see code snippet for preparing decoder input ids from labels)
Gotcha!
I was a bit confused by the warning saying
The decoder_input_ids are now created based on the "labels", no need to pass them yourself anymore.
when using EncoderDecoderModel.
So in my case, I guess, as you said, I have to prepare decoder_input_ids
myself by shifting labels and adding the appropriate start_token
at the beginning.
Many thanks!