verl
verl copied to clipboard
[worker, trainer, recipe] feat: add FP16 training and inference support
This PR adds FP16 (float16 precision) training support to verl. The implementation includes:
| Component | Precision |
|---|---|
| Training (Actor) | float16 |
| Training (Ref) | float16 |
| Inference (vLLM Rollout) | float16 |
| Gradient Scaler | ShardedGradScaler (enabled) |
- Inspired by Precision-RL Patch
- Reference Paper: Defeating the Training–Inference Mismatch via FP16
A new script demonstrates end-to-end FP16 training:
File: recipe/flowrl/run_flowrl_qwen2.5_7b_fp16.sh This script launches FlowRL with the appropriate precision overrides:
actor_rollout_ref.actor.dtype=float16 \
actor_rollout_ref.ref.dtype=float16 \
actor_rollout_ref.rollout.dtype=float16 \
Implementation Details :
verl/workers/actor/dp_actor.py
+ # Add ShardedGradScaler for fp16
+ if self.config.dtype == "float16":
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+ self.scaler = ShardedGradScaler(growth_interval=400)
+ else:
+ self.scaler = None
+ # Use scaler in backward
+ if self.scaler is not None:
+ self.scaler.unscale_(self.actor_optimizer)
+ # Use scaler in optimizer step
+ if self.scaler is not None:
+ self.scaler.step(self.actor_optimizer)
+ self.scaler.update()
recipe/flowrl/flowrl_fsdp_worker.py
- param_dtype = torch.bfloat16
+ param_dtype = PrecisionType.to_dtype(self.config.actor.get("dtype", "float16"))
+ vllm_dtype = PrecisionType.to_dtype(self.config.rollout.dtype)
- torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
+ torch_dtype = torch.float32 if self._is_actor else vllm_dtype
recipe/flowrl/flowrl_actor.py
- with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
+ from verl.utils.torch_dtypes import PrecisionType
+ torch_dtype = PrecisionType.to_dtype(self.config.dtype)
+ with torch.autocast(device_type=self.device_name, dtype=torch_dtype):
- loss.backward()
+ if self.scaler is not None:
+ self.scaler.scale(loss).backward()
+ else:
+ loss.backward()
Same as @ISEEKYAN , bf16 should be kept.
More comprehensive tests on fp16 for RL training should be verified.
Yes, I agree. @ISEEKYAN @PeterSH6 I’ve kept the default bf16 dtype unchanged, and added a new example script in the flowrl recipe that enables FP16 through configuration overrides.
recipe/flowrl/run_flowrl_qwen2.5_7b_fp16.sh
actor_rollout_ref.actor.dtype=float16 \
actor_rollout_ref.ref.dtype=float16 \
actor_rollout_ref.rollout.dtype=float16 \
actor_rollout_ref.actor.use_amp=True \
This way, users can experiment with FP16 without affecting the existing default setup.
Also, I’m currently running FP16 training experiments to further verify its stability and performance on RL workloads.
Yes, I agree. @ISEEKYAN @PeterSH6 I’ve kept the default bf16 dtype unchanged, and added a new example script in the flowrl recipe that enables FP16 through configuration overrides.
recipe/flowrl/run_flowrl_qwen2.5_7b_fp16.sh
actor_rollout_ref.actor.dtype=float16 \ actor_rollout_ref.ref.dtype=float16 \ actor_rollout_ref.rollout.dtype=float16 \ actor_rollout_ref.actor.use_amp=True \This way, users can experiment with FP16 without affecting the existing default setup. Also, I’m currently running FP16 training experiments to further verify its stability and performance on RL workloads.
Could you also show the eval score? Thanks!
Yes, I agree. @ISEEKYAN @PeterSH6 I’ve kept the default bf16 dtype unchanged, and added a new example script in the flowrl recipe that enables FP16 through configuration overrides. recipe/flowrl/run_flowrl_qwen2.5_7b_fp16.sh
actor_rollout_ref.actor.dtype=float16 \ actor_rollout_ref.ref.dtype=float16 \ actor_rollout_ref.rollout.dtype=float16 \ actor_rollout_ref.actor.use_amp=True \This way, users can experiment with FP16 without affecting the existing default setup. Also, I’m currently running FP16 training experiments to further verify its stability and performance on RL workloads.
Could you also show the eval score? Thanks!
Absolutely! Here are the current online eval results; I’ll update more later.
@Xuekai-Zhu Is this ready for merge? There's some ci failed.
@Xuekai-Zhu Is this ready for merge? There's some ci failed.
@wuxibin89 I’ll work on fixing these CI issues today ! After the fixes are ready, I’ll reach out to you again for another review. Thanks again!
