[Feature] Keep transitions from the same rollout into one mini batch
In agent lightning, we collect all transitions and use balance_batch to reorder them for maximized efficiency.
This practice might have performance issues. I will try to investigate.
Hello.
In agentlightning/verl/trainer.py 346th line,
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# update critic
if self.use_critic:
with _timer("update_critic", timing_raw):
-> critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
below does not use self._balance_batch instead they just use batch to critic. This means that balance_batch is not yet used in this repo?
In the _balance_batch function of verl/verl/trainer/ppo/ray_trainer.py (line 919), lines 926–941 seem to perform partitioning by group only when minibatch=True, while the default is keep_minibatch=False.
So, it appears that group-based advantage computation hasn’t been fully implemented yet, consistent with 1).
From https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py,
if keep_minibatch:
# Decouple the DP balancing and mini-batching.
minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size")
minibatch_num = len(global_seqlen_lst) // minibatch_size
global_partition_lst = [[] for _ in range(world_size)]
for i in range(minibatch_num):
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],
k_partitions=world_size,
equal_size=True,
)
for j, part in enumerate(rearrange_minibatch_lst):
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
It does not use group id when sorting, so I think boundary samples might move to different minibatches, which could cause training problems.
@genji970
-
Currently,
_balance_batchis applied. As it is an inplace operation, after callingself._balance_batch(batch, metrics=metrics), the order insidebatchis changed. It is also the default verl behavior. -
I am not sure whether I fully get your point. If your question is "whether the group-based advantage has been correctly calculated", the answer should be yes. As we move all group-based advantage calculation to https://github.com/microsoft/agent-lightning/blob/main/agentlightning/verl/trainer.py#L310-L318 , before we drop any transitions or reordering.
Thanks I misunderstood that batch is not affected. I should read balance_batch func and trainer parts carefully.