xtuner icon indicating copy to clipboard operation
xtuner copied to clipboard

[WIP][Feature] DPO

Open amulil opened this issue 1 year ago • 10 comments

@pppppM 佬,按你说的,初步想法是在 dataset 目录下实现 DPODataset,在 model 目录下实现 DPO,其他 hook 暂时和 sft 一致的,不用修改,但是有一个疑问,DPO 里有 model 和 ref_model 两个 model,deepspeed 相关的部分用修改嘛?

amulil avatar Feb 25 '24 16:02 amulil

更新了 dpo 的实现,使用 sft 的数据,可以跑通流程,但是存在两个问题: NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_qlora_dpo_ultra_e3 --deepspeed deepspeed_zero2

  1. loss 为 nan 截屏2024-03-06 15 32 06

  2. deepcopy 的方式不支持量化加载,只有 lora 和不量化加载,流程可以跑通

@xiaohangguo @pppppM 佬们,看下这两个问题是为啥呀

amulil avatar Mar 06 '24 07:03 amulil

ref_model 要不直接用 llm 的 config 重新 build ?

loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节

pppppM avatar Mar 07 '24 07:03 pppppM

ref_model 要不直接用 llm 的 config 重新 build ?

loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节

好,今晚我切到这个分支复现一下,debug看看

xiaohangguo avatar Mar 07 '24 07:03 xiaohangguo

ref_model 要不直接用 llm 的 config 重新 build ?

loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节

可以 我试试改成 用 llm 的 config 重新 build

amulil avatar Mar 07 '24 10:03 amulil

写了个Mock 数据pytest来验证算法,目前测试结果,loss计算应该是没有问题。

import torch
import torch.nn.functional as F
from unittest import TestCase, main
# from utils import print


class MockModelOutput:
    def __init__(self, logits):
        self.logits = logits


class TestModel:
    def __init__(self, beta):
        self.beta = beta

    def llm(self, **kwargs):
        return MockModelOutput(logits=torch.randn(10, 5, 20))

    def ref_model(self, **kwargs):
        return MockModelOutput(logits=torch.randn(10, 5, 20))

    def compute_loss(self, data, data_samples=None):
        len_chosen = data["input_ids"].shape[0] // 2
        assert len_chosen != 0
        all_logits = self.llm(**data).logits
        all_ref_logits = self.ref_model(**data).logits

        print("all_logits:", all_logits)
        print("all_ref_logits:", all_ref_logits)

        labels = data["labels"]
        labels[labels == -100] = 0
        loss_mask = labels != 0

        print("labels:", labels)
        print("loss_mask:", loss_mask)

        per_token_logps = torch.gather(
            all_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        per_ref_token_logps = torch.gather(
            all_ref_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

        print("per_token_logps:", per_token_logps)
        print("per_ref_token_logps:", per_ref_token_logps)

        epsilon = 0
        all_logps = (per_token_logps * loss_mask).sum(-1) / \
            (loss_mask.sum(-1) + epsilon)
        all_ref_logps = (per_ref_token_logps * loss_mask).sum(-1) / \
            (loss_mask.sum(-1) + epsilon)
        print("loss_mask.sum(-1)", loss_mask.sum(-1))
        print("all_logps:", all_logps)
        print("all_ref_logps:", all_ref_logps)

        policy_chosen_logps = all_logps[:len_chosen]
        policy_rejected_logps = all_logps[len_chosen:]
        reference_chosen_logps = all_ref_logps[:len_chosen]
        reference_rejected_logps = all_ref_logps[len_chosen:]

        print("policy_chosen_logps:", policy_chosen_logps)
        print("policy_rejected_logps:", policy_rejected_logps)
        print("reference_chosen_logps:", reference_chosen_logps)
        print("reference_rejected_logps:", reference_rejected_logps)

        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps

        print("pi_logratios:", pi_logratios)
        print("ref_logratios:", ref_logratios)

        logits = pi_logratios - ref_logratios
        loss = -F.logsigmoid(self.beta * logits)

        print("logits:", logits)
        print("loss:", loss)

        chosen_rewards = self.beta * \
            (policy_chosen_logps - reference_chosen_logps)
        rejected_rewards = self.beta * \
            (policy_rejected_logps - reference_rejected_logps)

        print("chosen_rewards:", chosen_rewards)
        print("rejected_rewards:", rejected_rewards)

        loss_dict = {
            'loss': loss,
            'chosen_rewards': chosen_rewards,
            'rejected_rewards': rejected_rewards
        }
        return loss_dict


class LossComputationTest(TestCase):
    def test_compute_loss(self):
        model = TestModel(beta=0.1)
        data = {
            "input_ids": torch.randint(0, 20, (10, 5)),
            "labels": torch.randint(-100, 20, (10, 5))
        }

        # 确保所有标签值非负
        data["labels"] = torch.where(
            data["labels"] < 0, torch.tensor(0), data["labels"])

        loss_dict = model.compute_loss(data)
        loss, chosen_rewards, rejected_rewards = loss_dict['loss'], loss_dict[
            'chosen_rewards'], loss_dict['rejected_rewards']
        # print("Loss values:", loss)
        # print("chosen_rewards values:", chosen_rewards)
        # print("rejected_rewards values:", rejected_rewards)
        self.assertTrue(torch.all(loss >= 0))
        # self.assertTrue(torch.all(chosen_rewards <= 0))
        # self.assertTrue(torch.all(rejected_rewards >= 0))


if __name__ == "__main__":
    main()

下一步需要适配Class DPOdataset ,一条batch中格式保持(prompt chosen reject)

xiaohangguo avatar Mar 07 '24 14:03 xiaohangguo

把item_fn 搞了一下,但感觉还是有问题,单个conversation,应该是可以的,不知道能否和原来的encode_fn 结合,对于整个数据集处理好,正常走packer。 @LZHgrla ZH哥,麻烦帮忙看下看行不行

xiaohangguo avatar Mar 09 '24 15:03 xiaohangguo

NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_full_dpo_ultra_e3 --deepspeed deepspeed_zero2 目前 full dpo loss 正常了: 截屏2024-04-02 23 17 57 接下来按照 trl 文档里的说明添加 qlora dpo: https://moon-ci-docs.huggingface.co/docs/trl/pr_1193/en/dpo_trainer#downsides-to-merging-qlora-before-dpo-approach-2

amulil avatar Apr 02 '24 15:04 amulil

NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_full_dpo_ultra_e3 --deepspeed deepspeed_zero2 目前 full dpo loss 正常了: 截屏2024-04-02 23 17 57 接下来按照 trl 文档里的说明添加 qlora dpo: https://moon-ci-docs.huggingface.co/docs/trl/pr_1193/en/dpo_trainer#downsides-to-merging-qlora-before-dpo-approach-2

太强了!

xiaohangguo avatar Apr 04 '24 02:04 xiaohangguo

@amulil 请问现在有DPO训练的模型指标对比吗?我想参考这个实现RLHF-V code: https://github.com/RLHF-V/RLHF-V, https://github.com/thunlp/Muffin

KooSung avatar Apr 07 '24 11:04 KooSung

@amulil 请问现在有DPO训练的模型指标对比吗?我想参考这个实现RLHF-V code: https://github.com/RLHF-V/RLHF-V, https://github.com/thunlp/Muffin

@KooSung 目前暂时没有,后面会参考 https://github.com/huggingface/alignment-handbook/blob/main/recipes/zephyr-7b-beta/README.md 提到的 zephyr-7b-dpo-qlora 模型来看指标对比。

amulil avatar Apr 07 '24 13:04 amulil