Agentic RL Support in GPT-OSS
Feature request
Agentic RL Support in GPT-OSS
Motivation
Hey Community,
@HJSang and I are from the LinkedIn Core AI team. Over the past few weeks, we’ve been working on adapting GPT-OSS models for agentic RL post-training with tool calls, using SGLang as the inference engine in an agent-loop setup (token-in, token-out).
We’ve successfully made it functionally work on both internal and retool datasets, but we’ve observed some unusual behaviors in gradient norms and entropy.
We’d love to share reproducible code and collaborate with the community to better understand and resolve these issues.
Let's use this for supporting the gpt-oss in verl. cc @chenhaiq @eric-haibin-lin
Your contribution
- Add support for tool call parser in GPT-OSS
- Making the templated chat message to be compatible with compute_score in rewards
- Use triton endpoint for GPT-OSS
The current issue to train GPT-OSS model: grad_norm of GRPO grows too fast, which prevents model to achieve reasonable good performance.
- Train on gsm8k PR
reasoning effort: medium
reasoning effort: low
as a comparison, qwen3-4B can get 90% easily
- agent loop training using math-expression example: https://github.com/volcengine/verl/blob/main/recipe/langgraph_agent/example/run_gpt_oss_20b_bf16.sh
We exclude the MOE instability by setting batch_size = mini batch size to enforce on policy.
My only hypothesis is that there is some issue for current gpt-oss model's implementation in transformers which causes the instability of gradient. Your investigation will be really appreciated.
Try importance sampling for gpt-oss training. The good news is taht the grad is no longer exploding but the reward and val are very looking good. My hypothesis is that: there is some unknown issue for model implementation.
The rollout-training mis-match metrics
We investigated a critical compatibility issue where flash_attention_2 doesn't support gpt-oss attention sink, causing gradient norm spikes during training. The existing VERL codebase lacked the ability to override the attn_implementation parameter, forcing users to use incompatible attention mechanisms. Our solution in PR #3978 adds flexible attention implementation override support, allowing users to switch to attn_implementation via configuration. Initial testing with attn_implementation=eager confirmed this eliminates gradient norm spikes and ensures proper score convergence.
Thank you for your solutions. It works~ If you want to gain the Training Speed and GPU Memory Advantage, replace eager to "kernels-community/vllm-flash-attn3".
Implementation Details: https://huggingface.co/kernels-community/vllm-flash-attn3
@yinzhangyue it didn't implement the backward path with attention-sink support.
I want to point out that, both empirically and from looking at the code, it seems the HF implementation of Flex Attention supports sinks for both forward and backward passes, and is much more efficient than eager. Let me know if I'm mistaken!
I also want to point out that the latest released version of vLLM doesn't support LoRA for GPT OSS; it applies lora only to attention layers which causes discrepancy between the training policy and the rollout policy. It seems like the latest dev version has a fix but I haven't tested it.
@aghyad-deeb Yeah I have tried flex attention and I am able to replicate the reward curve as eager mode for gsm8k example, but when I tried retool example which is more complex that involves multi-turn tool-use, I got the following error in backward pass:
[36m(RewardManagerWorker pid=1080802)[0m Debugging: num_turns: 16[32m [repeated 7x across cluster][0m
Error executing job with overrides: ['algorithm.adv_estimator=gae', 'algorithm.use_kl_in_reward=False', 'algorithm.kl_ctrl.kl_coef=0.0', 'algorithm.gamma=1.0', 'algorithm.lam=1.0', "data.train_files=['/home/jobuser/dataset/BytedTsinghua-SIA/DAPO-Math-17k']", "data.val_files=['/home/jobuser/dataset/yentinglin/aime_2025']", 'data.return_raw_chat=True', 'data.train_batch_size=512', 'data.max_prompt_length=2048', 'data.max_response_length=16384', 'data.filter_overlong_prompts=True', '+data.apply_chat_template_kwargs.reasoning_effort=medium', 'data.truncation=error', 'data.custom_cls.path=/home/jobuser/verl/recipe/retool/retool.py', 'data.custom_cls.name=CustomRLHFDataset', 'custom_reward_function.path=/home/jobuser/verl/recipe/retool/retool.py', 'custom_reward_function.name=compute_score', 'actor_rollout_ref.model.path=/shared/public/sharing/hsang/gpt-oss-20b-bf16', 'actor_rollout_ref.model.use_remove_padding=True', 'actor_rollout_ref.model.enable_gradient_checkpointing=True', 'actor_rollout_ref.actor.use_kl_loss=False', 'actor_rollout_ref.actor.kl_loss_coef=0.0', '+actor_rollout_ref.model.override_config.attn_implementation=flex_attention', 'actor_rollout_ref.actor.clip_ratio_low=0.2', 'actor_rollout_ref.actor.clip_ratio_high=0.28', 'actor_rollout_ref.actor.clip_ratio_c=10.0', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.actor.use_dynamic_bsz=True', 'actor_rollout_ref.actor.ppo_mini_batch_size=512', 'actor_rollout_ref.actor.ppo_max_token_len_per_gpu=36864', 'actor_rollout_ref.actor.ulysses_sequence_parallel_size=4', 'actor_rollout_ref.actor.fsdp_config.param_offload=True', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=True', 'actor_rollout_ref.rollout.name=sglang', 'actor_rollout_ref.rollout.mode=async', 'actor_rollout_ref.rollout.tensor_model_parallel_size=4', 'actor_rollout_ref.rollout.multi_turn.enable=True', 'actor_rollout_ref.rollout.multi_turn.max_user_turns=8', 'actor_rollout_ref.rollout.multi_turn.max_assistant_turns=8', 'actor_rollout_ref.rollout.multi_turn.tool_config_path=/home/jobuser/verl/recipe/retool/sandbox_fusion_tool_config.yaml', 'actor_rollout_ref.rollout.multi_turn.format=gpt-oss', '+actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.8', 'actor_rollout_ref.rollout.val_kwargs.top_p=1.0', 'actor_rollout_ref.rollout.val_kwargs.temperature=1.0', 'actor_rollout_ref.rollout.val_kwargs.n=30', 'critic.optim.lr=2e-6', 'critic.model.use_remove_padding=True', 'critic.model.path=/shared/public/sharing/hsang/gpt-oss-20b-bf16', 'critic.model.enable_gradient_checkpointing=True', 'critic.ppo_max_token_len_per_gpu=73728', 'critic.ulysses_sequence_parallel_size=4', 'critic.model.fsdp_config.param_offload=True', 'critic.model.fsdp_config.optimizer_offload=True', 'trainer.critic_warmup=20', 'trainer.logger=[mlflow]', 'trainer.project_name=wuxibin_retool', 'trainer.experiment_name=gpt-oss-20b-bf16_ppo_jaszhu', 'trainer.n_gpus_per_node=8', 'trainer.val_before_train=True', 'trainer.log_val_generations=100', 'trainer.nnodes=1', 'trainer.save_freq=30', 'trainer.default_local_dir=/home/jobuser/checkpoint/gpt-oss-20b-bf16_ppo_jaszhu', 'trainer.test_freq=5', 'trainer.total_epochs=1']
Traceback (most recent call last):
File "/home/jobuser/verl/verl/trainer/main_ppo.py", line 42, in main
run_ppo(config)
File "/home/jobuser/verl/verl/trainer/main_ppo.py", line 96, in run_ppo
ray.get(runner.run.remote(config))
File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 2782, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 929, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError: [36mray::TaskRunner.run()[39m (pid=1037235, ip=100.96.49.172, actor_id=96bebcc7e8e13436722d792401000000, repr=<main_ppo.TaskRunner object at 0x7adc328a3070>)
File "/home/jobuser/verl/verl/trainer/main_ppo.py", line 341, in run
trainer.fit()
File "/home/jobuser/verl/verl/trainer/ppo/ray_trainer.py", line 1120, in fit
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
File "/home/jobuser/verl/verl/single_controller/ray/base.py", line 48, in __call__
output = ray.get(output)
ray.exceptions.RayTaskError: [36mray::WorkerDict.actor_rollout_compute_log_prob()[39m (pid=1057174, ip=100.96.49.172, actor_id=1312a0b6dd7413234af06c8b01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7048fc39c220>)
File "/export/apps/python/3.10/lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/export/apps/python/3.10/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/jobuser/verl/verl/single_controller/ray/base.py", line 700, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
File "/home/jobuser/verl/verl/single_controller/base/decorator.py", line 442, in inner
return func(*args, **kwargs)
File "/home/jobuser/verl/verl/utils/transferqueue_utils.py", line 199, in dummy_inner
return func(*args, **kwargs)
File "/home/jobuser/verl/verl/utils/profiler/profile.py", line 256, in wrapper
return func(self_instance, *args, **kwargs_inner)
File "/home/jobuser/verl/verl/workers/fsdp_workers.py", line 979, in compute_log_prob
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
File "/home/jobuser/verl/verl/utils/profiler/performance.py", line 105, in f
return self.log(decorated_function, *args, **kwargs)
File "/home/jobuser/verl/verl/utils/profiler/performance.py", line 118, in log
output = func(*args, **kwargs)
File "/home/jobuser/verl/verl/workers/actor/dp_actor.py", line 339, in compute_log_prob
entropy, log_probs = self._forward_micro_batch(
File "/home/jobuser/verl/verl/workers/actor/dp_actor.py", line 170, in _forward_micro_batch
output = self.actor_module(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 940, in wrapper
output = func(self, *args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 663, in forward
outputs: MoeModelOutputWithPast = self.model(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 1064, in wrapper
outputs = func(self, *args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 502, in forward
hidden_states = decoder_layer(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 366, in forward
hidden_states, _ = self.self_attn(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 323, in forward
attn_output, attn_weights = attention_interface(
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/integrations/flex_attention.py", line 296, in flex_attention_forward
flex_attention_output = compile_friendly_flex_attention(
File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/integrations/flex_attention.py", line 97, in compile_friendly_flex_attention
return flex_attention_compiled(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 923, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 907, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1377, in codegen_and_compile
graph.run(*example_inputs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 921, in run
return super().run(*args)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 173, in run
self.env[node] = self.run_node(node)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1599, in run_node
result = super().run_node(n)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 242, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1268, in call_function
raise LoweringException(e, target, args, kwargs).with_traceback(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1258, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 446, in wrapped
out = decomp_fn(*args, **kwargs)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 1534, in flex_attention
error = flex_attention_template.maybe_append_choice(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1315, in maybe_append_choice
choices.append(self.generate(generate_with_caching=True, **kwargs))
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1563, in generate
result = self.generate_and_load(
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1490, in generate_and_load
result = generate_code(kernel)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1442, in generate_code
template = kernel.render(self.template, kwargs, caching_enabled)
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1038, in render
template.render(**template_env, **kwargs),
File "/home/jobuser/.local/lib/python3.10/site-packages/jinja2/environment.py", line 1295, in render
self.environment.handle_exception()
File "/home/jobuser/.local/lib/python3.10/site-packages/jinja2/environment.py", line 942, in handle_exception
raise rewrite_traceback_stack(source=source)
File "<template>", line 398, in top-level template code
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 734, in modification
out = subgraph.data.inner_fn(())
File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 600, in inner_fn
assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
torch._inductor.exc.InductorError: LoweringException: AssertionError: wrong ndim () [64]
target: flex_attention
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 64, 36371, 64], stride=[148975616, 2327744, 64, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 36371, 64], stride=[18621952, 2327744, 64, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 36371, 64], stride=[18621952, 64, 512, 1]))
))
args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
args[4]: (36371, 36371, TensorBox(StorageBox(
InputBuffer(name='arg4_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285], stride=[285, 285, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg3_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285, 285], stride=[81225, 81225, 285, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg7_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285], stride=[285, 285, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg8_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285, 285], stride=[81225, 81225, 285, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg9_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285], stride=[285, 285, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg10_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285, 285], stride=[81225, 81225, 285, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg11_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285], stride=[285, 285, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='arg12_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285, 285], stride=[81225, 81225, 285, 1]))
)), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
args[5]: 0.125
args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}
args[7]: (TensorBox(StorageBox(
InputBuffer(name='arg5_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[64], stride=[1]))
)),)
args[8]: (TensorBox(StorageBox(
InputBuffer(name='arg6_1', layout=FixedLayout('cuda:0', torch.int64, size=[], stride=[]))
)),)
Reward curve with flex attention
@chenhaiq Any thoughts on above issue with flex attention?
@JasonZhu1313 can you share your config for running flex attention on gsm8k?
@MarkYangjiayi You can just pass +actor_rollout_ref.model.override_config.attn_implementation=flex_attention as the team earlier fixed the issue that flash attention v2 is hard coded in the code.