Layer for Permutation Language Modelling [XLNet]
Building on https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/preprocessing/mlm_mask_generator.py which dynamically masks tokens, I was wondering if we can implement a layer for how XLNet generates permutation masks for its inputs (Permutation Language Modelling).
This is a very good function which generates inputs for XLNet: https://github.com/huggingface/transformers/blob/72728be3dbca26c70dddc8b724eb2c8d901e97dc/src/transformers/data/data_collator.py#L1230
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
masked
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
span_length]` and mask tokens `start_index:start_index + span_length`
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
sequence to be processed), repeat from Step 1.
Would love to take this up!
@abheesht17 Thanks for opening this feature request!
I have one question - why are we still doing masking at step 3? I am not very familiar with permutation language modeling, but reading some articles, it does not apply masks any more?
Hello, @chenmoneygithub! I think the reason is as follows:
XLNet has multiple factorisation orders since it permutes the input sequence. Suppose our input text is [1, 2, 3, 4], and assume that XLNet generates two permutations - [3, 2, 4, 1] and [2, 4, 3, 1]. Then, in the first case, if we want to compute the updated representation of token "3", we will mask "2", "4", "1" (since they come after "3"), and in the second case, we will mask "1". That's why we have a "permutation mask" for XLNet.
(Sorry for the late reply)
This figure explains it well:

A more concrete explanation:
https://www.borealisai.com/en/blog/understanding-xlnet/