xlnet
xlnet copied to clipboard
why not use a partial factorization ?
Hi, thanks for the great work. I have a question about the process of pre-training. Assuming a lenth-T sequence with a permutation z, since we use a partial prediction(only predict tokens after z_c), why dont we use a partial factorization? which means that using a all-zero attention mask for the first c tokens to make the context tokens see each other ?
Thank you.
Same question in here. Why not permute only on partial prediction parts, and let it based on bidirectional contexts x_{z<=c}) and previous z in permutation? Like
E[log P(x_{z>c} | x_{z<=c})] = E[sum_{t=c+1}^{|z|} log P(x_t | x_{c<z<t}, x_{z<=c})]
Thank you.
Good question. In fact we are using the implementation that you just mentioned. Sorry about the confusion.
@kimiyoung To make it clearer, let us walk through a concrete example:
Assume the original sentence is 12345678, and the permutation is 12367845. The last two tokens, i.e.: 4 and 5, are to be predicted. What we want is: let token 123678 attend to each other (note that according to the paper, i
cannot attend to j
if i
< j
, i.e.: 2 only attend to 1, and 3 only attend to 12, and so forth), let token 4 attend to 123678, and let token 5 attend to 1236784.
I haven't fully understood the source code, but this comment says can attend if i > j or j is non-masked
. Combined with the above example, are token 4 and 5 masked? And token 123678 DO attend to each other because j is non-masked
?
If this understanding is correct, please update the formula in the paper accordingly to address this issue. Thanks.
Yes, tokens 123678 have bidirectional attention and they attend to all the other tokens, while tokens 4 and 5 use an auto-regressive factorization conditioned on 123678. This is what we meant in our paper, but I agree that the description was not clear and specific enough. Will fix the paper soon.
@kimiyoung Hi, I also have the problem here. I tried the source code _local_perm in data_utils with a given input. inputs: 1,2,3,4,5,6,7,8,9,10 is_masked:False,False,False,False,False,True, True, False,False, False perm_size and seq_len: 10 The rev_index in function is:[-1, -1, -1, -1, -1, 7, 8, -1, -1, -1] And the perm_mask is: [[0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.]]
It looks like the shuffled inputs is x,x,x,x,x,7,8,x,x,x But I can't understand the output of the perm_mask here. Shouldn't the units after 8 all be masked for num 8 ? Another question: the input's order and target_mask is preserved (the shuffle only works inside the function), so the partial_predict still points to the original 5,6 rather than the shuffled 7,8 ?
@kimiyoung Hi, I also have the problem here. I tried the source code _local_perm in data_utils with a given input. inputs: 1,2,3,4,5,6,7,8,9,10 is_masked:False,False,False,False,False,True, True, False,False, False perm_size and seq_len: 10 The rev_index in function is:[-1, -1, -1, -1, -1, 7, 8, -1, -1, -1] And the perm_mask is: [[0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.]]
It looks like the shuffled inputs is x,x,x,x,x,7,8,x,x,x But I can't understand the output of the perm_mask here. Shouldn't the units after 8 all be masked for num 8 ? Another question: the input's order and target_mask is preserved (the shuffle only works inside the function), so the partial_predict still points to the original 5,6 rather than the shuffled 7,8 ?
Be aware that inputs is not really shuffled ( In section2.2, the author said “we keep the original sequence order, and rely on a proper attention mask in Transformers to achieve permutation of the factorization order.”) Back to ur example, suppose inputs=[1,2,3,4,5,6,7,8,9,10], and the masked tokens are inputs[5] & inputs[6] (which means num6 & num7). Now u run function _local_perm and get rev_index=[-1, -1, -1, -1, -1, 7, 8, -1, -1, -1] (notice that 7 in rev_index is not a number but a index , rev_index[5]=7 means the fifth token in original inputs(which is num6 ) is now permuted to location7) ,and thus the 'shuffled inputs' u have mentioned can be seen as [x, x, x, x, x, x, x, 6, 7, x]. Given the above idea, u can have a better understand on perm_mask. perm_mask[i,j]=1 means the i-th token in inputs(original sequence order!) cant attend to the j-th token in inputs(original sequence order!). Lets see a concrete example: perm_mask[5,6]=1, which means inputs[5](aka num6) cant attend to inputs[6](aka num7), it is straitforward because num7 is masked. perm_mask[6,5]=0, which means inputs[6](aka num7) can attend to inputs[5](aka num6), it is because num6 is left to num7.
Hi, @kimiyoung . The paper says and as you said here:
Formally, we split z into a non-target subsequence z≤c and a target subsequence z>c, where c is the cutting point. The objective is to maximize the log-likelihood of the target subsequence conditioned on the non-target subsequence.
However, I do not find this cutting point for target subsequence and non-target subsequence in this function. https://github.com/zihangdai/xlnet/blob/5cd50bc451436e188a8e7fea15358d5a8c916b72/data_utils.py#L331 This function, which creates mask, I think it only does this
We employ an idea of span-based prediction, where we first sample a length L ∈ [1, · · · , 5], and then randomly select a consecutive span of L tokens as prediction targets within a context of (KL) tokens.
Or, do I miss some function in the code which defines the cutting point?