verl icon indicating copy to clipboard operation
verl copied to clipboard

Agentic RL Support in GPT-OSS

Open JasonZhu1313 opened this issue 2 months ago • 15 comments

Feature request

Agentic RL Support in GPT-OSS

Motivation

Hey Community,

@HJSang and I are from the LinkedIn Core AI team. Over the past few weeks, we’ve been working on adapting GPT-OSS models for agentic RL post-training with tool calls, using SGLang as the inference engine in an agent-loop setup (token-in, token-out).

We’ve successfully made it functionally work on both internal and retool datasets, but we’ve observed some unusual behaviors in gradient norms and entropy.

We’d love to share reproducible code and collaborate with the community to better understand and resolve these issues.

Let's use this for supporting the gpt-oss in verl. cc @chenhaiq @eric-haibin-lin

Your contribution

  • Add support for tool call parser in GPT-OSS
  • Making the templated chat message to be compatible with compute_score in rewards
  • Use triton endpoint for GPT-OSS

JasonZhu1313 avatar Oct 16 '25 22:10 JasonZhu1313

The current issue to train GPT-OSS model: grad_norm of GRPO grows too fast, which prevents model to achieve reasonable good performance.

  • Train on gsm8k PR

reasoning effort: medium

Image Image

reasoning effort: low

Image Image

as a comparison, qwen3-4B can get 90% easily

Image Image

HJSang avatar Oct 21 '25 02:10 HJSang

Train on retool with tool agent: PR grad_norm can grow as large as 1500

Image

HJSang avatar Oct 21 '25 02:10 HJSang

  • agent loop training using math-expression example: https://github.com/volcengine/verl/blob/main/recipe/langgraph_agent/example/run_gpt_oss_20b_bf16.sh
Image

HJSang avatar Oct 21 '25 02:10 HJSang

We exclude the MOE instability by setting batch_size = mini batch size to enforce on policy.

HJSang avatar Oct 21 '25 02:10 HJSang

My only hypothesis is that there is some issue for current gpt-oss model's implementation in transformers which causes the instability of gradient. Your investigation will be really appreciated.

HJSang avatar Oct 21 '25 02:10 HJSang

Try importance sampling for gpt-oss training. The good news is taht the grad is no longer exploding but the reward and val are very looking good. My hypothesis is that: there is some unknown issue for model implementation.

Image Image

The rollout-training mis-match metrics

Image

HJSang avatar Oct 27 '25 17:10 HJSang

We investigated a critical compatibility issue where flash_attention_2 doesn't support gpt-oss attention sink, causing gradient norm spikes during training. The existing VERL codebase lacked the ability to override the attn_implementation parameter, forcing users to use incompatible attention mechanisms. Our solution in PR #3978 adds flexible attention implementation override support, allowing users to switch to attn_implementation via configuration. Initial testing with attn_implementation=eager confirmed this eliminates gradient norm spikes and ensures proper score convergence.

Image Image

arde171 avatar Oct 31 '25 20:10 arde171

Thank you for your solutions. It works~ If you want to gain the Training Speed and GPU Memory Advantage, replace eager to "kernels-community/vllm-flash-attn3".

Implementation Details: https://huggingface.co/kernels-community/vllm-flash-attn3

yinzhangyue avatar Nov 05 '25 01:11 yinzhangyue

@yinzhangyue it didn't implement the backward path with attention-sink support.

arde171 avatar Nov 05 '25 01:11 arde171

I want to point out that, both empirically and from looking at the code, it seems the HF implementation of Flex Attention supports sinks for both forward and backward passes, and is much more efficient than eager. Let me know if I'm mistaken!

aghyad-deeb avatar Nov 14 '25 20:11 aghyad-deeb

I also want to point out that the latest released version of vLLM doesn't support LoRA for GPT OSS; it applies lora only to attention layers which causes discrepancy between the training policy and the rollout policy. It seems like the latest dev version has a fix but I haven't tested it.

aghyad-deeb avatar Nov 14 '25 20:11 aghyad-deeb

@aghyad-deeb Yeah I have tried flex attention and I am able to replicate the reward curve as eager mode for gsm8k example, but when I tried retool example which is more complex that involves multi-turn tool-use, I got the following error in backward pass:

[36m(RewardManagerWorker pid=1080802)[0m Debugging: num_turns: 16[32m [repeated 7x across cluster][0m
Error executing job with overrides: ['algorithm.adv_estimator=gae', 'algorithm.use_kl_in_reward=False', 'algorithm.kl_ctrl.kl_coef=0.0', 'algorithm.gamma=1.0', 'algorithm.lam=1.0', "data.train_files=['/home/jobuser/dataset/BytedTsinghua-SIA/DAPO-Math-17k']", "data.val_files=['/home/jobuser/dataset/yentinglin/aime_2025']", 'data.return_raw_chat=True', 'data.train_batch_size=512', 'data.max_prompt_length=2048', 'data.max_response_length=16384', 'data.filter_overlong_prompts=True', '+data.apply_chat_template_kwargs.reasoning_effort=medium', 'data.truncation=error', 'data.custom_cls.path=/home/jobuser/verl/recipe/retool/retool.py', 'data.custom_cls.name=CustomRLHFDataset', 'custom_reward_function.path=/home/jobuser/verl/recipe/retool/retool.py', 'custom_reward_function.name=compute_score', 'actor_rollout_ref.model.path=/shared/public/sharing/hsang/gpt-oss-20b-bf16', 'actor_rollout_ref.model.use_remove_padding=True', 'actor_rollout_ref.model.enable_gradient_checkpointing=True', 'actor_rollout_ref.actor.use_kl_loss=False', 'actor_rollout_ref.actor.kl_loss_coef=0.0', '+actor_rollout_ref.model.override_config.attn_implementation=flex_attention', 'actor_rollout_ref.actor.clip_ratio_low=0.2', 'actor_rollout_ref.actor.clip_ratio_high=0.28', 'actor_rollout_ref.actor.clip_ratio_c=10.0', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.actor.use_dynamic_bsz=True', 'actor_rollout_ref.actor.ppo_mini_batch_size=512', 'actor_rollout_ref.actor.ppo_max_token_len_per_gpu=36864', 'actor_rollout_ref.actor.ulysses_sequence_parallel_size=4', 'actor_rollout_ref.actor.fsdp_config.param_offload=True', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=True', 'actor_rollout_ref.rollout.name=sglang', 'actor_rollout_ref.rollout.mode=async', 'actor_rollout_ref.rollout.tensor_model_parallel_size=4', 'actor_rollout_ref.rollout.multi_turn.enable=True', 'actor_rollout_ref.rollout.multi_turn.max_user_turns=8', 'actor_rollout_ref.rollout.multi_turn.max_assistant_turns=8', 'actor_rollout_ref.rollout.multi_turn.tool_config_path=/home/jobuser/verl/recipe/retool/sandbox_fusion_tool_config.yaml', 'actor_rollout_ref.rollout.multi_turn.format=gpt-oss', '+actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.8', 'actor_rollout_ref.rollout.val_kwargs.top_p=1.0', 'actor_rollout_ref.rollout.val_kwargs.temperature=1.0', 'actor_rollout_ref.rollout.val_kwargs.n=30', 'critic.optim.lr=2e-6', 'critic.model.use_remove_padding=True', 'critic.model.path=/shared/public/sharing/hsang/gpt-oss-20b-bf16', 'critic.model.enable_gradient_checkpointing=True', 'critic.ppo_max_token_len_per_gpu=73728', 'critic.ulysses_sequence_parallel_size=4', 'critic.model.fsdp_config.param_offload=True', 'critic.model.fsdp_config.optimizer_offload=True', 'trainer.critic_warmup=20', 'trainer.logger=[mlflow]', 'trainer.project_name=wuxibin_retool', 'trainer.experiment_name=gpt-oss-20b-bf16_ppo_jaszhu', 'trainer.n_gpus_per_node=8', 'trainer.val_before_train=True', 'trainer.log_val_generations=100', 'trainer.nnodes=1', 'trainer.save_freq=30', 'trainer.default_local_dir=/home/jobuser/checkpoint/gpt-oss-20b-bf16_ppo_jaszhu', 'trainer.test_freq=5', 'trainer.total_epochs=1']
Traceback (most recent call last):
  File "/home/jobuser/verl/verl/trainer/main_ppo.py", line 42, in main
    run_ppo(config)
  File "/home/jobuser/verl/verl/trainer/main_ppo.py", line 96, in run_ppo
    ray.get(runner.run.remote(config))
  File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 2782, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 929, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError: [36mray::TaskRunner.run()[39m (pid=1037235, ip=100.96.49.172, actor_id=96bebcc7e8e13436722d792401000000, repr=<main_ppo.TaskRunner object at 0x7adc328a3070>)
  File "/home/jobuser/verl/verl/trainer/main_ppo.py", line 341, in run
    trainer.fit()
  File "/home/jobuser/verl/verl/trainer/ppo/ray_trainer.py", line 1120, in fit
    old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
  File "/home/jobuser/verl/verl/single_controller/ray/base.py", line 48, in __call__
    output = ray.get(output)
ray.exceptions.RayTaskError: [36mray::WorkerDict.actor_rollout_compute_log_prob()[39m (pid=1057174, ip=100.96.49.172, actor_id=1312a0b6dd7413234af06c8b01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7048fc39c220>)
  File "/export/apps/python/3.10/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/export/apps/python/3.10/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/home/jobuser/verl/verl/single_controller/ray/base.py", line 700, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/home/jobuser/verl/verl/single_controller/base/decorator.py", line 442, in inner
    return func(*args, **kwargs)
  File "/home/jobuser/verl/verl/utils/transferqueue_utils.py", line 199, in dummy_inner
    return func(*args, **kwargs)
  File "/home/jobuser/verl/verl/utils/profiler/profile.py", line 256, in wrapper
    return func(self_instance, *args, **kwargs_inner)
  File "/home/jobuser/verl/verl/workers/fsdp_workers.py", line 979, in compute_log_prob
    output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
  File "/home/jobuser/verl/verl/utils/profiler/performance.py", line 105, in f
    return self.log(decorated_function, *args, **kwargs)
  File "/home/jobuser/verl/verl/utils/profiler/performance.py", line 118, in log
    output = func(*args, **kwargs)
  File "/home/jobuser/verl/verl/workers/actor/dp_actor.py", line 339, in compute_log_prob
    entropy, log_probs = self._forward_micro_batch(
  File "/home/jobuser/verl/verl/workers/actor/dp_actor.py", line 170, in _forward_micro_batch
    output = self.actor_module(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 940, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 663, in forward
    outputs: MoeModelOutputWithPast = self.model(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 502, in forward
    hidden_states = decoder_layer(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 366, in forward
    hidden_states, _ = self.self_attn(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 323, in forward
    attn_output, attn_weights = attention_interface(
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/integrations/flex_attention.py", line 296, in flex_attention_forward
    flex_attention_output = compile_friendly_flex_attention(
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/integrations/flex_attention.py", line 97, in compile_friendly_flex_attention
    return flex_attention_compiled(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 923, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 907, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1377, in codegen_and_compile
    graph.run(*example_inputs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 921, in run
    return super().run(*args)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 173, in run
    self.env[node] = self.run_node(node)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1599, in run_node
    result = super().run_node(n)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 242, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1268, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1258, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 446, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 1534, in flex_attention
    error = flex_attention_template.maybe_append_choice(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1315, in maybe_append_choice
    choices.append(self.generate(generate_with_caching=True, **kwargs))
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1563, in generate
    result = self.generate_and_load(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1490, in generate_and_load
    result = generate_code(kernel)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1442, in generate_code
    template = kernel.render(self.template, kwargs, caching_enabled)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1038, in render
    template.render(**template_env, **kwargs),
  File "/home/jobuser/.local/lib/python3.10/site-packages/jinja2/environment.py", line 1295, in render
    self.environment.handle_exception()
  File "/home/jobuser/.local/lib/python3.10/site-packages/jinja2/environment.py", line 942, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<template>", line 398, in top-level template code
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 734, in modification
    out = subgraph.data.inner_fn(())
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 600, in inner_fn
    assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
torch._inductor.exc.InductorError: LoweringException: AssertionError: wrong ndim () [64]
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 64, 36371, 64], stride=[148975616, 2327744, 64, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 36371, 64], stride=[18621952, 2327744, 64, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 36371, 64], stride=[18621952, 64, 512, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (36371, 36371, TensorBox(StorageBox(
    InputBuffer(name='arg4_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285], stride=[285, 285, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285, 285], stride=[81225, 81225, 285, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg7_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285], stride=[285, 285, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg8_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285, 285], stride=[81225, 81225, 285, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg9_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285], stride=[285, 285, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg10_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285, 285], stride=[81225, 81225, 285, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg11_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285], stride=[285, 285, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg12_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 285, 285], stride=[81225, 81225, 285, 1]))
  )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}
  args[7]: (TensorBox(StorageBox(
    InputBuffer(name='arg5_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[64], stride=[1]))
  )),)
  args[8]: (TensorBox(StorageBox(
    InputBuffer(name='arg6_1', layout=FixedLayout('cuda:0', torch.int64, size=[], stride=[]))
  )),)


Reward curve with flex attention Image

JasonZhu1313 avatar Nov 15 '25 01:11 JasonZhu1313

@chenhaiq Any thoughts on above issue with flex attention?

JasonZhu1313 avatar Nov 15 '25 01:11 JasonZhu1313

@JasonZhu1313 can you share your config for running flex attention on gsm8k?

MarkYangjiayi avatar Nov 17 '25 15:11 MarkYangjiayi

@MarkYangjiayi You can just pass +actor_rollout_ref.model.override_config.attn_implementation=flex_attention as the team earlier fixed the issue that flash attention v2 is hard coded in the code.

JasonZhu1313 avatar Nov 17 '25 23:11 JasonZhu1313