CosyVoice icon indicating copy to clipboard operation
CosyVoice copied to clipboard

A serious bug in DPO implementation.

Open fispresent opened this issue 6 months ago • 10 comments

Describe the bug

In the 418 to 424 line in cosyvoice/llm/llm.py.

The original code is: chosen_lm_mask = chosen_lm_target == IGNORE_ID rejected_lm_mask = rejected_lm_target == IGNORE_ID chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1) chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1) rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)

chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1) and rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)

will lead to the following two problems:

  1. The values involved in the actual calculation are those of the IGNORE_ID part, and the values that need to be calculated are set to zero because chosen_lm_mask = chosen_lm_target == IGNORE_ID and rejected_lm_mask = rejected_lm_target == IGNORE_ID.

  2. Calculating the mean directly on dim=-1 would incorrectly introduce the IGNORE_ID part into the calculation (even if it is set to zero), because the denominator becomes larger. This might require using chosen_lm_target.sum(dim=-1) / chosen_lm_mask.sum(dim=-1) instead.

Additional context Your work is excellent! If I have misunderstood something, please do not hesitate to point it out, because I am not familiar with all the code details. Also, I'm not sure if this bug will have an impact on the data in the paper.

fispresent avatar Jul 12 '25 18:07 fispresent

you mean the mean also take IGNORE_ID part into account, right? yes we will update it later, chosen_logps should be the mean value only of chosen tokens

aluminumbox avatar Jul 13 '25 23:07 aluminumbox

除了这个地方您还发现了dpo代码的其他问题吗?

hizening avatar Jul 15 '25 09:07 hizening

Is this right?:chosen_lm_mask = (chosen_lm_target != IGNORE_ID) rejected_lm_mask = (rejected_lm_target != IGNORE_ID)

chosen_logps = torch.gather( chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(~chosen_lm_mask, 0).unsqueeze(-1) ).squeeze(-1)

rejected_logps = torch.gather( rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(~rejected_lm_mask, 0).unsqueeze(-1) ).squeeze(-1)

chosen_logps = chosen_logps * chosen_lm_mask rejected_logps = rejected_logps * rejected_lm_mask

chosen_counts = chosen_lm_mask.sum(dim=-1).clamp(min=1) rejected_counts = rejected_lm_mask.sum(dim=-1).clamp(min=1)

chosen_logps = chosen_logps.sum(dim=-1) / chosen_counts rejected_logps = rejected_logps.sum(dim=-1) / rejected_counts

hizening avatar Jul 15 '25 10:07 hizening

you mean the mean also take IGNORE_ID part into account, right? yes we will update it later, chosen_logps should be the mean value only of chosen tokens

Yes.

fispresent avatar Jul 15 '25 14:07 fispresent

Is this right?:chosen_lm_mask = (chosen_lm_target != IGNORE_ID) rejected_lm_mask = (rejected_lm_target != IGNORE_ID)

chosen_logps = torch.gather( chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(~chosen_lm_mask, 0).unsqueeze(-1) ).squeeze(-1)

rejected_logps = torch.gather( rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(~rejected_lm_mask, 0).unsqueeze(-1) ).squeeze(-1)

chosen_logps = chosen_logps * chosen_lm_mask rejected_logps = rejected_logps * rejected_lm_mask

chosen_counts = chosen_lm_mask.sum(dim=-1).clamp(min=1) rejected_counts = rejected_lm_mask.sum(dim=-1).clamp(min=1)

chosen_logps = chosen_logps.sum(dim=-1) / chosen_counts rejected_logps = rejected_logps.sum(dim=-1) / rejected_counts

Yes, thank you for your hard work!

fispresent avatar Jul 15 '25 14:07 fispresent

Do you find others error in dpo finetune code?

hizening avatar Jul 15 '25 14:07 hizening

Do you find others error in dpo finetune code?

No. This is an excellent job.

fispresent avatar Jul 15 '25 14:07 fispresent

This issue is stale because it has been open for 30 days with no activity.

github-actions[bot] avatar Aug 15 '25 02:08 github-actions[bot]

It seems that this serious bug hasn't been fixed yet

hm-li0420 avatar Dec 17 '25 08:12 hm-li0420

It seems that this serious bug hasn't been fixed yet

If I understand correctly, the official had fixed this bug. See this line: https://github.com/FunAudioLLM/CosyVoice/blob/9f27b42cd98ed1e46cd5472f6cd47853ab49de01/cosyvoice/llm/llm.py#L431C9-L431C96

        chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
        rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)

Please correct me if I am wrong.

sweetice avatar Dec 18 '25 09:12 sweetice