[WIP][Feature] DPO
@pppppM 佬,按你说的,初步想法是在 dataset 目录下实现 DPODataset,在 model 目录下实现 DPO,其他 hook 暂时和 sft 一致的,不用修改,但是有一个疑问,DPO 里有 model 和 ref_model 两个 model,deepspeed 相关的部分用修改嘛?
更新了 dpo 的实现,使用 sft 的数据,可以跑通流程,但是存在两个问题:
NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_qlora_dpo_ultra_e3 --deepspeed deepspeed_zero2
-
loss 为 nan
-
deepcopy 的方式不支持量化加载,只有 lora 和不量化加载,流程可以跑通
@xiaohangguo @pppppM 佬们,看下这两个问题是为啥呀
ref_model 要不直接用 llm 的 config 重新 build ?
loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节
ref_model 要不直接用 llm 的 config 重新 build ?
loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节
好,今晚我切到这个分支复现一下,debug看看
ref_model 要不直接用 llm 的 config 重新 build ?
loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节
可以 我试试改成 用 llm 的 config 重新 build
写了个Mock 数据pytest来验证算法,目前测试结果,loss计算应该是没有问题。
import torch
import torch.nn.functional as F
from unittest import TestCase, main
# from utils import print
class MockModelOutput:
def __init__(self, logits):
self.logits = logits
class TestModel:
def __init__(self, beta):
self.beta = beta
def llm(self, **kwargs):
return MockModelOutput(logits=torch.randn(10, 5, 20))
def ref_model(self, **kwargs):
return MockModelOutput(logits=torch.randn(10, 5, 20))
def compute_loss(self, data, data_samples=None):
len_chosen = data["input_ids"].shape[0] // 2
assert len_chosen != 0
all_logits = self.llm(**data).logits
all_ref_logits = self.ref_model(**data).logits
print("all_logits:", all_logits)
print("all_ref_logits:", all_ref_logits)
labels = data["labels"]
labels[labels == -100] = 0
loss_mask = labels != 0
print("labels:", labels)
print("loss_mask:", loss_mask)
per_token_logps = torch.gather(
all_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_ref_token_logps = torch.gather(
all_ref_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
print("per_token_logps:", per_token_logps)
print("per_ref_token_logps:", per_ref_token_logps)
epsilon = 0
all_logps = (per_token_logps * loss_mask).sum(-1) / \
(loss_mask.sum(-1) + epsilon)
all_ref_logps = (per_ref_token_logps * loss_mask).sum(-1) / \
(loss_mask.sum(-1) + epsilon)
print("loss_mask.sum(-1)", loss_mask.sum(-1))
print("all_logps:", all_logps)
print("all_ref_logps:", all_ref_logps)
policy_chosen_logps = all_logps[:len_chosen]
policy_rejected_logps = all_logps[len_chosen:]
reference_chosen_logps = all_ref_logps[:len_chosen]
reference_rejected_logps = all_ref_logps[len_chosen:]
print("policy_chosen_logps:", policy_chosen_logps)
print("policy_rejected_logps:", policy_rejected_logps)
print("reference_chosen_logps:", reference_chosen_logps)
print("reference_rejected_logps:", reference_rejected_logps)
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
print("pi_logratios:", pi_logratios)
print("ref_logratios:", ref_logratios)
logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(self.beta * logits)
print("logits:", logits)
print("loss:", loss)
chosen_rewards = self.beta * \
(policy_chosen_logps - reference_chosen_logps)
rejected_rewards = self.beta * \
(policy_rejected_logps - reference_rejected_logps)
print("chosen_rewards:", chosen_rewards)
print("rejected_rewards:", rejected_rewards)
loss_dict = {
'loss': loss,
'chosen_rewards': chosen_rewards,
'rejected_rewards': rejected_rewards
}
return loss_dict
class LossComputationTest(TestCase):
def test_compute_loss(self):
model = TestModel(beta=0.1)
data = {
"input_ids": torch.randint(0, 20, (10, 5)),
"labels": torch.randint(-100, 20, (10, 5))
}
# 确保所有标签值非负
data["labels"] = torch.where(
data["labels"] < 0, torch.tensor(0), data["labels"])
loss_dict = model.compute_loss(data)
loss, chosen_rewards, rejected_rewards = loss_dict['loss'], loss_dict[
'chosen_rewards'], loss_dict['rejected_rewards']
# print("Loss values:", loss)
# print("chosen_rewards values:", chosen_rewards)
# print("rejected_rewards values:", rejected_rewards)
self.assertTrue(torch.all(loss >= 0))
# self.assertTrue(torch.all(chosen_rewards <= 0))
# self.assertTrue(torch.all(rejected_rewards >= 0))
if __name__ == "__main__":
main()
下一步需要适配Class DPOdataset ,一条batch中格式保持(prompt chosen reject)
把item_fn 搞了一下,但感觉还是有问题,单个conversation,应该是可以的,不知道能否和原来的encode_fn 结合,对于整个数据集处理好,正常走packer。 @LZHgrla ZH哥,麻烦帮忙看下看行不行
NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_full_dpo_ultra_e3 --deepspeed deepspeed_zero2
目前 full dpo loss 正常了:
接下来按照 trl 文档里的说明添加 qlora dpo:
https://moon-ci-docs.huggingface.co/docs/trl/pr_1193/en/dpo_trainer#downsides-to-merging-qlora-before-dpo-approach-2
NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_full_dpo_ultra_e3 --deepspeed deepspeed_zero2目前 full dpo loss 正常了:接下来按照 trl 文档里的说明添加 qlora dpo: https://moon-ci-docs.huggingface.co/docs/trl/pr_1193/en/dpo_trainer#downsides-to-merging-qlora-before-dpo-approach-2
太强了!
@amulil 请问现在有DPO训练的模型指标对比吗?我想参考这个实现RLHF-V code: https://github.com/RLHF-V/RLHF-V, https://github.com/thunlp/Muffin
@amulil 请问现在有DPO训练的模型指标对比吗?我想参考这个实现RLHF-V code: https://github.com/RLHF-V/RLHF-V, https://github.com/thunlp/Muffin
@KooSung 目前暂时没有,后面会参考 https://github.com/huggingface/alignment-handbook/blob/main/recipes/zephyr-7b-beta/README.md 提到的 zephyr-7b-dpo-qlora 模型来看指标对比。
接下来按照 trl 文档里的说明添加 qlora dpo: https://moon-ci-docs.huggingface.co/docs/trl/pr_1193/en/dpo_trainer#downsides-to-merging-qlora-before-dpo-approach-2