RLHF
RLHF copied to clipboard
基于ChatGLM2的RLHF训练问题
[2023-08-12 01:22:11,409] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.10.0, git-hash=unknown, git-branch=unknown
Traceback (most recent call last):
File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/train_rlhf.py", line 373, in
在使用ChatGLM2作为sft和reward模型,在A100*8的环境上训练的时候,在第三阶段train_rlhf时出现如上报错,尝试了很多方法都没有解决,deepspeed版本是0.10.0,奇怪的点是当--actor_zero_stage是2的时候,能够成功装载actor模型,但是装载reference的时候仍然会报这个错,想请问一下作者有什么建议吗?
这个原因应该是系统认为在运行deepspeed.initialize()
之前world_size
一直都是1,所以ds_config['train_batch_size']
不需要乘上world_size
。只能在运行deepspeed.initialize()
之前,才把ds_config['train_batch_size']
改为乘上world_size
。
RL部分的代码还没来得及修复这个问题,具体可以参见pretrain_wo_trainer.py 第220-221行和pretrain_wo_trainer.py 第292行
这个原因应该是系统认为在运行
deepspeed.initialize()
之前world_size
一直都是1,所以ds_config['train_batch_size']
不需要乘上world_size
。只能在运行deepspeed.initialize()
之前,才把ds_config['train_batch_size']
改为乘上world_size
。RL部分的代码还没来得及修复这个问题,具体可以参见pretrain_wo_trainer.py 第220-221行和pretrain_wo_trainer.py 第292行 具体咋解决呢