Integrate MLMMaskGenerator into BERT example
Currently the BERT example writes custom code to generate MLM mask, which is slow. We should replace it with the MLMMaskGenerator.
Hey @chenmoneygithub I can take this up incase no one else wants to!
@aflah02 Thanks! assigned to you
@chenmoneygithub Hey Sorry for the delay I want to confirm how should I handle the extra parameters in the MLMMaskGenerator which are not used in the function From what I got in the following piece of code:
(
tokens,
masked_lm_positions,
masked_lm_labels,
) = create_masked_lm_predictions(
tokens,
masked_lm_prob,
max_predictions_per_seq,
vocab_words,
rng,
)
- masked_lm_prob is the value for the mask_selection_rate parameter
- max_predictions_per_seq is the value for the mask_selection_length parameter
- rng is the value for the rng parameter
However MLMMaskGenerator also takes the following:
- mask_token_id: I plan to handle this by computing this value before the call and passing it
- unselectable_token_ids: For this I don't think there is currently any similar parameter so I'm not quite sure how to do this without changing every function calling this and passing this list
- mask_token_rate, random_token_rate: Not quite sure how I'll incorporate these
@aflah02 thanks for working on this! Re your question:
- mask_token_id: this should be the index of "[MASK]" in the vocab.
- unselectable_token_ids: you can leave it as default. Usually this is the same as padding token id, which means we don't want to mask out padding tokens for MLM.
- mask_token_rate, random_token_rate: set as 0.8 and 0.1 separately. This stands for for selected tokens, 80% chance we replace it with a mask token, and 10% chance we replace it with a random token, and 10% leave it unchanged.
This is probably not a priority for us right now. Closing for now!