agent-lightning icon indicating copy to clipboard operation
agent-lightning copied to clipboard

filter out empty triplets when batching

Open zxgx opened this issue 4 months ago • 9 comments

Reproduce bug

How to

insert prompt = " ".join([prompt for _ in range(3000)]) below this line.

Full trace stack on server side

ray.exceptions.RayTaskError(RuntimeError): ray::TaskRunner.run() (pid=2135564, ip=10.8.163.2, actor_id=29e7f19e7d1e8e75dd2fa30d16000000, repr=<agentlightning.verl.entrypoint.TaskRunner object at 0x7f370f04fc40>)
  File "/home/xxx/snap/code/agent-lightning/agentlightning/verl/entrypoint.py", line 152, in run
    trainer.fit()
  File "/home/xxx/snap/code/agent-lightning/agentlightning/verl/trainer.py", line 314, in fit
    metrics = self._train_step(batch_dict)
  File "/home/xxx/snap/code/agent-lightning/agentlightning/verl/trainer.py", line 141, in _train_step
    old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/single_controller/ray/base.py", line 50, in __call__
    output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_compute_log_prob() (pid=2135803, ip=10.8.163.2, actor_id=ebb4ab7008cbbb8a3ac28dea16000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7183d1799c00>)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/single_controller/ray/base.py", line 705, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/single_controller/base/decorator.py", line 514, in inner
    return func(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/workers/fsdp_workers.py", line 782, in compute_log_prob
    output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/utils/profiler/performance.py", line 89, in f
    return self.log(decorated_function, *args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/utils/profiler/performance.py", line 102, in log
    output = func(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/workers/actor/dp_actor.py", line 332, in compute_log_prob
    entropy, log_probs = self._forward_micro_batch(
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/workers/actor/dp_actor.py", line 167, in _forward_micro_batch
    output = self.actor_module(
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 856, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 544, in forward
    outputs: BaseModelOutputWithPast = self.model(
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 432, in forward
    layer_outputs = decoder_layer(
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 856, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/modeling_layers.py", line 83, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 236, in forward
    hidden_states, self_attn_weights = self.self_attn(
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 154, in forward
    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1, 128] because the unspecified dimension size -1 can be any value and is ambiguous

Solution

As shown in the trace stack, this issue is caused by a 0-length response. In addition, it does not affect inference mode. So, a straightforward idea is to filter out such kind of samples from training batches. Not sure if this is too aggressive to impact other components.

Other solutions

  1. Use other models As the bug is from transformers, using another model can avoid this issue. fixing this issue in transformers's modeling_qwen2 model also works.
  2. Drop at runtime add anthoer field like is_drop_mask to filter out 0-respoonse samples before entering compute_log_prob, and fill those droped samples with some default value.

zxgx avatar Aug 15 '25 18:08 zxgx

@zxgx please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

zxgx avatar Aug 15 '25 19:08 zxgx

Have you tested locally whether the proposed fix will mitigate issue #50? Show the behavior before the fix and after.

ultmaster avatar Aug 16 '25 11:08 ultmaster

Before fixing

the server side reports the above error and exits. The client side reports error and hangs after server exiting: image This is returned from the request

After fixing

The client side reports the same error to notice users. The server side can successfully pass training steps without valid samples image

zxgx avatar Aug 16 '25 12:08 zxgx

Have you checked wandb? I don't think the training makes any sense if all the data are thrown away.

I think a design choice has to be made here, with pros and cons, to ask users to intervene, if that's the case happening. Simply silencing all errors might not be a good idea.

ultmaster avatar Aug 16 '25 16:08 ultmaster

Clarification

In my debugging code, as all generated samples exceed the length limitation, there is no valid samples in each batch, and wandb log is empty. However, the skipping mechanism filters out invalid samples at individual rollout granularity. In real practice, I suppose such issue is rare, so there are still valid samples in a batch to get trained.

Potential improvement

To allow user intervention, is an option like actor_rollout_ref.rollout.skip_badrequest in config preffered?

Behaviour

By enabling this option, when a bad request like this issue happens, the server should skip this rollout sample during batching since there is no valid data. If disabling this option, the server would throw the exception to notice the user.

Default value of the option

This option should be enabled by default. Otherwise users may meet this issue in late training steps and waste most compute resources. For example, Figure 3 of Deepseek-R1 indicates that response will gradually increase as the training step increases.

zxgx avatar Aug 17 '25 04:08 zxgx

Well I think there are two cases. If it's the prompt length that is too long, it's a problem at the agent side and the server should clearly warn or point out that the user should take a look at their agent. If it's a response length problem, like the DSR1 given by you, we may need to further discuss what's happening here and should have a proper mechanism to handle it.

ultmaster avatar Aug 17 '25 05:08 ultmaster

To warn the prompt length of agent client, we need to intialize a tokenizer to precompute number of tokens during enqueuing a task (queue_task). Then we can add a warninng if the token count is out of limit. Is this a preferred design?

To get the response length, the execution is wrapped by verl's RayPPOTrainer which further wraps a OpenAIServingChat from vllm. While we can add certain patch like PatchedvLLMServer to handle this issue, I suppose such bad request should be removed from the training batch, as the model cannot get correct reward.

zxgx avatar Aug 17 '25 07:08 zxgx

Before fixing

the server side reports the above error and exits.

The client side reports error and hangs after server exiting:

image

This is returned from the request

After fixing

The client side reports the same error to notice users.

The server side can successfully pass training steps without valid samples

image

I'm fixing the same issue right now. With changing the verl training bash, setting 'truncate' in

data.truncation.

this replacement with error can somehow omit the error and let the training process continue, while somehow we can also set a bigger

data.max_prompt_length

If a 4096 length is not enough.

However, it‘s also important that if too much responses are dropped, the rl training is stakced and advantage estimation is incorrect, which may cause degrading performance.

IsaacGHX avatar Aug 26 '25 07:08 IsaacGHX

To warn the prompt length of agent client, we need to intialize a tokenizer to precompute number of tokens during enqueuing a task (queue_task). Then we can add a warninng if the token count is out of limit. Is this a preferred design?

To get the response length, the execution is wrapped by verl's RayPPOTrainer which further wraps a OpenAIServingChat from vllm. While we can add certain patch like PatchedvLLMServer to handle this issue, I suppose such bad request should be removed from the training batch, as the model cannot get correct reward.

I think it's a really good solution that truncating the input prompt to exceed the length of the agent when the LLM inference Max Token is set too small. However, when it comes to long agent responses, such as 10000+ input and 4000+ output, it can cause really verbose training time waste.

IsaacGHX avatar Aug 27 '25 14:08 IsaacGHX