policy.train slow at >=32 nodes b/c workers start at different time
When number of nodes is >=32 and calling policy.train, on some ranks it takes a long time to perform the initial synchronization (AllReduce nccl kernel) but on some other ranks the synchronization is fast. This indicates that different ranks start jobs at different times (difference is ~800ms), and we suspect it is because of ray job submission overhead.
A potential solution is to use Ray compiled graph to reduce the job submission overhead.
@katec846 please update the latest status
@euronymous-aithal I've implemented ray compiled graph and tested on sft algorithm. The original overhead was ~4s for 32 nodes seqlen 48k TP4 CP4 Qwen2.5-14B model. With ray compiled graph, the overhead went to <1s. However, the step time will be extreme high after 9-10 steps. it went from 19->22->49>73, but the computation time is still the same. The overheads come from the python side. Still investigating the root cause of this issue.
Renaming this issue to clarify the purpose is to minimize Ray-related overhead in GRPO including
- returning samples to driver after generation
- dispatching training functions
- sending samples to train workers as arguments A similar issue but focused on SFT is tracked in subissue.