xlnet icon indicating copy to clipboard operation
xlnet copied to clipboard

why not use a partial factorization ?

Open fyubang opened this issue 5 years ago • 7 comments

Hi, thanks for the great work. I have a question about the process of pre-training. Assuming a lenth-T sequence with a permutation z, since we use a partial prediction(only predict tokens after z_c), why dont we use a partial factorization? which means that using a all-zero attention mask for the first c tokens to make the context tokens see each other ?

Thank you.

fyubang avatar Jun 25 '19 09:06 fyubang

Same question in here. Why not permute only on partial prediction parts, and let it based on bidirectional contexts x_{z<=c}) and previous z in permutation? Like

E[log P(x_{z>c} | x_{z<=c})] = E[sum_{t=c+1}^{|z|} log P(x_t | x_{c<z<t}, x_{z<=c})]

Thank you.

desperadoola avatar Jun 25 '19 09:06 desperadoola

Good question. In fact we are using the implementation that you just mentioned. Sorry about the confusion.

kimiyoung avatar Jun 26 '19 00:06 kimiyoung

@kimiyoung To make it clearer, let us walk through a concrete example:

Assume the original sentence is 12345678, and the permutation is 12367845. The last two tokens, i.e.: 4 and 5, are to be predicted. What we want is: let token 123678 attend to each other (note that according to the paper, i cannot attend to j if i < j, i.e.: 2 only attend to 1, and 3 only attend to 12, and so forth), let token 4 attend to 123678, and let token 5 attend to 1236784.

I haven't fully understood the source code, but this comment says can attend if i > j or j is non-masked. Combined with the above example, are token 4 and 5 masked? And token 123678 DO attend to each other because j is non-masked?

If this understanding is correct, please update the formula in the paper accordingly to address this issue. Thanks.

soloice avatar Jun 27 '19 09:06 soloice

Yes, tokens 123678 have bidirectional attention and they attend to all the other tokens, while tokens 4 and 5 use an auto-regressive factorization conditioned on 123678. This is what we meant in our paper, but I agree that the description was not clear and specific enough. Will fix the paper soon.

kimiyoung avatar Jun 27 '19 17:06 kimiyoung

@kimiyoung Hi, I also have the problem here. I tried the source code _local_perm in data_utils with a given input. inputs: 1,2,3,4,5,6,7,8,9,10 is_masked:False,False,False,False,False,True, True, False,False, False perm_size and seq_len: 10 The rev_index in function is:[-1, -1, -1, -1, -1, 7, 8, -1, -1, -1] And the perm_mask is: [[0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.]]

It looks like the shuffled inputs is x,x,x,x,x,7,8,x,x,x But I can't understand the output of the perm_mask here. Shouldn't the units after 8 all be masked for num 8 ? Another question: the input's order and target_mask is preserved (the shuffle only works inside the function), so the partial_predict still points to the original 5,6 rather than the shuffled 7,8 ?

zkmake520 avatar Jun 28 '19 07:06 zkmake520

@kimiyoung Hi, I also have the problem here. I tried the source code _local_perm in data_utils with a given input. inputs: 1,2,3,4,5,6,7,8,9,10 is_masked:False,False,False,False,False,True, True, False,False, False perm_size and seq_len: 10 The rev_index in function is:[-1, -1, -1, -1, -1, 7, 8, -1, -1, -1] And the perm_mask is: [[0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 1., 0., 0., 0.]]

It looks like the shuffled inputs is x,x,x,x,x,7,8,x,x,x But I can't understand the output of the perm_mask here. Shouldn't the units after 8 all be masked for num 8 ? Another question: the input's order and target_mask is preserved (the shuffle only works inside the function), so the partial_predict still points to the original 5,6 rather than the shuffled 7,8 ?

Be aware that inputs is not really shuffled ( In section2.2, the author said “we keep the original sequence order, and rely on a proper attention mask in Transformers to achieve permutation of the factorization order.”) Back to ur example, suppose inputs=[1,2,3,4,5,6,7,8,9,10], and the masked tokens are inputs[5] & inputs[6] (which means num6 & num7). Now u run function _local_perm and get rev_index=[-1, -1, -1, -1, -1, 7, 8, -1, -1, -1] (notice that 7 in rev_index is not a number but a index , rev_index[5]=7 means the fifth token in original inputs(which is num6 ) is now permuted to location7) ,and thus the 'shuffled inputs' u have mentioned can be seen as [x, x, x, x, x, x, x, 6, 7, x]. Given the above idea, u can have a better understand on perm_mask. perm_mask[i,j]=1 means the i-th token in inputs(original sequence order!) cant attend to the j-th token in inputs(original sequence order!). Lets see a concrete example: perm_mask[5,6]=1, which means inputs[5](aka num6) cant attend to inputs[6](aka num7), it is straitforward because num7 is masked. perm_mask[6,5]=0, which means inputs[6](aka num7) can attend to inputs[5](aka num6), it is because num6 is left to num7.

xingchensong avatar Jul 19 '19 10:07 xingchensong

Hi, @kimiyoung . The paper says and as you said here:

Formally, we split z into a non-target subsequence z≤c and a target subsequence z>c, where c is the cutting point. The objective is to maximize the log-likelihood of the target subsequence conditioned on the non-target subsequence.

However, I do not find this cutting point for target subsequence and non-target subsequence in this function. https://github.com/zihangdai/xlnet/blob/5cd50bc451436e188a8e7fea15358d5a8c916b72/data_utils.py#L331 This function, which creates mask, I think it only does this

We employ an idea of span-based prediction, where we first sample a length L ∈ [1, · · · , 5], and then randomly select a consecutive span of L tokens as prediction targets within a context of (KL) tokens.

Or, do I miss some function in the code which defines the cutting point?

nina124 avatar Sep 05 '19 03:09 nina124