transformer-xl
transformer-xl copied to clipboard
tf code question?
Good works! I have two question about your tf codes.
The first: In the paper, query vector is calculated using previous layer's hidden state rather than the concatenated pre-layer's memory and hidden state. However, in the tf code, I found the query vector is calculated the same as key vector and value vector.
The second:
Each layers has memory tensor with shape [mem_len, batch_size, d_model]. When calculating query, key and value vector, the input vector of tf.layers.dens
is the concatenation of current layer's memory and pre-layers' output. which seems be conflict with the paper. Besides, why stop gradient in _cache_mem
method rather than in rel_multihead_attn
, the later seems to make better sense.