tensorflow-seq2seq-tutorials icon indicating copy to clipboard operation
tensorflow-seq2seq-tutorials copied to clipboard

Mixing teacher forcing with "feed previous"

Open suriyadeepan opened this issue 7 years ago • 1 comments

As you mentioned in the start of the 2nd tutorial, it is good idea to mix teacher forcing with "feed previous" technique, while decoding. Just thought I could share some ideas on how to do that.

prob = 0.5 # set as placeholder or tf.constant
r = tf.random_normal(shape=[],mean=prob, stddev=0.5, dtype=tf.float32) # get a random value
feed_previous = r > prob # sample -> True/False

In the loop_fn_transition function, you could add an outer condition like this.

if feed_previous:
  input = tf.cond(finished, padded_next_input, search_for_next_input)
else:
  input = tf.cond(finished, padded_next_input, fetch_next_decoder_target)

The fetch_next_decoder_target function is supposed to fetch the next decoder target by indexing decoder_targets with time - decoder_targets[time]. Though you need to transpose decoder_targets to "time major" format.

Hope this helps. I will try this and add a pull request if I find time.

suriyadeepan avatar Mar 01 '17 05:03 suriyadeepan