@@ -500,6 +500,24 @@ def init_vllm():
500500 assert loss_config ["use_importance_sampling_correction" ] is True , (
501501 "Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
502502 )
503+ if generation_config ["vllm_cfg" ]["kv_cache_dtype" ].startswith ("fp8" ):
504+ # FP8 KV cache requires FP8 model precision
505+ assert generation_config ["vllm_cfg" ]["precision" ] == "fp8" , (
506+ f"kv_cache_dtype='{ generation_config ['vllm_cfg' ]['kv_cache_dtype' ]} ' requires precision='fp8'. "
507+ "FP8 KV cache can only be used together with FP8 model weights."
508+ )
509+ # FP8 KV cache compatibility checks
510+ assert policy_config ["dtensor_cfg" ]["enabled" ] == False , (
511+ "DTensor backend is not supported with kv cache fp8 enabled."
512+ )
513+ assert not _should_use_async_rollouts (master_config ), (
514+ "Async rollouts is not supported with kv cache fp8 enabled."
515+ )
516+ assert policy_config ["megatron_cfg" ]["pipeline_model_parallel_size" ] == 1 , (
517+ "Currently when using FP8 KV cache in generation, then in megatron we only support pipeline_model_parallel_size=1. We will add more support in future."
518+ )
519+
520+ ## make vllm hf overrides match the training policy
503521 generation_config ["vllm_cfg" ]["hf_overrides" ] = policy_config .get (
504522 "hf_config_overrides" , {}
505523 )
@@ -877,6 +895,7 @@ def refit_policy_generation(
877895 colocated_inference : bool ,
878896 _refit_buffer_size_gb : Optional [int ] = None ,
879897 timer : Optional [Timer ] = None ,
898+ kv_scales : Optional [dict [str , float ]] = None ,
880899) -> None :
881900 """Refit the policy generation interface with the latest policy weights.
882901
@@ -887,6 +906,7 @@ def refit_policy_generation(
887906 If it is None, the buffer size will be computed by the remaining memory.
888907 This parameter is primarily used for testing.
889908 timer: Optional Timer used to time the prepare/transfer/update phase
909+ kv_scales: Optional dictionary of KV cache scales for FP8 quantization.
890910 """
891911 if colocated_inference :
892912 policy .offload_before_refit ()
@@ -914,7 +934,7 @@ def refit_policy_generation(
914934 )
915935
916936 futures_train = policy .stream_weights_via_ipc_zmq (
917- buffer_size_bytes = buffer_size_bytes
937+ buffer_size_bytes = buffer_size_bytes , kv_scales = kv_scales
918938 )
919939 futures_inference = policy_generation .update_weights_via_ipc_zmq ()
920940 # wait for all futures to complete
@@ -923,7 +943,7 @@ def refit_policy_generation(
923943 update_success = all (result for result in results if result is not None )
924944 else :
925945 # update weights through nccl
926- futures_train = policy .broadcast_weights_for_collective ()
946+ futures_train = policy .broadcast_weights_for_collective (kv_scales = kv_scales )
927947 futures_inference = policy_generation .update_weights_from_collective ()
928948 # wait for all futures to complete
929949 ray .get (futures_train )
@@ -973,6 +993,8 @@ def grpo_train(
973993 )
974994 timeout .start_iterations ()
975995
996+ kv_scales_cache = None # Cache reused for computed kv scales
997+
976998 NEED_REFIT = True
977999 # If policy_generation is None, use the policy as the generation interface (megatron framework backend)
9781000 if policy_generation is None :
@@ -981,6 +1003,10 @@ def grpo_train(
9811003 POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
9821004 assert policy_generation is not None # for mypy type check
9831005
1006+ # Check if we need to sync KV cache scales
1007+ # When fallback to policy as the policy_generation, we use getattr to check.
1008+ sync_kv_scales = getattr (policy_generation , "requires_kv_scale_sync" , False )
1009+
9841010 # common config/state itmes
9851011 current_step = grpo_save_state ["current_step" ] # current step within an epoch
9861012 total_steps = grpo_save_state ["total_steps" ] # total steps across all epochs
@@ -1002,6 +1028,7 @@ def grpo_train(
10021028 colocated_inference = master_config ["policy" ]["generation" ]["colocated" ]["enabled" ]
10031029
10041030 # Run validation at the start if configured
1031+ # TODO: Add validation with kv scales if needed
10051032 if val_at_start and current_step == 0 :
10061033 print ("\n 🔍 Running initial validation..." , flush = True )
10071034 if NEED_REFIT and POLICY_GENERATION_STALE :
@@ -1063,8 +1090,43 @@ def grpo_train(
10631090 )
10641091 with timer .time ("prepare_for_generation/total" ):
10651092 if NEED_REFIT and POLICY_GENERATION_STALE :
1093+ # Compute KV scales if needed for FP8 quantization
1094+ if sync_kv_scales and kv_scales_cache is None :
1095+ print ("▶ Computing KV cache scales..." , flush = True )
1096+ policy .prepare_for_lp_inference ()
1097+ # Align with training data processing to ensure parallel training compatibility
1098+ calib_flat , calib_input_lengths = (
1099+ batched_message_log_to_flat_message (
1100+ repeated_batch ["message_log" ],
1101+ pad_value_dict = {
1102+ "token_ids" : tokenizer .pad_token_id
1103+ },
1104+ make_sequence_length_divisible_by = master_config [
1105+ "policy"
1106+ ]["make_sequence_length_divisible_by" ],
1107+ )
1108+ )
1109+ # Create calibration data from flattened messages
1110+ calibration_data = BatchedDataDict [ClippedPGLossDataDict ](
1111+ {
1112+ "input_ids" : calib_flat ["token_ids" ],
1113+ "input_lengths" : calib_input_lengths ,
1114+ }
1115+ )
1116+ calibration_data .update (
1117+ calib_flat .get_multimodal_dict (as_tensors = False )
1118+ )
1119+ calibration_data .to ("cpu" )
1120+ kv_scales_cache = policy .calibrate_qkv_fp8_scales (
1121+ calibration_data , include_q = True
1122+ )["layers" ]
1123+
10661124 refit_policy_generation (
1067- policy , policy_generation , colocated_inference , timer = timer
1125+ policy ,
1126+ policy_generation ,
1127+ colocated_inference ,
1128+ timer = timer ,
1129+ kv_scales = kv_scales_cache if sync_kv_scales else None ,
10681130 )
10691131 POLICY_GENERATION_STALE = False
10701132 else :
@@ -1277,6 +1339,19 @@ def grpo_train(
12771339 with timer .time ("policy_training" ):
12781340 train_results = policy .train (train_data , loss_fn )
12791341
1342+ # Recompute KV scales after policy training if needed
1343+ if sync_kv_scales :
1344+ with timer .time ("recompute_kv_scales" ):
1345+ print (
1346+ "▶ Recomputing KV cache scales after policy update..." ,
1347+ flush = True ,
1348+ )
1349+ kv_scales_cache = policy .calibrate_qkv_fp8_scales (
1350+ train_data , include_q = True
1351+ )["layers" ]
1352+ # Set generation as stale to force refit with new scales
1353+ POLICY_GENERATION_STALE = True
1354+
12801355 is_last_step = (total_steps + 1 >= max_num_steps ) or (
12811356 (current_epoch + 1 == max_num_epochs )
12821357 and (current_step + 1 == len (dataloader ))
@@ -1286,7 +1361,10 @@ def grpo_train(
12861361 if val_period > 0 and (total_steps + 1 ) % val_period == 0 :
12871362 if NEED_REFIT and POLICY_GENERATION_STALE :
12881363 refit_policy_generation (
1289- policy , policy_generation , colocated_inference
1364+ policy ,
1365+ policy_generation ,
1366+ colocated_inference ,
1367+ kv_scales = kv_scales_cache if sync_kv_scales else None ,
12901368 )
12911369 POLICY_GENERATION_STALE = False
12921370 else :
0 commit comments