transformer-xl icon indicating copy to clipboard operation
transformer-xl copied to clipboard

Some questions about pytorch code and details.

Open wlhgtc opened this issue 6 years ago • 16 comments

Hi, there: So nice that you release the original code. Maybe a little difficult for me to reproduce: ( After nearly 1.5 days for matching your paper and code, still... some questions about model structure, hope you could help, maybe some foolish ...

  1. What's the difference between RelLearnableMultiHeadAttn and RelPartialLearnableMultiHeadAttn ? Seem the most important part is the construction of embedding (A+B+C+D), but the first one doesn't use the position embedding in "Attention is all you need"?

  2. Can you explain the function _rel_shift in detail for me? Especially the top -4 line code, I don't know why we need this?

  3. What happens when the param div_val > 1 and what's the meaning of the cutoff_xxx? More specifically, I think what we need is the part of code when div_val==1.

Hope you could help me, thx.

wlhgtc avatar Jan 15 '19 13:01 wlhgtc

  1. RelLearnableMultiHeadAttn corresponds to the "relative positional encoding" Shaw et al. (2018) proposed, which merges the multiplication W^kR into a single trainable matrix hat{R} (see the last paragraph of Page 5). RelPartialLearnableMultiHeadAttn is the "relative positional encoding" we proposed in this work.

  2. It is easier to give an example. To perform relative attention, we want to relatively shift the attention score matrix as follows:

a00 a01 a02      a02  0   0
a10 a11 a12  =>  a11 a12  0
a20 a21 a22      a20 a21 a22

What the _rel_shift does is just a clear way of achieving the transformation above:

a00 a01 a02      0 a00 a01 a02       0  a00 a01      a02  0  a10     a02  0   0
a10 a11 a12  =>  0 a10 a11 a12  =>  a02  0  a10  =>  a11 a12  0  =>  a11 a12  0
a20 a21 a22      0 a20 a21 a22      a11 a12  0       a20 a21 a22     a20 a21 a22
                                    a20 a21 a22
  • Append one "column" of zeros to the left
  • Reshape the matrix from [3 x 4] into [4 x 3]
  • Remove the first "row"
  • Mask out the upper triangle
  1. The div_val is the ratio used to reduce the embedding dimension from each bin, where cutoff is the boundary of the bins. The name is adapted from original PyTorch class.

zihangdai avatar Jan 15 '19 16:01 zihangdai

So glad to see your reply, and list some person understanding, could help me correct them?

  1. About the shift operation in 2., seems an easy way to calculate position embedding in Appendix B?
  2. There seems some redundant code in pytorch version code, e.g. the _shift function?

wlhgtc avatar Jan 16 '19 07:01 wlhgtc

By the way ,seems like you add position embedding at each layer, is there any improvement compared with add only with the word embedding in your ablation study? @zihangdai

wlhgtc avatar Jan 16 '19 13:01 wlhgtc

@wlhgtc

  • def _rel_shift does correspond to Appendix B.
  • You are right. def _shift is unused.
  • Relative positional encodings are defined on word pairs rather than a single word so they could not be added on word embeddings. That being said, it is possible to tie the positional encodings for different layers. Empirically AFAIR, tying/untying the relative positional encodings does not lead to substantial changes in terms of performance.

kimiyoung avatar Jan 17 '19 00:01 kimiyoung

Seem your position embedding conflict with the original version in <Attention is all you need> ??? your layer seems like the second col(sin,sin,...,sin,cos,cos,...,cos); but it should like the first col(sin,cos,sin,cos,...). image

wlhgtc avatar Jan 17 '19 13:01 wlhgtc

And thanks for your help, I finish detach the whole TRANSFORMER-XL model code in a single file. Still one question about your training process: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/train.py#L433-L437 Seem you split the whole context into several chunk, and the mems[i] is used for training data[i+1] . But this code doesn't prove this? Or some special points in the BalancedDataParallel class?

wlhgtc avatar Jan 17 '19 13:01 wlhgtc

@kimiyoung Hope you could help

wlhgtc avatar Jan 17 '19 13:01 wlhgtc

  1. For position embedding, the two columns are equivalent, simply because they are consumed by the matrix multiplication which is permutation-invariant.
  2. Just copy what we have explained in the README file ==> --batch_chunk: this option allows one to trade speed for memory. For batch_chunk > 1, the program will split each training batch into batch_chunk sub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided by batch_chunk. Hence, the memory usage will proportionally lower while the computation time will inversely higher.
  3. For BalancedDataParallel, see #5

zihangdai avatar Jan 17 '19 14:01 zihangdai

Yeach, but I mean when we training on data[i], we need mems[i-1]: the memory for the last chunk. But ret = para_model(data_i, target_i, *mems[i]) seem use mems[i]?

wlhgtc avatar Jan 17 '19 14:01 wlhgtc

No. The split by batch_chunk is along the batch dimension. In this case, mems is a python list (line 424), where mems[i] correspond to the i-th chunk, i.e., mini-batch.

zihangdai avatar Jan 17 '19 14:01 zihangdai

Fine, I re-read the code, seem the mems update when a batch finish and will be used in the next batch, am I right?
But according to the Figure 2 in your paper, I think the mems should flow between different segments. So I regard each chunk as different segment, but seems like that the different batch in the iterator are different segments ?

wlhgtc avatar Jan 17 '19 15:01 wlhgtc

Please refer to the _update_mems function for how a single mem is updated.

When batch_chunk is used, each element mems[i] in mems is updated the exactly same way and then returned to the train loop so that it can be used for the next segment (see this line for how the mems[i] is returned).

zihangdai avatar Jan 17 '19 15:01 zihangdai

@zihangdai Gonna Piggyback on this issue since my question is somewhat related: Could you explain in a little more detail how the segment level recurrence works in code? I can see that you calculate the mems for an entire batch of (i assume consecutive) segments, and then reuse that in the next step, but I am confused on how recurrence between consecutive segments inside of one batch works.

Im asking this, because im thinking about how to adapt this model to a different task like question answering, and cant really wrap my head around how to build the segmentation when you have to distinguish between the contexts for different questions and cant treat the entire corpus as one large chunk of text.

BenjaminWinter avatar Jan 30 '19 14:01 BenjaminWinter

@kimiyoung @zihangdai Following up,

It seems that by default, the zeroing of upper triangular matrix is False.

https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L194

What is the reason for that?

abhitopia avatar Mar 27 '19 17:03 abhitopia

@BenjaminWinter I'm also confused about segment level recurrence in the paper?

LindgeW avatar Dec 18 '19 11:12 LindgeW

@zihangdai could you please clarify this issue? I can't find anywhere how you deal with multiple segments.

aleSuglia avatar Apr 29 '21 17:04 aleSuglia