ms-swift
ms-swift copied to clipboard
DPO训练报错KeyError: 'prompt_input_ids'
按自定义数据格式,训练DPO在Map时报错 File "ms-swift/swift/trainers/dpo_trainer.py", line 114, in tokenize_row if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: KeyError: 'prompt_input_ids'
打印了下answer的key:dict_keys(['input_ids', 'attention_mask', 'prompt_inputs_embeds', 'prompt_attention_mask'])
训练代码:
CUDA_VISIBLE_DEVICES=2
swift rlhf
--rlhf_type dpo
--model_type internvl2-4b
--model_id_or_path ./OpenGVLab/InternVL2-4B
--beta 0.1
--sft_beta 0.1
--sft_type lora
--dataset {custom_dataset_path}.jsonl
--num_train_epochs 2
--lora_target_modules DEFAULT
--gradient_checkpointing true
--batch_size 1
--learning_rate 5e-5
--gradient_accumulation_steps 16
--warmup_ratio 0.03
--save_total_limit 1