MeanSum
MeanSum copied to clipboard
Possible bug when setting k > 1 & Gumbel_hard = True
Thanks for the interesting work and code
I was trying to get my head around the code and I couldn't understand something:
When training the mlstm
model If we try the following set of parameters:
- gumbel_hard = true
- sampling method = "greedy" or "sample"
- k > 1
In the line mlstm #L291
The logits_to_prob
function will return a strict one hot vector according to the Torch gumbel softmax implementation F.gumbel_softmax
Afterwards, this prob vector is sent to prob_to_vocab_id
method which is supposed to apply either torch.top_k
(beam search) or torch.multinomial
(top k sampling).
Implementation wise this shouldn't show any errors in beam search because of the torch.topk
function ability to handle draws, however, the top k you get aren't the actual top k probabilities e.g.
But if you try to sample multinomial
from 1 hot vector where K > 1 you get a runtime error:
Am I missing something here?