gpt-2
gpt-2 copied to clipboard
while text generation first past is repeating again and again.
After first generating first word, the present is not getting updated.
with tf.variable_scope(scope):
c = conv1d(x, 'c_attn', n_state*3)
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
**present = tf.stack([k, v], axis=1)**
if past is not None:
pk, pv = tf.unstack(past, axis=1)
k = tf.concat([pk, k], axis=-2)
v = tf.concat([pv, v], axis=-2)
a = multihead_attn(q, k, v)
a = merge_heads(a)
a = conv1d(a, 'c_proj', n_state)
return a, present
Please change it to
with tf.variable_scope(scope):
c = conv1d(x, 'c_attn', n_state*3)
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
if past is not None:
pk, pv = tf.unstack(past, axis=1)
k = tf.concat([pk, k], axis=-2)
v = tf.concat([pv, v], axis=-2)
**present = tf.stack([k, v], axis=1)**
a = multihead_attn(q, k, v)
a = merge_heads(a)
a = conv1d(a, 'c_proj', n_state)
return a, present
:-) they are doing that inside sample.py while loop. Their code is so efficient, they don't want to cache it altogether, while calculating logits with new predicted word.
In sample.py there is a line, if past is None else tf.concat