torch-rnn icon indicating copy to clipboard operation
torch-rnn copied to clipboard

Beam search for sampling

Open ostrosablin opened this issue 8 years ago • 2 comments

I've looked at Andrej's char-rnn repository and found this issue, that describes a better sampling strategy which leads to significantly better quality of samples with much reduced probability of typos: Issue 138

It's more computationally intensive, but quality increase should be worth it. I wonder if it's possible to implement that feature for torch-rnn? If I understand it correctly, it's just a heurestic-based variant of BFS, which explores several possible next characters and chooses the best route, shouldn't be hard to implement.

ostrosablin avatar Jun 20 '16 10:06 ostrosablin

This is quite rough—and I'm not 100% it's actually beam search—but I did get something up and running:

https://gist.github.com/robinsloan/e5ce8d3b7892f797f759905be5f7c68d

If you drop those three files in your torch-rnn folder and run beamtest.lua (making sure to point the checkpoint to whatever you have handy) you should get reasonable results. FWIW, I found the output less useful for my purposes than greedy search; it's more "conservative," similar to what you get from a lower sampling temperature.

robinsloan avatar Jun 26 '16 20:06 robinsloan

Well, from how I understand the algorithm, it's exactly what it's supposed to be. Explore N most-probable nodes, and check which path will yield higher probabilities product.

I didn't try it yet, but "conservativity" boost is to be expected. I just hope it's not very similar to low sampling temperature results (like 0.3) though, because that usually ends up with same sentence being repeated through the entire sample, which is not interesting at all. According to comments in Issue 138, beam search makes network a bit more conservative, but in a good way (such as not falling into loops on URIs, etc) and should make text more sensible than in greedy search.

ostrosablin avatar Jun 27 '16 07:06 ostrosablin