A serious bug in DPO implementation.
Describe the bug
In the 418 to 424 line in cosyvoice/llm/llm.py.
The original code is:
chosen_lm_mask = chosen_lm_target == IGNORE_ID rejected_lm_mask = rejected_lm_target == IGNORE_ID chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1) rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
and
rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
will lead to the following two problems:
-
The values involved in the actual calculation are those of the IGNORE_ID part, and the values that need to be calculated are set to zero because
chosen_lm_mask = chosen_lm_target == IGNORE_IDandrejected_lm_mask = rejected_lm_target == IGNORE_ID. -
Calculating the mean directly on
dim=-1would incorrectly introduce the IGNORE_ID part into the calculation (even if it is set to zero), because the denominator becomes larger. This might require usingchosen_lm_target.sum(dim=-1) / chosen_lm_mask.sum(dim=-1)instead.
Additional context Your work is excellent! If I have misunderstood something, please do not hesitate to point it out, because I am not familiar with all the code details. Also, I'm not sure if this bug will have an impact on the data in the paper.
you mean the mean also take IGNORE_ID part into account, right? yes we will update it later, chosen_logps should be the mean value only of chosen tokens
除了这个地方您还发现了dpo代码的其他问题吗?
Is this right?:chosen_lm_mask = (chosen_lm_target != IGNORE_ID) rejected_lm_mask = (rejected_lm_target != IGNORE_ID)
chosen_logps = torch.gather( chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(~chosen_lm_mask, 0).unsqueeze(-1) ).squeeze(-1)
rejected_logps = torch.gather( rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(~rejected_lm_mask, 0).unsqueeze(-1) ).squeeze(-1)
chosen_logps = chosen_logps * chosen_lm_mask rejected_logps = rejected_logps * rejected_lm_mask
chosen_counts = chosen_lm_mask.sum(dim=-1).clamp(min=1) rejected_counts = rejected_lm_mask.sum(dim=-1).clamp(min=1)
chosen_logps = chosen_logps.sum(dim=-1) / chosen_counts rejected_logps = rejected_logps.sum(dim=-1) / rejected_counts
you mean the mean also take IGNORE_ID part into account, right? yes we will update it later, chosen_logps should be the mean value only of chosen tokens
Yes.
Is this right?:chosen_lm_mask = (chosen_lm_target != IGNORE_ID) rejected_lm_mask = (rejected_lm_target != IGNORE_ID)
chosen_logps = torch.gather( chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(~chosen_lm_mask, 0).unsqueeze(-1) ).squeeze(-1)
rejected_logps = torch.gather( rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(~rejected_lm_mask, 0).unsqueeze(-1) ).squeeze(-1)
chosen_logps = chosen_logps * chosen_lm_mask rejected_logps = rejected_logps * rejected_lm_mask
chosen_counts = chosen_lm_mask.sum(dim=-1).clamp(min=1) rejected_counts = rejected_lm_mask.sum(dim=-1).clamp(min=1)
chosen_logps = chosen_logps.sum(dim=-1) / chosen_counts rejected_logps = rejected_logps.sum(dim=-1) / rejected_counts
Yes, thank you for your hard work!
Do you find others error in dpo finetune code?
Do you find others error in dpo finetune code?
No. This is an excellent job.
This issue is stale because it has been open for 30 days with no activity.
It seems that this serious bug hasn't been fixed yet
It seems that this serious bug hasn't been fixed yet
If I understand correctly, the official had fixed this bug. See this line: https://github.com/FunAudioLLM/CosyVoice/blob/9f27b42cd98ed1e46cd5472f6cd47853ab49de01/cosyvoice/llm/llm.py#L431C9-L431C96
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
Please correct me if I am wrong.