Skip to content

Commit 40e7040

Browse files
sharonyu-115Shuang YuShuang Yuzpqiuguyueh1
authored
feat: KV cache quantization support in fp8 rollout in GRPO (#1212)
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com> Signed-off-by: Shuang Yu <shuangy@nvidia.com> Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com> Signed-off-by: alexchiu <alexq@nvidia.com> Co-authored-by: Shuang Yu <shuangy@shuangy-mlt.client.nvidia.com> Co-authored-by: Shuang Yu <shuangy@cw-dfw-cs-001-vscode-02.cm.cluster> Co-authored-by: Zhaopeng Qiu <alexq@nvidia.com> Co-authored-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com> Co-authored-by: alexchiu <qiuzhaopeng@foxmail.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
1 parent 3817189 commit 40e7040

20 files changed

+827
-67
lines changed

examples/configs/distillation_math.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ policy: &POLICY_BASE
173173
vllm_cfg:
174174
async_engine: false
175175
precision: ${...precision}
176+
kv_cache_dtype: "auto"
176177
tensor_parallel_size: 1
177178
pipeline_parallel_size: 1
178179
expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP

examples/configs/grpo_math_1B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ policy:
217217
vllm_cfg:
218218
async_engine: false
219219
precision: ${policy.precision}
220+
kv_cache_dtype: "auto"
220221
tensor_parallel_size: 1
221222
pipeline_parallel_size: 1
222223
expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
val_period: 20
4+
checkpointing:
5+
enabled: false
6+
checkpoint_dir: results/grpo_qwen3_8b_fp8_kvcache
7+
loss_fn:
8+
use_importance_sampling_correction: true
9+
policy:
10+
model_name: Qwen/Qwen3-8B-Base
11+
train_micro_batch_size: 1
12+
logprob_batch_size: 1
13+
max_total_sequence_length: 8192
14+
dtensor_cfg:
15+
enabled: false
16+
optimizer: null
17+
scheduler: null
18+
megatron_cfg:
19+
enabled: true
20+
converter_type: Qwen3ForCausalLM
21+
tensor_model_parallel_size: 4
22+
optimizer:
23+
lr: 1.0e-06
24+
min_lr: 1.0e-06
25+
weight_decay: 0.1
26+
use_precision_aware_optimizer: false
27+
scheduler:
28+
lr_decay_iters: null
29+
lr_warmup_iters: 10
30+
lr_warmup_init: 1.0e-07
31+
make_sequence_length_divisible_by: ${mul:${policy.megatron_cfg.tensor_model_parallel_size},
32+
2}
33+
generation:
34+
vllm_cfg:
35+
precision: fp8
36+
kv_cache_dtype: fp8
37+
use_deep_gemm: true
38+
data:
39+
max_input_seq_length: 2048
40+
prompt_file: null
41+
dataset_name: DAPOMath17K
42+
env:
43+
dapo:
44+
num_workers: 16
45+
math:
46+
num_workers: 16
47+
math_verify_impl: dapo_math_verify
48+
cluster:
49+
gpus_per_node: 8

examples/configs/vlm_grpo_3B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ policy:
206206
vllm_cfg:
207207
async_engine: false # Only for internal testing, will be enabled by https://github.com/NVIDIA/NeMo-RL/issues/447.
208208
precision: ${policy.precision}
209+
kv_cache_dtype: "auto"
209210
tensor_parallel_size: 1
210211
pipeline_parallel_size: 1
211212
expert_parallel_size: 1

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ policy:
106106
vllm_cfg:
107107
async_engine: false
108108
precision: ${policy.precision}
109+
kv_cache_dtype: "auto"
109110
tensor_parallel_size: 1
110111
pipeline_parallel_size: 1
111112
expert_parallel_size: 1

nemo_rl/algorithms/grpo.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,24 @@ def init_vllm():
500500
assert loss_config["use_importance_sampling_correction"] is True, (
501501
"Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
502502
)
503+
if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"):
504+
# FP8 KV cache requires FP8 model precision
505+
assert generation_config["vllm_cfg"]["precision"] == "fp8", (
506+
f"kv_cache_dtype='{generation_config['vllm_cfg']['kv_cache_dtype']}' requires precision='fp8'. "
507+
"FP8 KV cache can only be used together with FP8 model weights."
508+
)
509+
# FP8 KV cache compatibility checks
510+
assert policy_config["dtensor_cfg"]["enabled"] == False, (
511+
"DTensor backend is not supported with kv cache fp8 enabled."
512+
)
513+
assert not _should_use_async_rollouts(master_config), (
514+
"Async rollouts is not supported with kv cache fp8 enabled."
515+
)
516+
assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, (
517+
"Currently when using FP8 KV cache in generation, then in megatron we only support pipeline_model_parallel_size=1. We will add more support in future."
518+
)
519+
520+
## make vllm hf overrides match the training policy
503521
generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get(
504522
"hf_config_overrides", {}
505523
)
@@ -877,6 +895,7 @@ def refit_policy_generation(
877895
colocated_inference: bool,
878896
_refit_buffer_size_gb: Optional[int] = None,
879897
timer: Optional[Timer] = None,
898+
kv_scales: Optional[dict[str, float]] = None,
880899
) -> None:
881900
"""Refit the policy generation interface with the latest policy weights.
882901
@@ -887,6 +906,7 @@ def refit_policy_generation(
887906
If it is None, the buffer size will be computed by the remaining memory.
888907
This parameter is primarily used for testing.
889908
timer: Optional Timer used to time the prepare/transfer/update phase
909+
kv_scales: Optional dictionary of KV cache scales for FP8 quantization.
890910
"""
891911
if colocated_inference:
892912
policy.offload_before_refit()
@@ -914,7 +934,7 @@ def refit_policy_generation(
914934
)
915935

916936
futures_train = policy.stream_weights_via_ipc_zmq(
917-
buffer_size_bytes=buffer_size_bytes
937+
buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales
918938
)
919939
futures_inference = policy_generation.update_weights_via_ipc_zmq()
920940
# wait for all futures to complete
@@ -923,7 +943,7 @@ def refit_policy_generation(
923943
update_success = all(result for result in results if result is not None)
924944
else:
925945
# update weights through nccl
926-
futures_train = policy.broadcast_weights_for_collective()
946+
futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales)
927947
futures_inference = policy_generation.update_weights_from_collective()
928948
# wait for all futures to complete
929949
ray.get(futures_train)
@@ -973,6 +993,8 @@ def grpo_train(
973993
)
974994
timeout.start_iterations()
975995

996+
kv_scales_cache = None # Cache reused for computed kv scales
997+
976998
NEED_REFIT = True
977999
# If policy_generation is None, use the policy as the generation interface (megatron framework backend)
9781000
if policy_generation is None:
@@ -981,6 +1003,10 @@ def grpo_train(
9811003
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
9821004
assert policy_generation is not None # for mypy type check
9831005

1006+
# Check if we need to sync KV cache scales
1007+
# When fallback to policy as the policy_generation, we use getattr to check.
1008+
sync_kv_scales = getattr(policy_generation, "requires_kv_scale_sync", False)
1009+
9841010
# common config/state itmes
9851011
current_step = grpo_save_state["current_step"] # current step within an epoch
9861012
total_steps = grpo_save_state["total_steps"] # total steps across all epochs
@@ -1002,6 +1028,7 @@ def grpo_train(
10021028
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]
10031029

10041030
# Run validation at the start if configured
1031+
# TODO: Add validation with kv scales if needed
10051032
if val_at_start and current_step == 0:
10061033
print("\n🔍 Running initial validation...", flush=True)
10071034
if NEED_REFIT and POLICY_GENERATION_STALE:
@@ -1063,8 +1090,43 @@ def grpo_train(
10631090
)
10641091
with timer.time("prepare_for_generation/total"):
10651092
if NEED_REFIT and POLICY_GENERATION_STALE:
1093+
# Compute KV scales if needed for FP8 quantization
1094+
if sync_kv_scales and kv_scales_cache is None:
1095+
print("▶ Computing KV cache scales...", flush=True)
1096+
policy.prepare_for_lp_inference()
1097+
# Align with training data processing to ensure parallel training compatibility
1098+
calib_flat, calib_input_lengths = (
1099+
batched_message_log_to_flat_message(
1100+
repeated_batch["message_log"],
1101+
pad_value_dict={
1102+
"token_ids": tokenizer.pad_token_id
1103+
},
1104+
make_sequence_length_divisible_by=master_config[
1105+
"policy"
1106+
]["make_sequence_length_divisible_by"],
1107+
)
1108+
)
1109+
# Create calibration data from flattened messages
1110+
calibration_data = BatchedDataDict[ClippedPGLossDataDict](
1111+
{
1112+
"input_ids": calib_flat["token_ids"],
1113+
"input_lengths": calib_input_lengths,
1114+
}
1115+
)
1116+
calibration_data.update(
1117+
calib_flat.get_multimodal_dict(as_tensors=False)
1118+
)
1119+
calibration_data.to("cpu")
1120+
kv_scales_cache = policy.calibrate_qkv_fp8_scales(
1121+
calibration_data, include_q=True
1122+
)["layers"]
1123+
10661124
refit_policy_generation(
1067-
policy, policy_generation, colocated_inference, timer=timer
1125+
policy,
1126+
policy_generation,
1127+
colocated_inference,
1128+
timer=timer,
1129+
kv_scales=kv_scales_cache if sync_kv_scales else None,
10681130
)
10691131
POLICY_GENERATION_STALE = False
10701132
else:
@@ -1277,6 +1339,19 @@ def grpo_train(
12771339
with timer.time("policy_training"):
12781340
train_results = policy.train(train_data, loss_fn)
12791341

1342+
# Recompute KV scales after policy training if needed
1343+
if sync_kv_scales:
1344+
with timer.time("recompute_kv_scales"):
1345+
print(
1346+
"▶ Recomputing KV cache scales after policy update...",
1347+
flush=True,
1348+
)
1349+
kv_scales_cache = policy.calibrate_qkv_fp8_scales(
1350+
train_data, include_q=True
1351+
)["layers"]
1352+
# Set generation as stale to force refit with new scales
1353+
POLICY_GENERATION_STALE = True
1354+
12801355
is_last_step = (total_steps + 1 >= max_num_steps) or (
12811356
(current_epoch + 1 == max_num_epochs)
12821357
and (current_step + 1 == len(dataloader))
@@ -1286,7 +1361,10 @@ def grpo_train(
12861361
if val_period > 0 and (total_steps + 1) % val_period == 0:
12871362
if NEED_REFIT and POLICY_GENERATION_STALE:
12881363
refit_policy_generation(
1289-
policy, policy_generation, colocated_inference
1364+
policy,
1365+
policy_generation,
1366+
colocated_inference,
1367+
kv_scales=kv_scales_cache if sync_kv_scales else None,
12901368
)
12911369
POLICY_GENERATION_STALE = False
12921370
else:

0 commit comments

Comments
 (0)