feat: KV cache quantization support in fp8 rollout in GRPO
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Hi @guyueh1 , as discussed, I'm creating this draft PR for you to review the initial design so that we can discuss the refactor and next steps etc.
- Issue to be addressed: https://github.com/NVIDIA-NeMo/RL/issues/1185
- For the design and flow please refer to https://github.com/NVIDIA-NeMo/RL/issues/1185#issuecomment-3324411200
- The current implementation is only for the megatron + vllm (batch) backend.
Current experiment result of Qwen3-8B: (Orange line: bf16. Green: default FP8 rollout. Blue: default FP8 rollout + KV cache FP8.)
Note that there is room for calibration optimization to reduce the total step time.
@zpqiu fyi.
Usage
- You can potentially add a usage example below
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
- [x] Make sure you read and followed Contributor guidelines
- [x] Did you write any new necessary tests?
- [ ] Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
- [ ] Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.
Additional Information
- ...
Summary by CodeRabbit
- New Features
- Added FP8 KV-cache support with configurable kv_cache dtype and optional KV-scale calibration/synchronization across generation and training.
- Enabled KV-scale propagation during model refits and weight updates with vLLM.
- Bug Fixes
- Improved checkpoint saving stability by reducing GPU memory pressure and clearing caches to avoid OOM errors.
- Documentation
- Expanded docs for FP8 KV-cache settings and KV-scale usage.
- Chores
- Added runtime diagnostics and compatibility checks for FP8 KV-cache across backends.
📝 Walkthrough
Walkthrough
Adds FP8 KV-cache scale handling across GRPO refit and vLLM update paths: computes/calibrates Q/K/V FP8 scales, caches and threads them through IPC/NCCL weight updates, and applies them post-load in vLLM. Introduces config for kv_cache_dtype, backend compatibility checks, and new calibration methods in policy interfaces and workers.
Changes
| Cohort / File(s) | Summary of Changes |
|---|---|
GRPO algorithm refit and syncnemo_rl/algorithms/grpo.py |
Adds FP8 KV-scale threading through refit_policy_generation (new kv_scales arg). Computes/reuses kv_scales cache during generation/refit, propagates to IPC/NCCL update paths, and marks generation stale when recalibration occurs. Adds _should_sync_kv_scales and FP8 kv_cache compatibility checks. |
vLLM update pipeline (IPC/collective)nemo_rl/models/generation/vllm/vllm_backend.py, nemo_rl/models/generation/vllm/vllm_generation.py, nemo_rl/models/generation/vllm/vllm_worker.py |
Extends update_weights_from_local_ipc_handles/update_weights_from_collective to accept optional kv_scales, append scale tensors to payload, and invoke process_weights_after_loading when provided. Threads kv_scales through generation and worker RPCs; adds diagnostics. |
FP8 config and weight processingnemo_rl/models/generation/fp8.py |
Adds FP8Config.kv_cache_dtype. Introduces kv_cache_process_weights_after_loading to retain and remap Q/K/V/prob scales on load. Wires kv_cache_dtype into vLLM init kwargs and weight loading; adds debug output. |
Policy interfaces and implementationsnemo_rl/models/policy/interfaces.py, nemo_rl/models/policy/lm_policy.py, nemo_rl/models/policy/megatron_policy_worker.py |
Adds calibrate_qkv_fp8_scales to PolicyInterface and Policy (sharded dispatch/gather). Implements calibration in MegatronPolicyWorker using forward hooks, percentile-based amax, margin, optional JSON save, and distributed merging. Also augments checkpoint save with memory handling. |
Sequence Diagram(s)
sequenceDiagram
autonumber
actor Trainer as GRPO Trainer
participant Policy as Policy
participant Worker as MegatronPolicyWorker
participant Gen as vLLMGeneration
participant Wkr as vLLMWorker
participant BE as vLLMBackend
participant Model as Model
rect rgba(230,240,255,0.5)
note over Trainer,Worker: Optional FP8 Q/K/V scale calibration
Trainer->>Policy: calibrate_qkv_fp8_scales(data, opts)
Policy->>Worker: calibrate_qkv_fp8_scales(shard, opts)
Worker-->>Policy: kv_scales (per-layer)
Policy-->>Trainer: kv_scales
end
rect rgba(240,255,240,0.5)
note over Trainer,Model: Refit with KV scales
Trainer->>Gen: update_weights_from_ipc_handles/collective(kv_scales)
alt IPC
Gen->>Wkr: update_weights_from_ipc_handles(..., kv_scales)
Wkr->>BE: update_weights_from_local_ipc_handles(..., kv_scales)
else Collective
Gen->>Wkr: update_weights_from_collective(kv_scales)
Wkr->>BE: update_weights_from_collective(kv_scales)
end
BE->>Model: load weights (+kv scale tensors)
BE->>Model: process_weights_after_loading(apply kv scales)
Model-->>BE: ready
BE-->>Wkr: ok
Wkr-->>Gen: ok
Gen-->>Trainer: ok
end
note over Trainer: Continue rollout/generation
Estimated code review effort
🎯 4 (Complex) | ⏱️ ~60 minutes
Suggested labels
CI:L1, r0.4.0
Suggested reviewers
- yuki-97
- parthchadha
- joyang-nv
- jgerh
Pre-merge checks and finishing touches
❌ Failed checks (2 warnings)
| Check name | Status | Explanation | Resolution |
|---|---|---|---|
| Docstring Coverage | ⚠️ Warning | Docstring coverage is 68.75% which is insufficient. The required threshold is 80.00%. | You can run @coderabbitai generate docstrings to improve docstring coverage. |
| Test Results For Major Changes | ⚠️ Warning | The PR introduces a major feature by adding FP8 KV-cache quantization support, which directly impacts numerics and performance, but the description only mentions an illustrative image and leaves the testing checklist unchecked without documenting concrete test results, regression checks, or performance metrics, so the required evidence for major changes is absent. | Please update the PR description with explicit test results or benchmarking details, including configurations and outcomes that demonstrate numerical stability and performance impact for the new FP8 KV-cache quantization flow. |
✅ Passed checks (2 passed)
| Check name | Status | Explanation |
|---|---|---|
| Title check | ✅ Passed | The title accurately summarizes the main change: adding KV cache quantization support to FP8 rollout in GRPO, which aligns with the core modifications across fp8.py, vllm_backend.py, vllm_generation.py, vllm_worker.py, and grpo.py. |
| Description Check | ✅ Passed | Check skipped - CodeRabbit’s high-level summary is enabled. |
✨ Finishing touches
- [ ] 📝 Generate docstrings
🧪 Generate unit tests (beta)
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
[!TIP]
📝 Customizable high-level summaries are now available in beta!
You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
- Provide your own instructions using the
high_level_summary_instructionssetting.- Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
- Use
high_level_summary_in_walkthroughto move the summary from the description to the walkthrough section.Example instruction:
"Divide the high-level summary into five sections:
- 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
- 📓 References — List relevant issues, discussions, documentation, or related PRs.
- 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
- 📊 Contributor Summary — Include a Markdown table showing contributions:
| Contributor | Lines Added | Lines Removed | Files Changed |- ✔️ Additional Notes — Add any extra reviewer context. Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
Comment @coderabbitai help to get the list of available commands and usage tips.
@zpqiu sorry for the long delay, I have put some comments; could you first merge in main and then address them? I think this solution can be further optimized in terms of performance, but i'm ok with merging this version first.
@zpqiu sorry for the long delay, I have put some comments; could you first merge in main and then address them? I think this solution can be further optimized in terms of performance, but i'm ok with merging this version first.
Sure. I will rebase code and resolve these comments first.
ℹ️ File Consistency Check
Check based on commit: 0dbf7ab7775d78dbac2d0e4097edeec75d22f98b (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 0ea586b2072c5715ecb38da60638fc6c3dd3accf (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 8f759a1f9c71bcaf6e2ad51a278c004c571ed809 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: ea3e500ea0b507efdf4e8d50469b95c8a9cfdae7 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 6f3bed7519e4dc24c5cba3b12b9be2d31980e0e5 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 4089ab58ea0d70f9904a9899f31d74a51be5858f (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
The changes generally look good, there are some minor issues. We should add some test for this feature, I suggest adding tests for fp8 kv cache in the following unit tests
- https://github.com/NVIDIA-NeMo/RL/blob/5f6cfc7f66180461ce1f6d12c00e9e462d4d90ec/tests/unit/models/generation/test_vllm_generation.py#L916
- https://github.com/NVIDIA-NeMo/RL/blob/5f6cfc7f66180461ce1f6d12c00e9e462d4d90ec/tests/unit/models/generation/test_vllm_generation.py#L864
hi @guyueh1 Thanks for the comments. Since the current implementation is for megatron backend only, I added a test into https://github.com/NVIDIA-NeMo/RL/blob/4089ab58ea0d70f9904a9899f31d74a51be5858f/tests/unit/models/generation/test_vllm_generation.py#L1924
ℹ️ File Consistency Check
Check based on commit: ac6f66c906de549581b685f42e4003e289776348 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: af60c9a8e82f15df9421ad01d2537ba69720d00f (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: f150419d125e9166e9a0ced32eb231cb9dc585e4 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 231a739d81af5d379a0c17d5da14aa5b2d8fe7c3 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
@zpqiu can you fix the functional test failure? Also I think the L1 functionality is ran on Ampere GPUs, maybe you need to conditionally skip for cuda arch before sm_90
@zpqiu can you fix the functional test failure?
Also I think the L1 functionality is ran on Ampere GPUs, maybe you need to conditionally skip for cuda arch before sm_90
If so, do we need to delete this test?
ℹ️ File Consistency Check
Check based on commit: 4f1324a8dcd7e160988177ca90c200512c80f187 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 47ea0c0304d5e90947e1e1e77dfe9cd179c8ec0c (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 94d16ec98520e1e7a7d8ef68109b8f91cc5a7577 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 48b20aaf1da5e8dae90466a8ac898ed72ec4026d (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: 7ca82f30b9d967826c578a46db3cffd2ded39cc2 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
@zpqiu can you fix the functional test failure? Also I think the L1 functionality is ran on Ampere GPUs, maybe you need to conditionally skip for cuda arch before sm_90
If so, do we need to delete this test?
I'm ok with adding test just to the nightly test suite as you do now.
@terrykong please review
@terrykong this is the last FP8 functionality we want to merge before v0.5, after this I want to perform a refactor of code to make it cleaner and more structured. Please take a review when you have time.
ℹ️ File Consistency Check
Check based on commit: b34ad763a9bce7ede29c302b5b654ffc0af4ef1d (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
Update here the latest experimental results.
Configuration Model: Qwen3-8B-Base. Method: Dynamically calculate qkv scales at the end of each training step and synchronize them to vLLM. Framework: NeMo-RL, vLLM + MCore, batch rollout mode. Correction: Token-level TIS, C=2
Observations:
-
Mismatch: Enabling FP8 for KV cache and attention increases mismatch compared to using FP8 only for Linear layers.
-
Accuracy: Applying token-level TIS realigns the accuracy curve with BF16.
-
KV Cache Capacity: FP8 KV cache provides an additional 2x token capacity and concurrency. BF16: GPU KV cache size: 249,952 tokens, Maximum concurrency: 11.10x FP8 Linear-only: GPU KV cache size: 299,344 tokens, Maximum concurrency: 13.29x FP8 Linear + KV Cache: GPU KV cache size: 598,672 tokens, Maximum concurrency: 26.57x
-
Speedup: Adding FP8 KV cache/Attention yields an additional ~30% rollout speedup over FP8 Linear only. Total speedup compared to BF16 is approximately 48%.
-
Observation: Longer response lengths benefit more due to the higher portion of computation spent in attention.
ℹ️ File Consistency Check
Check based on commit: 6d654663150aacae496e52702bc81712adec9517 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
ℹ️ File Consistency Check
Check based on commit: db3ea88624aad10b6f4c14749a355374eab4b4b9 (PR #1212 from kv-cache-fp8)
✅ DTensor Policy Worker Synchronization Check
Both DTensor policy worker files were modified in this PR:
nemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
Please ensure that the changes are consistent between both files where applicable.
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
synced offline w/ @guyueh1 who is following up with another PR to do some cleanup, so i'll defer my comments to her PR
- https://github.com/NVIDIA-NeMo/RL/pull/1212#discussion_r2575847867
- https://github.com/NVIDIA-NeMo/RL/pull/1212#discussion_r2575860333
- https://github.com/NVIDIA-NeMo/RL/pull/1212/files#r2582264513
- https://github.com/NVIDIA-NeMo/RL/pull/1212#discussion_r2575890680
- https://github.com/NVIDIA-NeMo/RL/pull/1212#discussion_r2575917551