OpenRLHF icon indicating copy to clipboard operation
OpenRLHF copied to clipboard

Unexpected long actor_time when train_ppo_ray

Open LSC527 opened this issue 11 months ago • 9 comments

训练配置如下:

--ref_num_nodes 1 --ref_num_gpus_per_node 2 --reward_num_nodes 1 --reward_num_gpus_per_node 2 --critic_num_nodes 1 --critic_num_gpus_per_node 4 --actor_num_nodes 2 --actor_num_gpus_per_node 8 --vllm_num_engines 2 --vllm_tensor_parallel_size 4 --micro_train_batch_size 4 --train_batch_size 64 --micro_rollout_batch_size 4 --rollout_batch_size 64 --max_epochs 1 --prompt_max_len 1024 --generate_max_len 1024 --zero_stage 3 --bf16 --adam_offload --flash_attn --gradient_checkpointing --perf

actor模型是70b llama2,critic模型是13b llama2。 在通过actor模型计算action_log_probs时发现耗时异常,actor_time高达150秒。

        # log probs
        start = time.time()
        action_log_probs = self.actor(sequences, num_actions, attention_mask)
        actor_time = time.time() - start

通过profile发现是由于actor模型计算action_log_probs的推理开始时出现了长达80秒的all_gather通信。 image 怀疑是多机通信问题,但额外perf了一下actor模型训练的耗时也只有50秒。不清楚actor模型耗时异常是什么导致的。

LSC527 avatar Mar 14 '24 08:03 LSC527

收到,我们研究一下。最近工作比较忙,不一定顾得上~

hijkzzz avatar Mar 14 '24 10:03 hijkzzz

@LSC527 如果开启vllm的话,因为训练和推理分离,所以actor model和critic model的推理和训练计算量是相当的,建议把两者的GPU数量调整成一致。

wuxibin89 avatar Mar 15 '24 14:03 wuxibin89

通过profile发现是由于actor模型计算action_log_probs的推理开始时出现了长达80秒的all_gather通信。

这个all_gather通信的开销来自于actor和vllm参数同步,在训练阶段结束后,需要通过一次all_gather把参数收集到actor model的rank 0,然后broadcast给vllm的所有rank https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/ppo_actor.py#L142-L145

wuxibin89 avatar Mar 15 '24 14:03 wuxibin89

这里应该有有一定优化空间,现在是每个参数都需要经过一次all_gather+broadcast,可以把多个参数组成一个chunk,以减少通信次数

wuxibin89 avatar Mar 15 '24 14:03 wuxibin89

@LSC527 如果开启vllm的话,因为训练和推理分离,所以actor model和critic model的推理和训练计算量是相当的,建议把两者的GPU数量调整成一致。

@wuxibin89 因为我actor模型是70b llama2,critic模型是13b llama2小很多,所以critic的GPU数量设置的少。

通过profile发现是由于actor模型计算action_log_probs的推理开始时出现了长达80秒的all_gather通信。

这个all_gather通信的开销来自于actor和vllm参数同步,在训练阶段结束后,需要通过一次all_gather把参数收集到actor model的rank 0,然后broadcast给vllm的所有rank https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/ppo_actor.py#L142-L145

_broadcast_to_vllm的开销为什么会在actor_time上体现呢?https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ppo_utils/experience_maker.py#L257-L259

LSC527 avatar Mar 15 '24 15:03 LSC527

hmmm...我理解cuda kernel包括nccl通信应该都是异步执行的,所以actor_time这里可能触发了同步操作。 https://pytorch.org/docs/stable/notes/cuda.html 感觉可以在_broadcast_to_vllm之后加个torch.cuda.synchronize()验证一下 https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/ppo_actor.py#L145

wuxibin89 avatar Mar 15 '24 16:03 wuxibin89

@LSC527 能否看一下,是每个micro rollout batch的actor_time都很长,还是只有第一个batch的耗时较长?

wuxibin89 avatar Mar 16 '24 03:03 wuxibin89

@wuxibin89 每一个step,actor_time耗时都很长。并且我直接去掉_broadcast_to_vllm后仍然是这样。目前观察到这个现象会出现在actor_num_nodes>1 + zero3的场景下。https://github.com/OpenLLMAI/OpenRLHF/blob/3c918755faa31ee810f3624a82ba5f7879e4f8d3/openrlhf/trainer/ray/ppo_actor.py#L116 _broadcast_to_vllm后面有个torch.distributed.barrier(),所以耗时应该不会计算到actor_time里。actor_time看起来就是单纯的actor model zero3 forward耗时。 我再继续排查一下。

LSC527 avatar Mar 20 '24 12:03 LSC527

@wuxibin89 最终在ray.get(llm.generate.remote())前后加了barrier,发现是这一行代码运行带来的额外耗时。如果没有加barrier,额外耗时会被记入actor_time中。 https://github.com/OpenLLMAI/OpenRLHF/blob/3c918755faa31ee810f3624a82ba5f7879e4f8d3/openrlhf/trainer/ppo_utils/experience_maker.py#L344

LSC527 avatar Mar 21 '24 02:03 LSC527