transformer-xl
transformer-xl copied to clipboard
Some questions about pytorch code and details.
Hi, there: So nice that you release the original code. Maybe a little difficult for me to reproduce: ( After nearly 1.5 days for matching your paper and code, still... some questions about model structure, hope you could help, maybe some foolish ...
-
What's the difference between
RelLearnableMultiHeadAttn
andRelPartialLearnableMultiHeadAttn
? Seem the most important part is the construction of embedding (A+B+C+D), but the first one doesn't use the position embedding in "Attention is all you need"? -
Can you explain the function
_rel_shift
in detail for me? Especially the top -4 line code, I don't know why we need this? -
What happens when the param
div_val > 1
and what's the meaning of thecutoff_xxx
? More specifically, I think what we need is the part of code whendiv_val==1
.
Hope you could help me, thx.
-
RelLearnableMultiHeadAttn
corresponds to the "relative positional encoding" Shaw et al. (2018) proposed, which merges the multiplication W^kR into a single trainable matrix hat{R} (see the last paragraph of Page 5).RelPartialLearnableMultiHeadAttn
is the "relative positional encoding" we proposed in this work. -
It is easier to give an example. To perform relative attention, we want to relatively shift the attention score matrix as follows:
a00 a01 a02 a02 0 0
a10 a11 a12 => a11 a12 0
a20 a21 a22 a20 a21 a22
What the _rel_shift
does is just a clear way of achieving the transformation above:
a00 a01 a02 0 a00 a01 a02 0 a00 a01 a02 0 a10 a02 0 0
a10 a11 a12 => 0 a10 a11 a12 => a02 0 a10 => a11 a12 0 => a11 a12 0
a20 a21 a22 0 a20 a21 a22 a11 a12 0 a20 a21 a22 a20 a21 a22
a20 a21 a22
- Append one "column" of zeros to the left
- Reshape the matrix from [3 x 4] into [4 x 3]
- Remove the first "row"
- Mask out the upper triangle
- The
div_val
is the ratio used to reduce the embedding dimension from each bin, wherecutoff
is the boundary of the bins. The name is adapted from original PyTorch class.
So glad to see your reply, and list some person understanding, could help me correct them?
- About the shift operation in 2., seems an easy way to calculate position embedding in Appendix B?
- There seems some redundant code in pytorch version code, e.g. the
_shift
function?
By the way ,seems like you add position embedding at each layer, is there any improvement compared with add only with the word embedding in your ablation study? @zihangdai
@wlhgtc
-
def _rel_shift
does correspond to Appendix B. - You are right.
def _shift
is unused. - Relative positional encodings are defined on word pairs rather than a single word so they could not be added on word embeddings. That being said, it is possible to tie the positional encodings for different layers. Empirically AFAIR, tying/untying the relative positional encodings does not lead to substantial changes in terms of performance.
Seem your position embedding conflict with the original version in <Attention is all you need> ???
your layer seems like the second col(sin,sin,...,sin,cos,cos,...,cos); but it should like the first col(sin,cos,sin,cos,...).
And thanks for your help, I finish detach the whole TRANSFORMER-XL model code in a single file.
Still one question about your training process:
https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/train.py#L433-L437
Seem you split the whole context into several chunk, and the mems[i]
is used for training data[i+1] .
But this code doesn't prove this? Or some special points in the BalancedDataParallel
class?
@kimiyoung Hope you could help
- For position embedding, the two columns are equivalent, simply because they are consumed by the matrix multiplication which is permutation-invariant.
- Just copy what we have explained in the README file ==>
--batch_chunk
: this option allows one to trade speed for memory. Forbatch_chunk > 1
, the program will split each training batch intobatch_chunk
sub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided by batch_chunk. Hence, the memory usage will proportionally lower while the computation time will inversely higher. - For
BalancedDataParallel
, see #5
Yeach, but I mean when we training on data[i], we need mems[i-1]: the memory for the last chunk.
But ret = para_model(data_i, target_i, *mems[i])
seem use mems[i]?
No. The split by batch_chunk
is along the batch dimension. In this case, mems
is a python list (line 424), where mems[i]
correspond to the i-th chunk, i.e., mini-batch.
Fine, I re-read the code, seem the mems
update when a batch finish and will be used in the next batch, am I right?
But according to the Figure 2 in your paper, I think the mems
should flow between different segments. So I regard each chunk as different segment, but seems like that the different batch in the iterator are different segments ?
Please refer to the _update_mems
function for how a single mem
is updated.
When batch_chunk
is used, each element mems[i]
in mems
is updated the exactly same way and then returned to the train loop so that it can be used for the next segment (see this line for how the mems[i]
is returned).
@zihangdai Gonna Piggyback on this issue since my question is somewhat related: Could you explain in a little more detail how the segment level recurrence works in code? I can see that you calculate the mems for an entire batch of (i assume consecutive) segments, and then reuse that in the next step, but I am confused on how recurrence between consecutive segments inside of one batch works.
Im asking this, because im thinking about how to adapt this model to a different task like question answering, and cant really wrap my head around how to build the segmentation when you have to distinguish between the contexts for different questions and cant treat the entire corpus as one large chunk of text.
@kimiyoung @zihangdai Following up,
It seems that by default, the zeroing of upper triangular matrix is False.
https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L194
What is the reason for that?
@BenjaminWinter I'm also confused about segment level recurrence in the paper?
@zihangdai could you please clarify this issue? I can't find anywhere how you deal with multiple segments.