transformer
transformer copied to clipboard
Question on maskings
Hi @Kyubyong,
Can you help explain a bit on the following masking codes (the Key Masking and Query Masking) in the modules.py? Why we need them? We only need the causality, right?
# Key Masking
key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k)
key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k)
key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k)
paddings = tf.ones_like(outputs)*(-2**32+1)
outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k)
# Query Masking
query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q)
query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q)
query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k)
outputs *= query_masks # broadcasting. (N, T_q, C)
Thanks!
Also wondering what's the purpose of these query masking and key masking. Can anyone help? Thanks very much. @duyvuleo @Kyubyong
@duyvuleo Hi, the query masking and key masking are utilized to filtered out the attention allocated to `padding' in the sequence.
@shrshore : Thanks. I now understood it.
@haoransh When you pass inputs as an argument to function positional_encoding, yes, the inputs consists of padding info. However, inside of positional_encoding, this code just extract the shape info of inputs , without padding info. That means, the zero embedding vector of lookup table of positional_encoding is NOT the same thing as the lookup table of the word embedding.
This would result in the position-encoded padding embedding-vector non-zeros, let's take an example to make it clear. if T = maxlen = 6 and input sentence 'This mask simply fail', we get:
x = [[index_this, index_mask, index_simply, index_fail, 3, 0]] shape(1, T), 3 represent '<\S>' and 0 '<\PAD>', and '\PAD' is at 6th position
word embedding x_embedding = [[ [not-all-zeros], [not-all-zeros],...[0, 0, ..., 0]]] shape(1, T, len(word embedding vector))
positional embedding x_position = [[0, 1, 2, 3, 4, 5, 6]] if zero_pad = True x_position_embedding = [[[0, 0, ..., 0], [not-all-zeros], ..., [not-all-zeros]]] shape(1, T, len(positional embedding vector))
now, let's add the embeddings. The lengths of both embedding vectors are the same. x_embedding + x_position_embedding = [[ [not-all-zeros], [not-all-zeros], ..., [not-all-zeros]]]
So that the mask simply does not fulfill its original purpose to find out the paddings.
@haoransh When you pass inputs as an argument to function positional_encoding, yes, the inputs consists of padding info. However, inside of positional_encoding, this code just extract the shape info of inputs , without padding info. That means, the zero embedding vector of lookup table of positional_encoding is NOT the same thing as the lookup table of the word embedding.
This would result in the position-encoded padding embedding-vector non-zeros, let's take an example to make it clear. if T = maxlen = 6 and input sentence 'This mask simply fail', we get:
x = [[index_this, index_mask, index_simply, index_fail, 3, 0]] shape(1, T), 3 represent '<\S>' and 0 '<\PAD>', and '\PAD' is at 6th position
word embedding x_embedding = [[ [not-all-zeros], [not-all-zeros],...[0, 0, ..., 0]]] shape(1, T, len(word embedding vector))
positional embedding x_position = [[0, 1, 2, 3, 4, 5, 6]] if zero_pad = True x_position_embedding = [[[0, 0, ..., 0], [not-all-zeros], ..., [not-all-zeros]]] shape(1, T, len(positional embedding vector))
now, let's add the embeddings. The lengths of both embedding vectors are the same. x_embedding + x_position_embedding = [[ [not-all-zeros], [not-all-zeros], ..., [not-all-zeros]]]
So that the mask simply does not fulfill its original purpose to find out the paddings.
Good answer ! I alse have the same doubts.
@gitfourteen @jiangxinyang227 Yes, this issue has been posted here https://github.com/Kyubyong/transformer/issues/33 before.
So this repo can only serve as a toy example, not the same as the original implementation in tensor2tensor. Also if you are interested, you can also refer to another tensorflow implementation here, which is the same as the original implementation but much easier to follow than tensor2tensor.
Also, see #3