agent-lightning
agent-lightning copied to clipboard
Bug Fix:Prevent duplicate rollouts caused by stale task requeue + lat…
This PR addresses an issue with stale task requeueing that could cause duplicate rollouts to be received and stored.
Problem
- The server requeues tasks when they time out (
_check_and_requeue_stale_tasks). - If a worker eventually finishes and submits a rollout for a task that has already been requeued, the server still accepted the outdated result in
/rollout. - This caused the same logical task to be processed more than once.
- In our training loop, if this duplicate rollout is from last batch this led to:
failing, because duplicate rollouts from previous batches were counted.assert len(self._completed_rollouts) == self._total_tasks_queued
or led to rollout id key error because this rollout is not from current batch
Solution
- Introduced an attempt_id for each task claim.
- Every time a task is handed out via
/task, a newattempt_id(UUID) is generated. - Workers are required to include this
attempt_idwhen submitting rollouts.
- Every time a task is handed out via
- When the server receives a rollout:
- It checks if the
attempt_idmatches the currently active one in_processing_tasks. - If it doesn't match (i.e. the task was requeued and a newer attempt is active), the stale rollout is silently ignored.
- It checks if the
This ensures at most one valid rollout is stored per logical task, and late stale results will not break training.
Changes
- Updated
Taskto includeattempt_id. get_next_taskassigns a newattempt_idupon each claim.store_rolloutvalidates theattempt_idbefore accepting results.- Updated logging to make it clear when a stale rollout is dropped.
Why this is important
- Guarantees consistency between queued tasks and completed rollouts.
- Prevents assertion errors during training when tasks time out and later resurface.
- Makes the system more robust in long-running distributed training with occasional straggler workers.
Related issues
- Internal error
AssertionError: assert len(self._completed_rollouts) == self._total_tasks_queuedcaused by duplicate rollouts from stale tasks.
In server.py, we already had a num_claims. Maybe rename it as claim_sequence_id is a good idea? I don't see the API used elsewhere. so it should be safe.