ms-swift icon indicating copy to clipboard operation
ms-swift copied to clipboard

DPO训练报错KeyError: 'prompt_input_ids'

Open JiaweiZhao-git opened this issue 6 months ago • 4 comments

按自定义数据格式,训练DPO在Map时报错 File "ms-swift/swift/trainers/dpo_trainer.py", line 114, in tokenize_row if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: KeyError: 'prompt_input_ids'

打印了下answer的key:dict_keys(['input_ids', 'attention_mask', 'prompt_inputs_embeds', 'prompt_attention_mask'])

训练代码: CUDA_VISIBLE_DEVICES=2
swift rlhf
--rlhf_type dpo
--model_type internvl2-4b
--model_id_or_path ./OpenGVLab/InternVL2-4B
--beta 0.1
--sft_beta 0.1
--sft_type lora
--dataset {custom_dataset_path}.jsonl
--num_train_epochs 2
--lora_target_modules DEFAULT
--gradient_checkpointing true
--batch_size 1
--learning_rate 5e-5
--gradient_accumulation_steps 16
--warmup_ratio 0.03
--save_total_limit 1

JiaweiZhao-git avatar Aug 14 '24 11:08 JiaweiZhao-git