LLaMA-Factory
LLaMA-Factory copied to clipboard
The NCCL timed out during PPO full parameter Trainging while using the zero3 model. How can I solve this problem?
Reminder
- [X] I have read the README and searched the existing issues.
Reproduction
accelerate launch --config_file "accelerate_config_ppo.yaml" \
${TRAIN_BASH_PY} \
--stage ${STAGE} \
--model_name_or_path ${SFT_MODEL_PATH} \
--do_train \
--dataset ${DATASET_NAME} \
--dataset_dir ${DATASET_DIR}\
--val_size ${VAL_SIZE} \
--cutoff_len ${MAX_SEQ_LENGTH} \
--mix_strategy concat \
--finetuning_type ${FINETUNING_TYPE} \
--reward_model ${RM_CKPT} \
--reward_model_type full \
--output_dir ${OUTPUT_DIR} \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size ${BATCH_SIZE} \
--gradient_accumulation_steps ${GRADIENT_STEPS} \
--lr_scheduler_type ${LR_SCHEDULER_TYPE} \
--max_grad_norm ${MAX_GRAD_NORM} \
--save_strategy epoch \
--logging_strategy steps \
--logging_steps ${LOGGING_STEPS} \
--learning_rate ${LEARNING_RATE} \
--top_k ${TOP_K} \
--top_p ${TOP_P} \
--warmup_ratio ${WARM_UP} \
--weight_decay ${WEIGHT_DECAY} \
--ppo_epochs ${PPO_EPOCHS} \
--ppo_buffer_size ${PPO_EBUFFER_SIZE} \
--num_train_epochs ${NUM_EPOCHS} \
--plot_loss \
--template qwen \
--bf16 \
--preprocessing_num_workers 5\
--save_safetensors False \
--save_total_limit 5 \
--ddp_timeout 3600000
使用deepspeed zero3策略进行PPO阶段训练,在进行一定轮数的iteration后,会hang住,最后出现nccl timeout timeout,具体报错信息如下:
Traceback (most recent call last):
File "/LLaMA-Factory/src/train_bash.py", line 14, in <module>
main()
File "/LLaMA-Factory/src/train_bash.py", line 5, in main
run_exp()
File "/LLaMA-Factory/src/llmtuner/train/tuner.py", line 37, in run_exp
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
File "/LLaMA-Factory/src/llmtuner/train/ppo/workflow.py", line 58, in run_ppo
dataset=dataset,
File "/LLaMA-Factory/src/llmtuner/train/ppo/trainer.py", line 194, in ppo_train
mini_batch_queries, mini_batch_responses = self.get_inputs(
File "/usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/LLaMA-Factory/src/llmtuner/train/ppo/trainer.py", line 306, in get_inputs
generate_output: torch.Tensor = unwrapped_model.generate(
File "/usr/local/lib/python3.9/dist-packages/trl/models/modeling_value_head.py", line 203, in generate
return self.pretrained_model.generate(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/transformers/generation/utils.py", line 1525, in generate
return self.sample(
File "/usr/local/lib/python3.9/dist-packages/transformers/generation/utils.py", line 2622, in sample
outputs = self(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 1173, in forward
outputs = self.model(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 1003, in forward
inputs_embeds = self.embed_tokens(input_ids)
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
result = hook(self, args)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 278, in _pre_forward_module_hook
self.pre_sub_module_forward_function(module)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 452, in pre_sub_module_forward_function
param_coordinator.fetch_sub_module(sub_module, forward=True)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 290, in fetch_sub_module
self.__all_gather_params(params_to_fetch, forward)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 434, in __all_gather_params
self.__all_gather_params_(nonquantized_params, forward, quantize=self.zero_quantized_weights)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 463, in __all_gather_params_
handle = param_group[0].all_gather_coalesced(param_group, quantize=quantize)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 1217, in all_gather_coalesced
handles = _dist_allgather_fn(
File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/partition_parameters.py", line 93, in _dist_allgather_fn
return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/comm/comm.py", line 320, in allgather_fn
return all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/comm/comm.py", line 117, in log_wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/comm/comm.py", line 305, in all_gather_into_tensor
return cdb.all_gather_into_tensor(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
File "/usr/local/lib/python3.9/dist-packages/deepspeed/comm/torch.py", line 219, in all_gather_into_tensor
return self.all_gather_function(output_tensor=output_tensor,
File "/usr/local/lib/python3.9/dist-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/distributed/distributed_c10d.py", line 2532, in all_gather_into_tensor
work = group._allgather_base(output_tensor, input_tensor)
RuntimeError: NCCL communicator was aborted on rank 3. Original reason for failure was: [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=7065370, OpType=ALLREDUCE, Timeout(ms)=36000000) ran for 36006808 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=7065370, OpType=ALLREDUCE, Timeout(ms)=36000000) ran for 36007072 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=7065370, OpType=ALLREDUCE, Timeout(ms)=36000000) ran for 36007080 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=7065370, OpType=_ALLGATHER_BASE, Timeout(ms)=36000000) ran for 36007019 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=7065370, OpType=ALLREDUCE, Timeout(ms)=36000000) ran for 36007154 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=7065370, OpType=ALLREDUCE, Timeout(ms)=36000000) ran for 36007389 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=7065370, OpType=ALLREDUCE, Timeout(ms)=36000000) ran for 36007448 milliseconds before timing out.
Expected behavior
No response
System Info
No response
Others
No response