trlx
trlx copied to clipboard
Loss should be averaged over all samples instead of tokens
🐛 Describe the bug
Would it be more reasonable to calculate the average loss over the batch dim, instead of over all tokens? Now it seems that sequences of different lengths in a mini-batch are affecting each other, but I guess samples in a batch should be independent.
https://github.com/CarperAI/trlx/blob/07c962e13cbf91509f35c1a67c368393eac2333e/trlx/models/modeling_ppo.py#L195
A possible solution might be:
vf_loss = 0.5 * (torch.sum(torch.max(vf_loss1, vf_loss2) * mask, dim=1) / mask.sum(dim=1)).mean()
Please correct me if I'm wrong.
Which trlX version are you using?
trlx==0.6.0
Additional system and package information
Python 3.9.16, transformers==4.28.1, Linux