verl icon indicating copy to clipboard operation
verl copied to clipboard

[worker, trainer, recipe] feat: add FP16 training and inference support

Open Xuekai-Zhu opened this issue 4 weeks ago • 7 comments

This PR adds FP16 (float16 precision) training support to verl. The implementation includes:

Component Precision
Training (Actor) float16
Training (Ref) float16
Inference (vLLM Rollout) float16
Gradient Scaler ShardedGradScaler (enabled)

A new script demonstrates end-to-end FP16 training:

File: recipe/flowrl/run_flowrl_qwen2.5_7b_fp16.sh This script launches FlowRL with the appropriate precision overrides:

    actor_rollout_ref.actor.dtype=float16 \
    actor_rollout_ref.ref.dtype=float16 \
    actor_rollout_ref.rollout.dtype=float16 \

Implementation Details :

  1. verl/workers/actor/dp_actor.py
+ # Add ShardedGradScaler for fp16
+ if self.config.dtype == "float16":
+     from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+     self.scaler = ShardedGradScaler(growth_interval=400)
+ else:
+     self.scaler = None

+ # Use scaler in backward
+ if self.scaler is not None:
+     self.scaler.unscale_(self.actor_optimizer)

+ # Use scaler in optimizer step
+ if self.scaler is not None:
+     self.scaler.step(self.actor_optimizer)
+     self.scaler.update()
  1. recipe/flowrl/flowrl_fsdp_worker.py
- param_dtype = torch.bfloat16
+ param_dtype = PrecisionType.to_dtype(self.config.actor.get("dtype", "float16"))

+ vllm_dtype = PrecisionType.to_dtype(self.config.rollout.dtype)
- torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
+ torch_dtype = torch.float32 if self._is_actor else vllm_dtype
  1. recipe/flowrl/flowrl_actor.py
- with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
+ from verl.utils.torch_dtypes import PrecisionType
+ torch_dtype = PrecisionType.to_dtype(self.config.dtype)
+ with torch.autocast(device_type=self.device_name, dtype=torch_dtype):

- loss.backward()
+ if self.scaler is not None:
+     self.scaler.scale(loss).backward()
+ else:
+     loss.backward()

Xuekai-Zhu avatar Nov 06 '25 10:11 Xuekai-Zhu

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Nov 06 '25 10:11 CLAassistant

Same as @ISEEKYAN , bf16 should be kept.

More comprehensive tests on fp16 for RL training should be verified.

PeterSH6 avatar Nov 06 '25 10:11 PeterSH6

Yes, I agree. @ISEEKYAN @PeterSH6 I’ve kept the default bf16 dtype unchanged, and added a new example script in the flowrl recipe that enables FP16 through configuration overrides.

recipe/flowrl/run_flowrl_qwen2.5_7b_fp16.sh

    actor_rollout_ref.actor.dtype=float16 \
    actor_rollout_ref.ref.dtype=float16 \
    actor_rollout_ref.rollout.dtype=float16 \
    actor_rollout_ref.actor.use_amp=True \

This way, users can experiment with FP16 without affecting the existing default setup. Also, I’m currently running FP16 training experiments to further verify its stability and performance on RL workloads. W B Chart 2025_11_6 19_09_05

Xuekai-Zhu avatar Nov 06 '25 11:11 Xuekai-Zhu

Yes, I agree. @ISEEKYAN @PeterSH6 I’ve kept the default bf16 dtype unchanged, and added a new example script in the flowrl recipe that enables FP16 through configuration overrides.

recipe/flowrl/run_flowrl_qwen2.5_7b_fp16.sh

    actor_rollout_ref.actor.dtype=float16 \
    actor_rollout_ref.ref.dtype=float16 \
    actor_rollout_ref.rollout.dtype=float16 \
    actor_rollout_ref.actor.use_amp=True \

This way, users can experiment with FP16 without affecting the existing default setup. Also, I’m currently running FP16 training experiments to further verify its stability and performance on RL workloads. W B Chart 2025_11_6 19_09_05

Could you also show the eval score? Thanks!

vermouth1992 avatar Nov 06 '25 13:11 vermouth1992

Yes, I agree. @ISEEKYAN @PeterSH6 I’ve kept the default bf16 dtype unchanged, and added a new example script in the flowrl recipe that enables FP16 through configuration overrides. recipe/flowrl/run_flowrl_qwen2.5_7b_fp16.sh

    actor_rollout_ref.actor.dtype=float16 \
    actor_rollout_ref.ref.dtype=float16 \
    actor_rollout_ref.rollout.dtype=float16 \
    actor_rollout_ref.actor.use_amp=True \

This way, users can experiment with FP16 without affecting the existing default setup. Also, I’m currently running FP16 training experiments to further verify its stability and performance on RL workloads. W B Chart 2025_11_6 19_09_05

Could you also show the eval score? Thanks!

Absolutely! Here are the current online eval results; I’ll update more later. W B Chart 2025_11_6 21_13_06

Xuekai-Zhu avatar Nov 06 '25 13:11 Xuekai-Zhu

@Xuekai-Zhu Is this ready for merge? There's some ci failed.

wuxibin89 avatar Nov 13 '25 04:11 wuxibin89

@Xuekai-Zhu Is this ready for merge? There's some ci failed.

@wuxibin89 I’ll work on fixing these CI issues today ! After the fixes are ready, I’ll reach out to you again for another review. Thanks again!

Xuekai-Zhu avatar Nov 13 '25 04:11 Xuekai-Zhu