RLSeq2Seq
RLSeq2Seq copied to clipboard
scheduled sampling OOV issue
Hi @yaserkl ,
Thanks for publishing the code!
In line no. 525 of the attention_decoder.py
file, embedding_lookup
throws an error when there are OOV (greater than vocab size) id's in sample_ids_sampling
. I think they should be replaced with [UNK]
id number before looking up.
Maybe something like this:
sample_ids_sampling = tf.add(
tf.multiply(sample_ids_sampling, tf.cast(sample_ids_sampling < self._vocab.size(), dtype=tf.int32)),
tf.multiply(self._vocab.word2id(params.UNKNOWN_TOKEN), tf.cast(sample_ids_sampling >= self._vocab.size(),
dtype=tf.int32))
)
Please let me know if I am missing something.