LLaMA-Factory icon indicating copy to clipboard operation
LLaMA-Factory copied to clipboard

如何实现orpo do_predict ,并计算bleu-4等metrics

Open bingwork opened this issue 9 months ago • 0 comments

Reminder

  • [X] I have read the README and searched the existing issues.

Reproduction

https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llmtuner/train/sft/workflow.py sft支持do_predict。看CustomSeq2SeqTrainer继承于Seq2SeqTrainer,Seq2SeqTrainer实现了predict方法。 但是 https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llmtuner/train/orpo/workflow.py ,orpo还未支持do_predict。看CustomORPOTrainer继承于DPOTrainer,但是DPOTrainer并未实现predict方法。 我看了https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py#L800,也没有实现predict方法。

请问: 1,是因为本身DPOTrainer,orpo_trainer,理论上不适合实现predict方法吗?有什么原因,不实现predict方法吗?因为都是继承trainer,为何不实现呢? 2,如果也可以实现predict方法,有什么参考吗? 我也还在尝试,但发现有些难度。 3,另外我看DPOTrainer,orpo_trainer都有一个get_batch_samples,是需要使用这个方法吗?

多谢~

Expected behavior

No response

System Info

No response

Others

No response

bingwork avatar May 06 '24 12:05 bingwork