keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Layer for Permutation Language Modelling [XLNet]

Open abheesht17 opened this issue 3 years ago • 4 comments

Building on https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/layers/preprocessing/mlm_mask_generator.py which dynamically masks tokens, I was wondering if we can implement a layer for how XLNet generates permutation masks for its inputs (Permutation Language Modelling).

This is a very good function which generates inputs for XLNet: https://github.com/huggingface/transformers/blob/72728be3dbca26c70dddc8b724eb2c8d901e97dc/src/transformers/data/data_collator.py#L1230

The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
    0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
    1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
    2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
       masked
    3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
       span_length]` and mask tokens `start_index:start_index + span_length`
    4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
       sequence to be processed), repeat from Step 1.

Would love to take this up!

abheesht17 avatar Apr 23 '22 22:04 abheesht17

@abheesht17 Thanks for opening this feature request!

I have one question - why are we still doing masking at step 3? I am not very familiar with permutation language modeling, but reading some articles, it does not apply masks any more?

chenmoneygithub avatar Apr 25 '22 18:04 chenmoneygithub

Hello, @chenmoneygithub! I think the reason is as follows:

XLNet has multiple factorisation orders since it permutes the input sequence. Suppose our input text is [1, 2, 3, 4], and assume that XLNet generates two permutations - [3, 2, 4, 1] and [2, 4, 3, 1]. Then, in the first case, if we want to compute the updated representation of token "3", we will mask "2", "4", "1" (since they come after "3"), and in the second case, we will mask "1". That's why we have a "permutation mask" for XLNet.

(Sorry for the late reply)

abheesht17 avatar Apr 29 '22 15:04 abheesht17

This figure explains it well: image

abheesht17 avatar Apr 29 '22 15:04 abheesht17

A more concrete explanation: image https://www.borealisai.com/en/blog/understanding-xlnet/

abheesht17 avatar Apr 29 '22 16:04 abheesht17