SFT中 loss mask与最终loss没有对齐,会mask掉eos_token的loss
System Info
在fsdp_sft_trainer.py中,最新的代码,loss_mask丢弃第一个token,向左shift了
而loss计算时,丢弃的是最后一个token eos_token的loss_mask本应该是1,向左shift后,变成了0
导致SFT训练后的模型,经常在回复的结尾陷入循环回复,无法终止
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
loss_mask = batch.pop("loss_mask")[:, 1:].reshape(-1).to(self.device_name) full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss full_loss = full_loss.reshape(-1) loss_mask = loss_mask.to(full_loss.device) loss = full_loss * loss_mask
Expected behavior
loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).to(self.device_name) full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss full_loss = full_loss.reshape(-1) loss_mask = loss_mask.to(full_loss.device) loss = full_loss * loss_mask
The same issue. I think it is indeed a bug.
code
loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).to(self.device_name)
should be correct.
I also noticed the change. The reason for the change can be found in: https://github.com/volcengine/verl/pull/3287. I think it was fixed for multiple-turn dataset.
And I agree "loss_mask = batch.pop("loss_mask")[:, 1:].reshape(-1).to(self.device_name)" will cause endless repetition for single-turn dataset.
Same question. Why mask the loss of the last token (usually eos)?
似乎应该将下面这行代码去除,这样最后的loss mask是正确的,对于response的第一个token和最后一个token(EOS)的预测都会计算相应的loss https://github.com/volcengine/verl/blob/504696245b39dd22162579b36303706ee61a731a/verl/utils/dataset/sft_dataset.py#L197