SamplingSequenceGenerator not respecting `max_gen_len`
Describe the bug:
When specifying max_gen_len, the SamplingSequenceGenerator can potentially generate more than max_gen_len for all batched sequences whose prompt length is shorter than the longest prompt length in batch.
In this line the sequences are forced to be EOS with relation to self._max_seq_len. However, self._max_seq_len is calculated as the length of the longest prompt + max_gen_len (unless max_seq_len is specified). Therefore, any sequence in a batch whose prompt length ($N$) is shorter than the max prompt length in a batch ($M$) can potentially generate $M - N$ + max_gen_len (unless they naturally generate EOS beforehand).
Describe how to reproduce:
tokens, padding_mask = pad_seqs([torch.Tensor([4]).long(), torch.Tensor([1,5,2,6,8,4]).long()])
sampler = TopKSampler(k=1)
generator = SamplingSequenceGenerator(model, sampler, max_gen_len=10)
res = generator(prompt_seqs=tokens, prompt_padding_mask=padding_mask)
In [48]: res.hypotheses[0][0].seq.shape
Out[48]: torch.Size([15]) # max_gen_len + 5
In [49]: res.hypotheses[1][0].seq.shape
Out[49]: torch.Size([10]) # max_gen_len
Describe the expected behavior:
In [48]: res.hypotheses[0][0].seq.shape
Out[48]: torch.Size([10]) # max_gen_len
In [49]: res.hypotheses[1][0].seq.shape
Out[49]: torch.Size([10]) # max_gen_len
Environment: fairseq2==0.4.0
Hey! Can I address this issue?
Oops, you haven't pinged me and I haven't seen your approval. Will open PR then