Separate data/control planes in the token_id passing in GRPO
Is your feature request related to a problem? Please describe. Right now in nemo-rl GRPO, the generation workers will return the token_ids to header node on cpu, and header node will perform sequence packing and pass the token_ids as input arguments to train() or logprob() functions of policy workers. This is inefficient when global batch size or mean sequence length are large, as the large size of token_ids may cause excessive memory spilling to disk in the header node.
Describe the solution you'd like Implement a mode that minimizes the transfer of token_ids between CPU and GPU, ideally always keep token_ids on GPU and use GPU direct transfer to pass them to workers; workers only return metadata to CPU for running sequence packing, loss function should be calculated on GPU.
Describe alternatives you've considered A clear and concise description of any alternative solutions or features you've considered.
Additional context Add any other context or screenshots about the feature request here.