RLSeq2Seq icon indicating copy to clipboard operation
RLSeq2Seq copied to clipboard

scheduled sampling OOV issue

Open rajeev595 opened this issue 5 years ago • 0 comments

Hi @yaserkl ,

Thanks for publishing the code!

In line no. 525 of the attention_decoder.py file, embedding_lookup throws an error when there are OOV (greater than vocab size) id's in sample_ids_sampling. I think they should be replaced with [UNK] id number before looking up.

Maybe something like this:

    sample_ids_sampling = tf.add(
        tf.multiply(sample_ids_sampling, tf.cast(sample_ids_sampling < self._vocab.size(), dtype=tf.int32)),
        tf.multiply(self._vocab.word2id(params.UNKNOWN_TOKEN), tf.cast(sample_ids_sampling >= self._vocab.size(),
                                                                       dtype=tf.int32))
    )

Please let me know if I am missing something.

rajeev595 avatar Jun 13 '19 06:06 rajeev595