MeanSum icon indicating copy to clipboard operation
MeanSum copied to clipboard

Possible bug when setting k > 1 & Gumbel_hard = True

Open hadyelsahar opened this issue 5 years ago • 0 comments

Thanks for the interesting work and code

I was trying to get my head around the code and I couldn't understand something:

When training the mlstm model If we try the following set of parameters:

  • gumbel_hard = true
  • sampling method = "greedy" or "sample"
  • k > 1

In the line mlstm #L291 The logits_to_prob function will return a strict one hot vector according to the Torch gumbel softmax implementation F.gumbel_softmax

Afterwards, this prob vector is sent to prob_to_vocab_id method which is supposed to apply either torch.top_k (beam search) or torch.multinomial (top k sampling).

Implementation wise this shouldn't show any errors in beam search because of the torch.topk function ability to handle draws, however, the top k you get aren't the actual top k probabilities e.g.

image

But if you try to sample multinomial from 1 hot vector where K > 1 you get a runtime error: image

Am I missing something here?

hadyelsahar avatar Jul 24 '19 14:07 hadyelsahar