filter out empty triplets when batching
Reproduce bug
How to
insert prompt = " ".join([prompt for _ in range(3000)]) below this line.
Full trace stack on server side
ray.exceptions.RayTaskError(RuntimeError): ray::TaskRunner.run() (pid=2135564, ip=10.8.163.2, actor_id=29e7f19e7d1e8e75dd2fa30d16000000, repr=<agentlightning.verl.entrypoint.TaskRunner object at 0x7f370f04fc40>)
File "/home/xxx/snap/code/agent-lightning/agentlightning/verl/entrypoint.py", line 152, in run
trainer.fit()
File "/home/xxx/snap/code/agent-lightning/agentlightning/verl/trainer.py", line 314, in fit
metrics = self._train_step(batch_dict)
File "/home/xxx/snap/code/agent-lightning/agentlightning/verl/trainer.py", line 141, in _train_step
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/single_controller/ray/base.py", line 50, in __call__
output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.actor_rollout_compute_log_prob() (pid=2135803, ip=10.8.163.2, actor_id=ebb4ab7008cbbb8a3ac28dea16000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7183d1799c00>)
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/single_controller/ray/base.py", line 705, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/single_controller/base/decorator.py", line 514, in inner
return func(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/workers/fsdp_workers.py", line 782, in compute_log_prob
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/utils/profiler/performance.py", line 89, in f
return self.log(decorated_function, *args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/utils/profiler/performance.py", line 102, in log
output = func(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/workers/actor/dp_actor.py", line 332, in compute_log_prob
entropy, log_probs = self._forward_micro_batch(
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/verl/workers/actor/dp_actor.py", line 167, in _forward_micro_batch
output = self.actor_module(
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 856, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
output = func(self, *args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 544, in forward
outputs: BaseModelOutputWithPast = self.model(
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
output = func(self, *args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 432, in forward
layer_outputs = decoder_layer(
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 856, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/modeling_layers.py", line 83, in __call__
return super().__call__(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 236, in forward
hidden_states, self_attn_weights = self.self_attn(
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xxx/venvs/debug/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 154, in forward
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1, 128] because the unspecified dimension size -1 can be any value and is ambiguous
Solution
As shown in the trace stack, this issue is caused by a 0-length response. In addition, it does not affect inference mode. So, a straightforward idea is to filter out such kind of samples from training batches. Not sure if this is too aggressive to impact other components.
Other solutions
- Use other models As the bug is from transformers, using another model can avoid this issue. fixing this issue in transformers's modeling_qwen2 model also works.
- Drop at runtime
add anthoer field like
is_drop_maskto filter out 0-respoonse samples before enteringcompute_log_prob, and fill those droped samples with some default value.
@zxgx please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
@microsoft-github-policy-service agree [company="{your company}"]Options:
- (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
- (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"Contributor License Agreement
@microsoft-github-policy-service agree
Have you tested locally whether the proposed fix will mitigate issue #50? Show the behavior before the fix and after.
Before fixing
the server side reports the above error and exits.
The client side reports error and hangs after server exiting:
This is returned from the request
After fixing
The client side reports the same error to notice users.
The server side can successfully pass training steps without valid samples
Have you checked wandb? I don't think the training makes any sense if all the data are thrown away.
I think a design choice has to be made here, with pros and cons, to ask users to intervene, if that's the case happening. Simply silencing all errors might not be a good idea.
Clarification
In my debugging code, as all generated samples exceed the length limitation, there is no valid samples in each batch, and wandb log is empty. However, the skipping mechanism filters out invalid samples at individual rollout granularity. In real practice, I suppose such issue is rare, so there are still valid samples in a batch to get trained.
Potential improvement
To allow user intervention, is an option like actor_rollout_ref.rollout.skip_badrequest in config preffered?
Behaviour
By enabling this option, when a bad request like this issue happens, the server should skip this rollout sample during batching since there is no valid data. If disabling this option, the server would throw the exception to notice the user.
Default value of the option
This option should be enabled by default. Otherwise users may meet this issue in late training steps and waste most compute resources. For example, Figure 3 of Deepseek-R1 indicates that response will gradually increase as the training step increases.
Well I think there are two cases. If it's the prompt length that is too long, it's a problem at the agent side and the server should clearly warn or point out that the user should take a look at their agent. If it's a response length problem, like the DSR1 given by you, we may need to further discuss what's happening here and should have a proper mechanism to handle it.
To warn the prompt length of agent client, we need to intialize a tokenizer to precompute number of tokens during enqueuing a task (queue_task). Then we can add a warninng if the token count is out of limit. Is this a preferred design?
To get the response length, the execution is wrapped by verl's RayPPOTrainer which further wraps a OpenAIServingChat from vllm. While we can add certain patch like PatchedvLLMServer to handle this issue, I suppose such bad request should be removed from the training batch, as the model cannot get correct reward.
Before fixing
the server side reports the above error and exits.
The client side reports error and hangs after server exiting:
![]()
This is returned from the request
After fixing
The client side reports the same error to notice users.
The server side can successfully pass training steps without valid samples
![]()
I'm fixing the same issue right now. With changing the verl training bash, setting 'truncate' in
data.truncation.
this replacement with error can somehow omit the error and let the training process continue, while somehow we can also set a bigger
data.max_prompt_length
If a 4096 length is not enough.
However, it‘s also important that if too much responses are dropped, the rl training is stakced and advantage estimation is incorrect, which may cause degrading performance.
To warn the prompt length of agent client, we need to intialize a tokenizer to precompute number of tokens during enqueuing a task (queue_task). Then we can add a warninng if the token count is out of limit. Is this a preferred design?
To get the response length, the execution is wrapped by verl's
RayPPOTrainerwhich further wraps aOpenAIServingChatfrom vllm. While we can add certain patch like PatchedvLLMServer to handle this issue, I suppose such bad request should be removed from the training batch, as the model cannot get correct reward.
I think it's a really good solution that truncating the input prompt to exceed the length of the agent when the LLM inference Max Token is set too small. However, when it comes to long agent responses, such as 10000+ input and 4000+ output, it can cause really verbose training time waste.