torch-rnn
torch-rnn copied to clipboard
Beam search for sampling
I've looked at Andrej's char-rnn repository and found this issue, that describes a better sampling strategy which leads to significantly better quality of samples with much reduced probability of typos: Issue 138
It's more computationally intensive, but quality increase should be worth it. I wonder if it's possible to implement that feature for torch-rnn? If I understand it correctly, it's just a heurestic-based variant of BFS, which explores several possible next characters and chooses the best route, shouldn't be hard to implement.
This is quite rough—and I'm not 100% it's actually beam search—but I did get something up and running:
https://gist.github.com/robinsloan/e5ce8d3b7892f797f759905be5f7c68d
If you drop those three files in your torch-rnn
folder and run beamtest.lua
(making sure to point the checkpoint
to whatever you have handy) you should get reasonable results. FWIW, I found the output less useful for my purposes than greedy search; it's more "conservative," similar to what you get from a lower sampling temperature.
Well, from how I understand the algorithm, it's exactly what it's supposed to be. Explore N most-probable nodes, and check which path will yield higher probabilities product.
I didn't try it yet, but "conservativity" boost is to be expected. I just hope it's not very similar to low sampling temperature results (like 0.3) though, because that usually ends up with same sentence being repeated through the entire sample, which is not interesting at all. According to comments in Issue 138, beam search makes network a bit more conservative, but in a good way (such as not falling into loops on URIs, etc) and should make text more sensible than in greedy search.