pytorch-seq2seq icon indicating copy to clipboard operation
pytorch-seq2seq copied to clipboard

Backtracking in beam search

Open shubhamagarwal92 opened this issue 6 years ago • 6 comments

Compared to OpenNMT, why do we need this block which handles the dropped sequences that see EOS earlier. (This is not there in their beam search implementation.) They are also doing a similar process: not letting the EOS have children here. However they have this end condition when EOS is at the top. They construct back the hypothesis using get_hyp function.

More specifically, can you explain elaborately what we are doing here.

            #       1. If there is survived sequences in the return variables, replace
            #       the one with the lowest survived sequence score with the new ended
            #       sequences
            #       2. Otherwise, replace the ended sequence with the lowest sequence
            #       score with the new ended sequence

I understand why we need to handle EOS sequences since we have their information in backtracking variables. But why do we need to "replace the one with the lowest survived sequence score with the new ended sequences"? AFAIK, this res_k_idx is tracking which beam (from the end) can we replace the information (the two conditions specified in the comments). However, we are not replacing the contents of the beam which got EOS, i.e:

t_predecessors[res_idx] = predecessors[t][idx[0]]
t_predecessors[idx] = ??

I understand that after this process all the beams remain static and we use index_select at each step to select the top beams.

Also, the unit test for top_k_decoder is not deterministic. Fails when batch_size>2 and also sometimes when batch_size==2.

shubhamagarwal92 avatar Jun 30 '18 14:06 shubhamagarwal92

@shubhamagarwal92 thanks for pointing this out. I'll check their implementation and see what's different. Working on the test to make it more deterministic. Will test more beam sizes too.

pskrunner14 avatar Sep 02 '18 13:09 pskrunner14

Hi, Are there any updates for this issue? I've implemented a similar beam search strategy which uses the _backtrack(...) function from this repo but even with a beam_size of 1, I get worse results than greedy decoding. Would be really helpful if you can double-check the implementation. Thanks.

Mehrad0711 avatar Mar 25 '19 23:03 Mehrad0711

I studied the codes these days, and I thought you can use the torch.repeat_interleave. Such as follow: hidden = tuple([torch.repeat_interleave(h, self.k, dim=1) for h in encoder_hidden]) inflated_encoder_outputs = torch.repeat_interleave(encoder_outputs, self.k, dim=0)

GZJAS avatar Jun 11 '19 14:06 GZJAS

@Mehrad0711 maybe you can try and integrate BS from allennlp; their implementation here

shubhamagarwal92 avatar Jun 11 '19 15:06 shubhamagarwal92

Hey @Mehrad0711 @shubhamagarwal92 sorry haven't gotten the time to work on this yet. You're welcome to submit a PR :)

pskrunner14 avatar Jun 11 '19 18:06 pskrunner14

Any update on this issue ?

ghost avatar May 07 '21 13:05 ghost