fairseq2 icon indicating copy to clipboard operation
fairseq2 copied to clipboard

SamplingSequenceGenerator not respecting `max_gen_len`

Open heffernankevin opened this issue 9 months ago • 2 comments

Describe the bug: When specifying max_gen_len, the SamplingSequenceGenerator can potentially generate more than max_gen_len for all batched sequences whose prompt length is shorter than the longest prompt length in batch.

In this line the sequences are forced to be EOS with relation to self._max_seq_len. However, self._max_seq_len is calculated as the length of the longest prompt + max_gen_len (unless max_seq_len is specified). Therefore, any sequence in a batch whose prompt length ($N$) is shorter than the max prompt length in a batch ($M$) can potentially generate $M - N$ + max_gen_len (unless they naturally generate EOS beforehand).

Describe how to reproduce:

tokens, padding_mask = pad_seqs([torch.Tensor([4]).long(), torch.Tensor([1,5,2,6,8,4]).long()])
sampler = TopKSampler(k=1)
generator = SamplingSequenceGenerator(model, sampler, max_gen_len=10)
res = generator(prompt_seqs=tokens, prompt_padding_mask=padding_mask)

In [48]: res.hypotheses[0][0].seq.shape
Out[48]: torch.Size([15])    # max_gen_len + 5

In [49]: res.hypotheses[1][0].seq.shape
Out[49]: torch.Size([10])    # max_gen_len

Describe the expected behavior:

In [48]: res.hypotheses[0][0].seq.shape
Out[48]: torch.Size([10])    # max_gen_len

In [49]: res.hypotheses[1][0].seq.shape
Out[49]: torch.Size([10])   # max_gen_len

Environment: fairseq2==0.4.0

heffernankevin avatar Mar 06 '25 07:03 heffernankevin

Hey! Can I address this issue?

krammnic avatar Jul 17 '25 11:07 krammnic

Oops, you haven't pinged me and I haven't seen your approval. Will open PR then

krammnic avatar Sep 22 '25 12:09 krammnic