transformer
transformer copied to clipboard
About the query_mask
Source Code:
padding_num = -2 ** 32 + 1
if type in ("k", "key", "keys"):
key_masks = tf.to_float(key_masks)
key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
key_masks = tf.expand_dims(key_masks, 1) # (h*N, 1, seqlen)
outputs = inputs + key_masks * padding_num
I think the outputs should be:
padding_num = -2 ** 32 + 1
if type in ("k", "key", "keys"):
key_masks = tf.to_float(key_masks) # (N, T_k)
key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(key_masks)[1], 1]) # (h*N, T_q, seqlen)
paddings = tf.ones_like(key_masks) * padding_num
outputs = tf.where(tf.equal(key_masks, 0), paddings, inputs)