trl
trl copied to clipboard
Conflict in start index under `batched_forward_pass`
The code and the comment do not align in: https://github.com/huggingface/trl/blob/b68ff96f0c74368961e194081e122959cd1f4d4d/trl/trainer/ppo_trainer.py#L1032
In the comment, it is stated that logprobs starts from the second query token, however start
starts from the last query token in the code. Everything outside [last_query_token, last_response_token)
are masked out in:
https://github.com/huggingface/trl/blob/b68ff96f0c74368961e194081e122959cd1f4d4d/trl/trainer/ppo_trainer.py#L1037-L1038
This causes a problem especially when using response_masks
in step
function. response_mask
is multiplied with wrong indices in:
https://github.com/huggingface/trl/blob/b68ff96f0c74368961e194081e122959cd1f4d4d/trl/trainer/ppo_trainer.py#L1040
and masking is done wrongly.
I believe start
should be basicly set to the first response token which is the length of the query. This will make masks[j, start:end]
and response_masks_batch[j]
correspond to the same indices and solve this problem.