trl icon indicating copy to clipboard operation
trl copied to clipboard

Conflict in start index under `batched_forward_pass`

Open mertsayar8 opened this issue 8 months ago • 0 comments

The code and the comment do not align in: https://github.com/huggingface/trl/blob/b68ff96f0c74368961e194081e122959cd1f4d4d/trl/trainer/ppo_trainer.py#L1032

In the comment, it is stated that logprobs starts from the second query token, however start starts from the last query token in the code. Everything outside [last_query_token, last_response_token) are masked out in: https://github.com/huggingface/trl/blob/b68ff96f0c74368961e194081e122959cd1f4d4d/trl/trainer/ppo_trainer.py#L1037-L1038

This causes a problem especially when using response_masks in step function. response_mask is multiplied with wrong indices in: https://github.com/huggingface/trl/blob/b68ff96f0c74368961e194081e122959cd1f4d4d/trl/trainer/ppo_trainer.py#L1040 and masking is done wrongly.

I believe start should be basicly set to the first response token which is the length of the query. This will make masks[j, start:end] and response_masks_batch[j] correspond to the same indices and solve this problem.

mertsayar8 avatar Jun 27 '24 12:06 mertsayar8