agent-lightning icon indicating copy to clipboard operation
agent-lightning copied to clipboard

[Feature] Keep transitions from the same rollout into one mini batch

Open hzy46 opened this issue 2 months ago • 3 comments

In agent lightning, we collect all transitions and use balance_batch to reorder them for maximized efficiency.

This practice might have performance issues. I will try to investigate.

hzy46 avatar Nov 10 '25 09:11 hzy46

Hello.

In agentlightning/verl/trainer.py 346th line,

if self.config.trainer.balance_batch:
                self._balance_batch(batch, metrics=metrics)

            # update critic
            if self.use_critic:
                with _timer("update_critic", timing_raw):
  ->                critic_output = self.critic_wg.update_critic(batch)
                critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                metrics.update(critic_output_metrics)

below does not use self._balance_batch instead they just use batch to critic. This means that balance_batch is not yet used in this repo?

In the _balance_batch function of verl/verl/trainer/ppo/ray_trainer.py (line 919), lines 926–941 seem to perform partitioning by group only when minibatch=True, while the default is keep_minibatch=False.

So, it appears that group-based advantage computation hasn’t been fully implemented yet, consistent with 1).

From https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py,

if keep_minibatch:
            # Decouple the DP balancing and mini-batching.
            minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size")
            minibatch_num = len(global_seqlen_lst) // minibatch_size
            global_partition_lst = [[] for _ in range(world_size)]
            for i in range(minibatch_num):
                rearrange_minibatch_lst = get_seqlen_balanced_partitions(
                    global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],
                    k_partitions=world_size,
                    equal_size=True,
                )
                for j, part in enumerate(rearrange_minibatch_lst):
                    global_partition_lst[j].extend([x + minibatch_size * i for x in part])

It does not use group id when sorting, so I think boundary samples might move to different minibatches, which could cause training problems.

genji970 avatar Nov 11 '25 14:11 genji970

@genji970

  1. Currently, _balance_batch is applied. As it is an inplace operation, after calling self._balance_batch(batch, metrics=metrics), the order inside batch is changed. It is also the default verl behavior.

  2. I am not sure whether I fully get your point. If your question is "whether the group-based advantage has been correctly calculated", the answer should be yes. As we move all group-based advantage calculation to https://github.com/microsoft/agent-lightning/blob/main/agentlightning/verl/trainer.py#L310-L318 , before we drop any transitions or reordering.

hzy46 avatar Nov 11 '25 14:11 hzy46

Thanks I misunderstood that batch is not affected. I should read balance_batch func and trainer parts carefully.

genji970 avatar Nov 11 '25 14:11 genji970