transformers icon indicating copy to clipboard operation
transformers copied to clipboard

MBART pretrained model is unable to produce output in the target language

Open haqsaiful opened this issue 2 years ago • 5 comments

Hi,

I am using mbart-large-50 for generation task. Source language is Hindi and target language is Gujarati. However, I am always getting the output in Hindi. It is expected to get few tokens in the target language even though its a pretrained model since i am forcing the BOS token to the target language.

Sharing the code that i am using for this task.

`

translate Hindi to Gujarati

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50") tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")

tokenizer.src_lang = "hi_IN" article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है" encoded_hi = tokenizer(article_hi, return_tensors="pt") generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["gu_IN"]) tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)`

@patrickvonplaten

haqsaiful avatar Dec 06 '22 08:12 haqsaiful

You also need to set the tokenizer.tgt_lang I believe. Also cc @ArthurZucker

sgugger avatar Dec 06 '22 12:12 sgugger

I think you are just using the wrong checkpoint. Using the "facebook/mbart-large-50-many-to-many-mmt" I obtain the following : યુનાઇટેડ સ્ટેટ્સ ઓફ અમેરિકાના પ્રાંતિકારી کہتے हैं कि सीरिया में कोई सैन्य समाधान नहीं है which, according to Google is Gujarati!.

ArthurZucker avatar Dec 06 '22 15:12 ArthurZucker

@ArthurZucker "facebook/mbart-large-50-many-to-many-mmt" is fine tuned checkpoint. I am trying with a pretrained checkpoint which is "facebook/mbart-large-50".

The pretrained checkpoint should also be able to give output in the target language if we force the BOS token to the target language. The output may be little bit distorted but that's fine. Here, its giving the output same as the source language.

haqsaiful avatar Dec 07 '22 01:12 haqsaiful

The pretrained checkpoint should also be able to give output in the target language if we force the BOS token to the target language

I think this depends on the language since it is a pretrained checkpoint as mentioned on the model card :

mbart-large-50 is pre-trained model and primarily aimed at being fine-tuned on translation tasks. It can also be fine-tuned on other multilingual sequence-to-sequence tasks. See the model hub to look for fine-tuned versions.

Since it works totally fine with the fine-tuned checkpoint, this is not a bug.

ArthurZucker avatar Dec 30 '22 06:12 ArthurZucker

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jan 23 '23 15:01 github-actions[bot]

Hi, is there some mismatch between the tokenizer of facebook/mbart-large-50 and shift_tokens_right of MBartForConditionalGeneration? Since the tokenizer of facebook/mbart-large-en-ro would give X [eos, src_lang_code] while facebook/mbart-large-50's tokenizer would give [src_lang_code] X [eos], but they both use the same shift_tokens_right method which I believe is only suitable for input like this X [eos, src_lang_code] :


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
    """
    Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
    have a single `decoder_start_token_id` in contrast to other Bart-like models.
    """
    prev_output_tokens = input_ids.clone()

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)

    index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
    decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
    prev_output_tokens[:, 0] = decoder_start_tokens

    return prev_output_tokens

Hannibal046 avatar Jan 29 '23 16:01 Hannibal046

Indeed. But as mentioned in the documentation :

The text format for MBart-50 is slightly different from mBART. For MBart-50 the language id token is used as a prefix for both source and target text i.e the text format is [lang_code] X [eos], where lang_code is source language id for source text and target language id for target text, with X being the source or target text respectively. While For MBart [...] the source text format is X [eos, src_lang_code] where X is the source text. The target text format is [tgt_lang_code] X [eos]. bos is never used.

Which is why they don't have the same tokenization scheme. I checked that when generating, the forced_decoder_id properly works, and I think this issue can be closed as there are no guarantee that a certain pair of language will produce intelligible result as the checkpoints are pretrained.

ArthurZucker avatar Jan 30 '23 09:01 ArthurZucker

Hi, thanks for the comments! It is true that using MBart-50 to do generation with proper forced_decoder_id works. But it doesn't work on supervised learning scenarios. When there is no decoder_input_ids for training, Mbart-50 would automatically createdecoder_input_ids from labels which follows the tokenization scheme of Mbart rather than Mbart-50. And I think this should be fixed. MBart and MBart-50 2023-01-30 17-49-21

Hannibal046 avatar Jan 30 '23 09:01 Hannibal046

I am not sure I understand. When the decoder_input_ids are created from the labels, they are a shifted version. Let's use the example:

  • src_text : 'en_XX UN Chief Says There Is No Military Solution in Syria</s>'
  • labels : 'ro_RO Şeful ONU declară că nu există o soluţie militară în Siria</s>'
  • shifted labels : '</s>ro_RO Şeful ONU declară că nu există o soluţie militară în Siria' (= decoder_inputs_ids) This means that the shifted_labels will follow the correct pattern (which you enforce when generating).

ArthurZucker avatar Jan 30 '23 10:01 ArthurZucker

Sorry, my bad. You are right. I mistakenly thought the generation schema of MBart-50 is the same as MBart, whose decoder_start_token_id is the lang_id.

Hannibal046 avatar Jan 30 '23 10:01 Hannibal046

The MBart50 AI model is not translating the entire document; it is cutting it in half. How can we fix this?

mgrbonjourdesigns11 avatar Aug 04 '23 18:08 mgrbonjourdesigns11

Hey! Could you open a new issue with a reproducer for this? 😉

ArthurZucker avatar Aug 07 '23 06:08 ArthurZucker