trlx icon indicating copy to clipboard operation
trlx copied to clipboard

Loss should be averaged over all samples instead of tokens

Open li-plus opened this issue 1 year ago • 0 comments

🐛 Describe the bug

Would it be more reasonable to calculate the average loss over the batch dim, instead of over all tokens? Now it seems that sequences of different lengths in a mini-batch are affecting each other, but I guess samples in a batch should be independent.

https://github.com/CarperAI/trlx/blob/07c962e13cbf91509f35c1a67c368393eac2333e/trlx/models/modeling_ppo.py#L195

A possible solution might be:

vf_loss = 0.5 * (torch.sum(torch.max(vf_loss1, vf_loss2) * mask, dim=1) / mask.sum(dim=1)).mean()

Please correct me if I'm wrong.

Which trlX version are you using?

trlx==0.6.0

Additional system and package information

Python 3.9.16, transformers==4.28.1, Linux

li-plus avatar Apr 24 '23 16:04 li-plus