xlnet
xlnet copied to clipboard
Understanding _local_perm
I'm trying to understand the _local_perm
function.
I have trouble to understand what is done at that line
Why the first token of the inputs is added into the target ?
Edited
If we take a concrete example :
sent_a = [10, 20, 30]
sent_b = [40, 50, 60]
Then after preprocessing we will have : (4 is the SEP token and 3 the CLS token)
input = [10, 20, 30, 4, 40, 50, 4, 3]
target = [20, 30, 40, 40, 50, 60, 3, 3]
(Note : 40
is repeated to pad the SEP token : see this issue)
Let's say our goal is to predict the sent B from sent A. So we will have :
is_masked = [False, False, False, False, True, True, False, False]
If I run the function _local_perm
with these inputs (let's say the permutation gave exactly same order as token order), I have as a result :
perm_mask = [[0, 0, 0, 1, 1, 1, 1, 1]
[0, 0, 0, 1, 1, 1, 1, 1]
[0, 0, 0, 1, 1, 1, 1, 1]
[0, 0, 0, 0, 1, 1, 1, 1]
[0, 0, 0, 0, 1, 1, 1, 1]
[0, 0, 0, 0, 0, 1, 1, 1]
[0, 0, 0, 0, 0, 0, 0, 1]
[0, 0, 0, 0, 0, 0, 0, 0]]
target = [10, 20, 30, 40, 40, 50, 60, 3]
target_mask = [0, 0, 0, 0, 1, 1, 0, 0]
input K = [10, 20, 30, 4, 40, 50, 4, 3]
input Q = [0, 0, 0, 0, 1, 1, 0, 0]
In this example, I don't understand why target has same token at same position with the input ?
I would expect the output to be : input K = [10, 20, 30, 4, 40, 50, 4, 3]
(does not change) and target = [20, 30, 40, 40, 50, 60, 4, 3]
, so given 40
(pos 4 of input), model predict 50
(pos 4 of target).
@kimiyoung
also confused about this Q
@Colanim Hi, i think the answer is : For the conventional language model, we try to predict the next word based on the previous state, but in XLNet ,the model is try to predict the current word (which is a masked token) based not only on the previous state but also on current position. So the real new_targets have to be the same with inputs (except masked token and functional token).
he same with inpu
real new target
can't see itself. ------
new target: [next token] for LM and [curr token] (self) for PLM:
new_targets = tf.concat([inputs[0: 1], targets[: -1]],axis=0)