From d35b4dd734e0f9de1180a79cfd8c02dcb308bcee Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 3 Nov 2025 13:11:39 -0700 Subject: [PATCH 01/10] First commit. --- open_instruct/grpo_fast.py | 372 ++++++++++++++++++++----------------- 1 file changed, 199 insertions(+), 173 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 84abd902b..3f76505b1 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -92,6 +92,7 @@ cleanup_all_llm_judge_clients, soft_format_reward_func, ) +from open_instruct.metrics import MetricsTracker from open_instruct.model_utils import ( Batch, ModelConfig, @@ -135,6 +136,81 @@ INVALID_LOGPROB = 1.0 +class LossStatistics: + def __init__(self, num_batches: int, record_entropy: bool = False): + self.kl1_stats = torch.zeros(num_batches) + self.kl2_stats = torch.zeros(num_batches) + self.kl3_stats = torch.zeros(num_batches) + self.kl4_stats = torch.zeros(num_batches) + self.kl_loss_stats = torch.zeros(num_batches) + self.pg_clipfrac_stats = torch.zeros(num_batches) + self.pg_loss_stats = torch.zeros(num_batches) + self.loss_stats = torch.zeros(num_batches) + self.ratio_stats = torch.zeros(num_batches) + self.entropy_stats = torch.zeros(num_batches) if record_entropy else None + self.kl1 = None + self.kl2 = None + self.kl3 = None + self.kl4 = None + + def update_kl_estimates(self, ref_logprobs_diff, ratio, mb_response_masks_bool, args): + self.kl1 = ref_logprobs_diff + self.kl2 = (ref_logprobs_diff) ** 2 / 2 + self.kl3 = torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff + self.kl4 = ratio * ref_logprobs_diff + + def kl(self, args): + if args.kl_estimator == "kl1": + return self.kl1 + elif args.kl_estimator == "kl2": + return self.kl2 + elif args.kl_estimator == "kl3": + return self.kl3 + elif args.kl_estimator == "kl4": + return self.kl4 + + def update_stats( + self, i, mb_response_masks_bool, pg_losses, pg_losses2, pg_loss_max, ratio, loss, mb_entropy, args + ): + self.kl1_stats[i] = masked_mean( + self.kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ).float() + self.kl2_stats[i] = masked_mean( + self.kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ).float() + self.kl3_stats[i] = masked_mean( + self.kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ).float() + self.kl4_stats[i] = masked_mean( + self.kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ).float() + if args.kl_estimator == "kl1": + self.kl_loss_stats[i] = self.kl1_stats[i] * args.beta + elif args.kl_estimator == "kl2": + self.kl_loss_stats[i] = self.kl2_stats[i] * args.beta + elif args.kl_estimator == "kl3": + self.kl_loss_stats[i] = self.kl3_stats[i] * args.beta + elif args.kl_estimator == "kl4": + self.kl_loss_stats[i] = self.kl4_stats[i] * args.beta + self.pg_clipfrac_stats[i] = masked_mean( + (pg_losses2 > pg_losses).float(), + mb_response_masks_bool, + args.masked_mean_axis, + args.masked_mean_denominator, + ) + self.pg_loss_stats[i] = masked_mean( + pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + self.loss_stats[i] = loss + self.ratio_stats[i] = masked_mean( + ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + if args.record_entropy and self.entropy_stats is not None: + self.entropy_stats[i] = masked_mean( + mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ).float() + + class ShutdownSentinel: """Sentinel value to signal thread shutdown via queue.""" @@ -522,29 +598,60 @@ def masked_mean( return (numerator / denom).mean() -class MetricsTracker: - """A simple class to prellocate all metrics in an array - so we can do only one allreduce operation to get the metrics mean""" +def compare_vllm_logprobs_to_local(mb_new_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args, metrics_tracker): + with torch.no_grad(): + valid_mask = mb_response_masks_bool & ~torch.isnan(mb_vllm_logprobs) + logprob_diff = (mb_new_logprobs - mb_vllm_logprobs).abs() + masked_diff = torch.masked_fill(logprob_diff, ~valid_mask, 0.0) + mean_diff = masked_diff.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0 + max_diff = masked_diff.max() + std_diff = masked_diff[valid_mask].std() if valid_mask.sum() > 1 else 0.0 + reverse_kl = masked_mean( + torch.expm1(mb_vllm_logprobs - mb_new_logprobs) + (mb_vllm_logprobs - mb_new_logprobs), + mb_response_masks_bool, + args.masked_mean_axis, + args.masked_mean_denominator, + ) + metrics_tracker.add("debug/vllm_vs_local_logprob_diff_mean", mean_diff) + metrics_tracker.add("debug/vllm_vs_local_logprob_diff_max", max_diff) + metrics_tracker.add("debug/vllm_vs_local_logprob_diff_std", std_diff) + metrics_tracker.add("debug/vllm_local_reverse_kl", reverse_kl) - def __init__(self, max_metrics: int = 32, device: str = "cuda"): - self.metrics = torch.zeros(max_metrics, device=device) - self.names2idx = {} - self.current_idx = 0 - self.max_metrics = max_metrics - def add(self, name: str, value: torch.tensor): - if name not in self.names2idx: - if self.current_idx >= self.max_metrics: - raise ValueError(f"Exceeded maximum number of metrics ({self.max_metrics})") - self.names2idx[name] = self.current_idx - self.current_idx += 1 +def maybe_apply_importance_sampling( + pg_losses, pg_losses2, mb_old_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args +): + if args.truncated_importance_sampling_ratio_cap > 0 and mb_vllm_logprobs is not None: + old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB + vllm_logprobs_mask = mb_vllm_logprobs != INVALID_LOGPROB + + assert torch.all(old_logprobs_mask == mb_response_masks_bool), ( + f"Old logprobs mask should match response mask. " + f"old_mask sum={old_logprobs_mask.sum()}, " + f"response_mask sum={mb_response_masks_bool.sum()}" + ) + assert torch.all(vllm_logprobs_mask == mb_response_masks_bool), ( + f"vLLM logprobs mask should match response mask. " + f"vllm_mask sum={vllm_logprobs_mask.sum()}, " + f"response_mask sum={mb_response_masks_bool.sum()}" + ) - self.metrics[self.names2idx[name]] = value - return self + valid_mask = mb_response_masks_bool - def get_metrics_list(self) -> dict[str, float]: - metrics_list = self.metrics.tolist() - return {name: metrics_list[idx] for name, idx in self.names2idx.items()} + tis_imp_ratio = torch.ones_like(mb_old_logprobs) + + if valid_mask.any(): + logprob_diff_is = mb_old_logprobs - mb_vllm_logprobs + logprob_diff_is = torch.where( + valid_mask, logprob_diff_is.clamp(-10.0, 10.0), torch.zeros_like(logprob_diff_is) + ) + tis_imp_ratio = torch.where(valid_mask, torch.exp(logprob_diff_is), tis_imp_ratio) + tis_imp_ratio = torch.clamp(tis_imp_ratio, max=args.truncated_importance_sampling_ratio_cap) + + pg_losses = pg_losses * tis_imp_ratio + pg_losses2 = pg_losses2 * tis_imp_ratio + + return pg_losses, pg_losses2 def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: @@ -901,6 +1008,51 @@ def update_ref_policy(self): else: ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) + def calculate_loss( + self, + i, + loss_statistics, + mb_new_logprobs, + mb_old_logprobs, + mb_ref_logprob, + mb_advantages, + mb_response_masks_bool, + mb_vllm_logprobs, + mb_entropy, + accumulation_steps, + local_step, + args, + ): + logprobs_diff = mb_new_logprobs - mb_old_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantages[:, 1:] * ratio + pg_losses2 = -mb_advantages[:, 1:] * torch.clamp(ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher) + + pg_losses, pg_losses2 = maybe_apply_importance_sampling( + pg_losses, pg_losses2, mb_old_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args + ) + + pg_loss_max = torch.max(pg_losses, pg_losses2) + + ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) + loss_statistics.update_kl_estimates(ref_logprobs_diff, ratio, mb_response_masks_bool, args) + kl = loss_statistics.kl(args) + + loss = masked_mean( + pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + loss = loss / accumulation_steps + self.model.backward(loss) + if (local_step + 1) % accumulation_steps == 0: + self.model.step() + + with torch.no_grad(): + loss_statistics.update_stats( + i, mb_response_masks_bool, pg_losses, pg_losses2, pg_loss_max, ratio, loss, mb_entropy, args + ) + + return local_step + 1 + def train( self, collated_query_responses, @@ -1005,18 +1157,10 @@ def train( local_step = 0 # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): - kl1_stats = torch.zeros(len(collated_query_responses)) - kl2_stats = torch.zeros(len(collated_query_responses)) - kl3_stats = torch.zeros(len(collated_query_responses)) - kl4_stats = torch.zeros(len(collated_query_responses)) - kl_loss_stats = torch.zeros(len(collated_query_responses)) - pg_clipfrac_stats = torch.zeros(len(collated_query_responses)) - pg_loss_stats = torch.zeros(len(collated_query_responses)) - loss_stats = torch.zeros(len(collated_query_responses)) - ratio_stats = torch.zeros(len(collated_query_responses)) - entropy_stats = torch.zeros(len(collated_query_responses)) + loss_statistics = LossStatistics(len(collated_query_responses), record_entropy=args.record_entropy) for epoch_idx in range(args.num_epochs): for i in range(len(collated_query_responses)): + # mb = mini-batch mb_ref_logprob = collated_ref_logprobs[i] mb_query_responses = collated_query_responses[i] mb_tool_mask = collated_tool_masks[i] @@ -1043,23 +1187,9 @@ def train( # Replace any remaining NaN values (query tokens in packed sequences are set to NaN by pack_sequences in rl_utils.py) mb_vllm_logprobs = torch.nan_to_num(mb_vllm_logprobs, nan=INVALID_LOGPROB) - # Compare vLLM logprobs with local logprobs - with torch.no_grad(): - valid_mask = mb_response_masks_bool & ~torch.isnan(mb_vllm_logprobs) - logprob_diff = (mb_local_logprobs - mb_vllm_logprobs).abs() - masked_diff = torch.masked_fill(logprob_diff, ~valid_mask, 0.0) - mean_diff = masked_diff.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0 - max_diff = masked_diff.max() - std_diff = masked_diff[valid_mask].std() if valid_mask.sum() > 1 else 0.0 - - self.local_metrics.add("debug/vllm_vs_local_logprob_diff_mean", mean_diff.item()) - self.local_metrics.add("debug/vllm_vs_local_logprob_diff_max", max_diff.item()) - self.local_metrics.add("debug/vllm_vs_local_logprob_diff_std", std_diff.item()) - - reverse_kl = torch.exp(mb_vllm_logprobs) * (mb_vllm_logprobs - mb_local_logprobs) - masked_reverse_kl = torch.masked_fill(reverse_kl, ~valid_mask, 0.0) - mean_reverse_kl = masked_reverse_kl.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0 - self.local_metrics.add("debug/vllm_local_reverse_kl", mean_reverse_kl.item()) + compare_vllm_logprobs_to_local( + mb_local_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args, self.local_metrics + ) mb_new_logprobs = mb_local_logprobs @@ -1082,138 +1212,34 @@ def train( f"response_mask sum={mb_response_masks_bool.sum()}" ) - # Calculate the policy's loss - logprobs_diff = mb_new_logprobs - mb_old_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantages[:, 1:] * ratio - pg_losses2 = -mb_advantages[:, 1:] * torch.clamp( - ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher - ) - - # Apply truncated importance sampling if enabled - if args.truncated_importance_sampling_ratio_cap > 0 and mb_vllm_logprobs is not None: - old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB - vllm_logprobs_mask = mb_vllm_logprobs != INVALID_LOGPROB - - assert torch.all(old_logprobs_mask == mb_response_masks_bool), ( - f"Old logprobs mask should match response mask. " - f"old_mask sum={old_logprobs_mask.sum()}, " - f"response_mask sum={mb_response_masks_bool.sum()}" - ) - assert torch.all(vllm_logprobs_mask == mb_response_masks_bool), ( - f"vLLM logprobs mask should match response mask. " - f"vllm_mask sum={vllm_logprobs_mask.sum()}, " - f"response_mask sum={mb_response_masks_bool.sum()}" - ) - - valid_mask = mb_response_masks_bool - - # Initialize importance ratio to 1.0 (no effect) for all positions - tis_imp_ratio = torch.ones_like(mb_old_logprobs) - - if valid_mask.any(): - # Calculate logprob difference only for valid positions - logprob_diff_is = mb_old_logprobs - mb_vllm_logprobs - # Clamp to prevent numerical overflow in exp - logprob_diff_is = torch.where( - valid_mask, logprob_diff_is.clamp(-10.0, 10.0), torch.zeros_like(logprob_diff_is) - ) - # Compute importance ratio only for valid positions - tis_imp_ratio = torch.where(valid_mask, torch.exp(logprob_diff_is), tis_imp_ratio) - # Apply cap - tis_imp_ratio = torch.clamp( - tis_imp_ratio, max=args.truncated_importance_sampling_ratio_cap - ) - - # Apply importance sampling to losses - pg_losses = pg_losses * tis_imp_ratio - pg_losses2 = pg_losses2 * tis_imp_ratio - - pg_loss_max = torch.max(pg_losses, pg_losses2) - - # Here we recalculate kl: we want the KL loss to backpropagate through the model - # We also clamp the KL loss to avoid numerical instability - # https://chatgpt.com/share/679d0ed9-8f48-8011-926e-e274b15ae8ae - ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) - kl1 = ref_logprobs_diff - kl2 = (ref_logprobs_diff) ** 2 / 2 - kl3 = torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff # this is more numerically stable - kl4 = ratio * ref_logprobs_diff - if args.kl_estimator == "kl1": - kl = kl1 - elif args.kl_estimator == "kl2": - kl = kl2 - elif args.kl_estimator == "kl3": - kl = kl3 - elif args.kl_estimator == "kl4": - kl = kl4 - - # grpo change: directly subtract KL in loss (add) - loss = masked_mean( - pg_loss_max + (args.beta * kl), + local_step = self.calculate_loss( + i, + loss_statistics, + mb_new_logprobs, + mb_old_logprobs, + mb_ref_logprob, + mb_advantages, mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, + mb_vllm_logprobs, + mb_entropy, + accumulation_steps, + local_step, + args, ) - loss = loss / accumulation_steps - self.model.backward(loss) - if (local_step + 1) % accumulation_steps == 0: - self.model.step() - local_step += 1 - with torch.no_grad(): - # NOTE: in packed implementation, kl calculation are averages over response tokens - kl1_stats[i] = masked_mean( - kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl2_stats[i] = masked_mean( - kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl3_stats[i] = masked_mean( - kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl4_stats[i] = masked_mean( - kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - if args.kl_estimator == "kl1": - kl_loss_stats[i] = kl1_stats[i] * args.beta - elif args.kl_estimator == "kl2": - kl_loss_stats[i] = kl2_stats[i] * args.beta - elif args.kl_estimator == "kl3": - kl_loss_stats[i] = kl3_stats[i] * args.beta - elif args.kl_estimator == "kl4": - kl_loss_stats[i] = kl4_stats[i] * args.beta - pg_clipfrac_stats[i] = masked_mean( - (pg_losses2 > pg_losses).float(), - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, - ) - pg_loss_stats[i] = masked_mean( - pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - loss_stats[i] = loss - ratio_stats[i] = masked_mean( - ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - if args.record_entropy: - # Calculate entropy statistics - entropy_stats[i] = masked_mean( - mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() with torch.no_grad(): - self.local_metrics.add("objective/kl_avg", kl1_stats.mean()) - self.local_metrics.add("objective/kl2_avg", kl2_stats.mean()) - self.local_metrics.add("objective/kl3_avg", kl3_stats.mean()) - self.local_metrics.add("objective/kl4_avg", kl4_stats.mean()) - self.local_metrics.add("loss/policy_avg", pg_loss_stats.mean()) - self.local_metrics.add("loss/kl_avg", kl_loss_stats.mean()) - self.local_metrics.add("loss/total_avg", loss_stats.mean()) - self.local_metrics.add("policy/clipfrac_avg", pg_clipfrac_stats.mean()) - self.local_metrics.add("val/ratio", ratio_stats.mean()) - self.local_metrics.add("val/ratio_var", ratio_stats.var()) + self.local_metrics.add("objective/kl_avg", loss_statistics.kl1_stats.mean()) + self.local_metrics.add("objective/kl2_avg", loss_statistics.kl2_stats.mean()) + self.local_metrics.add("objective/kl3_avg", loss_statistics.kl3_stats.mean()) + self.local_metrics.add("objective/kl4_avg", loss_statistics.kl4_stats.mean()) + self.local_metrics.add("loss/policy_avg", loss_statistics.pg_loss_stats.mean()) + self.local_metrics.add("loss/kl_avg", loss_statistics.kl_loss_stats.mean()) + self.local_metrics.add("loss/total_avg", loss_statistics.loss_stats.mean()) + self.local_metrics.add("policy/clipfrac_avg", loss_statistics.pg_clipfrac_stats.mean()) + self.local_metrics.add("val/ratio", loss_statistics.ratio_stats.mean()) + self.local_metrics.add("val/ratio_var", loss_statistics.ratio_stats.var()) if args.record_entropy: - self.local_metrics.add("policy/entropy_avg", entropy_stats.mean()) + self.local_metrics.add("policy/entropy_avg", loss_statistics.entropy_stats.mean()) self.local_metrics.add("lr", self.scheduler.get_last_lr()[0]) return self.local_metrics.get_metrics_list() From 2d450c26c460e63726bc654357ed3905afc34804 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 3 Nov 2025 13:26:01 -0700 Subject: [PATCH 02/10] Cleaned up code --- open_instruct/grpo_fast.py | 88 +++----------------------------------- 1 file changed, 6 insertions(+), 82 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3f76505b1..a08166b57 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -92,7 +92,7 @@ cleanup_all_llm_judge_clients, soft_format_reward_func, ) -from open_instruct.metrics import MetricsTracker +from open_instruct.metrics import LossStatistics, MetricsTracker from open_instruct.model_utils import ( Batch, ModelConfig, @@ -136,81 +136,6 @@ INVALID_LOGPROB = 1.0 -class LossStatistics: - def __init__(self, num_batches: int, record_entropy: bool = False): - self.kl1_stats = torch.zeros(num_batches) - self.kl2_stats = torch.zeros(num_batches) - self.kl3_stats = torch.zeros(num_batches) - self.kl4_stats = torch.zeros(num_batches) - self.kl_loss_stats = torch.zeros(num_batches) - self.pg_clipfrac_stats = torch.zeros(num_batches) - self.pg_loss_stats = torch.zeros(num_batches) - self.loss_stats = torch.zeros(num_batches) - self.ratio_stats = torch.zeros(num_batches) - self.entropy_stats = torch.zeros(num_batches) if record_entropy else None - self.kl1 = None - self.kl2 = None - self.kl3 = None - self.kl4 = None - - def update_kl_estimates(self, ref_logprobs_diff, ratio, mb_response_masks_bool, args): - self.kl1 = ref_logprobs_diff - self.kl2 = (ref_logprobs_diff) ** 2 / 2 - self.kl3 = torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff - self.kl4 = ratio * ref_logprobs_diff - - def kl(self, args): - if args.kl_estimator == "kl1": - return self.kl1 - elif args.kl_estimator == "kl2": - return self.kl2 - elif args.kl_estimator == "kl3": - return self.kl3 - elif args.kl_estimator == "kl4": - return self.kl4 - - def update_stats( - self, i, mb_response_masks_bool, pg_losses, pg_losses2, pg_loss_max, ratio, loss, mb_entropy, args - ): - self.kl1_stats[i] = masked_mean( - self.kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - self.kl2_stats[i] = masked_mean( - self.kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - self.kl3_stats[i] = masked_mean( - self.kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - self.kl4_stats[i] = masked_mean( - self.kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - if args.kl_estimator == "kl1": - self.kl_loss_stats[i] = self.kl1_stats[i] * args.beta - elif args.kl_estimator == "kl2": - self.kl_loss_stats[i] = self.kl2_stats[i] * args.beta - elif args.kl_estimator == "kl3": - self.kl_loss_stats[i] = self.kl3_stats[i] * args.beta - elif args.kl_estimator == "kl4": - self.kl_loss_stats[i] = self.kl4_stats[i] * args.beta - self.pg_clipfrac_stats[i] = masked_mean( - (pg_losses2 > pg_losses).float(), - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, - ) - self.pg_loss_stats[i] = masked_mean( - pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - self.loss_stats[i] = loss - self.ratio_stats[i] = masked_mean( - ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - if args.record_entropy and self.entropy_stats is not None: - self.entropy_stats[i] = masked_mean( - mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - - class ShutdownSentinel: """Sentinel value to signal thread shutdown via queue.""" @@ -1035,8 +960,7 @@ def calculate_loss( pg_loss_max = torch.max(pg_losses, pg_losses2) ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) - loss_statistics.update_kl_estimates(ref_logprobs_diff, ratio, mb_response_masks_bool, args) - kl = loss_statistics.kl(args) + kl = loss_statistics.update_kl_estimates(i, ref_logprobs_diff, ratio, mb_response_masks_bool, args) loss = masked_mean( pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator @@ -1228,10 +1152,10 @@ def train( ) with torch.no_grad(): - self.local_metrics.add("objective/kl_avg", loss_statistics.kl1_stats.mean()) - self.local_metrics.add("objective/kl2_avg", loss_statistics.kl2_stats.mean()) - self.local_metrics.add("objective/kl3_avg", loss_statistics.kl3_stats.mean()) - self.local_metrics.add("objective/kl4_avg", loss_statistics.kl4_stats.mean()) + self.local_metrics.add("objective/kl_avg", loss_statistics.kl_stats[0].mean()) + self.local_metrics.add("objective/kl2_avg", loss_statistics.kl_stats[1].mean()) + self.local_metrics.add("objective/kl3_avg", loss_statistics.kl_stats[2].mean()) + self.local_metrics.add("objective/kl4_avg", loss_statistics.kl_stats[3].mean()) self.local_metrics.add("loss/policy_avg", loss_statistics.pg_loss_stats.mean()) self.local_metrics.add("loss/kl_avg", loss_statistics.kl_loss_stats.mean()) self.local_metrics.add("loss/total_avg", loss_statistics.loss_stats.mean()) From 41bdc23cf1ce318177fa7b052a7eb3e075777646 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 3 Nov 2025 13:26:54 -0700 Subject: [PATCH 03/10] Added metrics class --- open_instruct/metrics.py | 90 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 open_instruct/metrics.py diff --git a/open_instruct/metrics.py b/open_instruct/metrics.py new file mode 100644 index 000000000..3147dd699 --- /dev/null +++ b/open_instruct/metrics.py @@ -0,0 +1,90 @@ +import torch + +# TODO: Add docstrings to MetricsTracker, LossStatistics, masked_mean, and all methods +# added in this refactoring branch (compare_vllm_logprobs_to_local, +# maybe_apply_importance_sampling, calculate_loss) + + +def masked_mean( + values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None +) -> torch.Tensor: + numerator = (values * mask).sum(axis=axis) + denom = mask.sum(axis=axis) if denominator is None else denominator + return (numerator / denom).mean() + + +class LossStatistics: + def __init__(self, num_batches: int, record_entropy: bool = False): + self.kl_stats = torch.zeros(4, num_batches) + self.kl_loss_stats = torch.zeros(num_batches) + self.pg_clipfrac_stats = torch.zeros(num_batches) + self.pg_loss_stats = torch.zeros(num_batches) + self.loss_stats = torch.zeros(num_batches) + self.ratio_stats = torch.zeros(num_batches) + self.entropy_stats = torch.zeros(num_batches) if record_entropy else None + + def update_kl_estimates(self, i, ref_logprobs_diff, ratio, mb_response_masks_bool, args): + kl_values = torch.stack( + [ + ref_logprobs_diff, + ref_logprobs_diff**2 / 2, + torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff, + ratio * ref_logprobs_diff, + ] + ) + + vmapped_fn = torch.vmap( + lambda v: masked_mean(v, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator) + ) + self.kl_stats[:, i] = vmapped_fn(kl_values).float() + + kl_idx = {"kl1": 0, "kl2": 1, "kl3": 2, "kl4": 3}[args.kl_estimator] + return kl_values[kl_idx] + + def update_stats( + self, i, mb_response_masks_bool, pg_losses, pg_losses2, pg_loss_max, ratio, loss, mb_entropy, args + ): + kl_idx = {"kl1": 0, "kl2": 1, "kl3": 2, "kl4": 3}[args.kl_estimator] + self.kl_loss_stats[i] = self.kl_stats[kl_idx, i] * args.beta + self.pg_clipfrac_stats[i] = masked_mean( + (pg_losses2 > pg_losses).float(), + mb_response_masks_bool, + args.masked_mean_axis, + args.masked_mean_denominator, + ) + self.pg_loss_stats[i] = masked_mean( + pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + self.loss_stats[i] = loss + self.ratio_stats[i] = masked_mean( + ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + if args.record_entropy and self.entropy_stats is not None: + self.entropy_stats[i] = masked_mean( + mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ).float() + + +class MetricsTracker: + """A simple class to prellocate all metrics in an array + so we can do only one allreduce operation to get the metrics mean""" + + def __init__(self, max_metrics: int = 32, device: str = "cuda"): + self.metrics = torch.zeros(max_metrics, device=device) + self.names2idx = {} + self.current_idx = 0 + self.max_metrics = max_metrics + + def add(self, name: str, value: torch.tensor): + if name not in self.names2idx: + if self.current_idx >= self.max_metrics: + raise ValueError(f"Exceeded maximum number of metrics ({self.max_metrics})") + self.names2idx[name] = self.current_idx + self.current_idx += 1 + + self.metrics[self.names2idx[name]] = value + return self + + def get_metrics_list(self) -> dict[str, float]: + metrics_list = self.metrics.tolist() + return {name: metrics_list[idx] for name, idx in self.names2idx.items()} From e8b6f32f9c592bfff7198d00fa614b8909997277 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 3 Nov 2025 13:47:06 -0700 Subject: [PATCH 04/10] Updated code --- open_instruct/grpo_fast.py | 13 +------------ open_instruct/metrics.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index a08166b57..763dcc9e3 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1152,18 +1152,7 @@ def train( ) with torch.no_grad(): - self.local_metrics.add("objective/kl_avg", loss_statistics.kl_stats[0].mean()) - self.local_metrics.add("objective/kl2_avg", loss_statistics.kl_stats[1].mean()) - self.local_metrics.add("objective/kl3_avg", loss_statistics.kl_stats[2].mean()) - self.local_metrics.add("objective/kl4_avg", loss_statistics.kl_stats[3].mean()) - self.local_metrics.add("loss/policy_avg", loss_statistics.pg_loss_stats.mean()) - self.local_metrics.add("loss/kl_avg", loss_statistics.kl_loss_stats.mean()) - self.local_metrics.add("loss/total_avg", loss_statistics.loss_stats.mean()) - self.local_metrics.add("policy/clipfrac_avg", loss_statistics.pg_clipfrac_stats.mean()) - self.local_metrics.add("val/ratio", loss_statistics.ratio_stats.mean()) - self.local_metrics.add("val/ratio_var", loss_statistics.ratio_stats.var()) - if args.record_entropy: - self.local_metrics.add("policy/entropy_avg", loss_statistics.entropy_stats.mean()) + self.local_metrics.add_dict(loss_statistics.to_dict()) self.local_metrics.add("lr", self.scheduler.get_last_lr()[0]) return self.local_metrics.get_metrics_list() diff --git a/open_instruct/metrics.py b/open_instruct/metrics.py index 3147dd699..a965d805f 100644 --- a/open_instruct/metrics.py +++ b/open_instruct/metrics.py @@ -64,6 +64,23 @@ def update_stats( mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator ).float() + def to_dict(self) -> dict[str, torch.Tensor]: + metrics = { + "objective/kl_avg": self.kl_stats[0].mean(), + "objective/kl2_avg": self.kl_stats[1].mean(), + "objective/kl3_avg": self.kl_stats[2].mean(), + "objective/kl4_avg": self.kl_stats[3].mean(), + "loss/policy_avg": self.pg_loss_stats.mean(), + "loss/kl_avg": self.kl_loss_stats.mean(), + "loss/total_avg": self.loss_stats.mean(), + "policy/clipfrac_avg": self.pg_clipfrac_stats.mean(), + "val/ratio": self.ratio_stats.mean(), + "val/ratio_var": self.ratio_stats.var(), + } + if self.entropy_stats is not None: + metrics["policy/entropy_avg"] = self.entropy_stats.mean() + return metrics + class MetricsTracker: """A simple class to prellocate all metrics in an array @@ -85,6 +102,11 @@ def add(self, name: str, value: torch.tensor): self.metrics[self.names2idx[name]] = value return self + def add_dict(self, metrics_dict: dict[str, torch.Tensor]): + for k, v in metrics_dict.items(): + self.add(k, v) + return self + def get_metrics_list(self) -> dict[str, float]: metrics_list = self.metrics.tolist() return {name: metrics_list[idx] for name, idx in self.names2idx.items()} From d677b337bd3f900403309502cc40083341553e44 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 5 Nov 2025 11:08:38 -0700 Subject: [PATCH 05/10] cleaned up code --- open_instruct/grpo_fast.py | 228 ++++++++++++++++++++++++------------- open_instruct/metrics.py | 104 ++++++++++++++++- 2 files changed, 248 insertions(+), 84 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 763dcc9e3..82fabf185 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -523,17 +523,35 @@ def masked_mean( return (numerator / denom).mean() -def compare_vllm_logprobs_to_local(mb_new_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args, metrics_tracker): +def compare_logprobs( + new_logprobs: torch.Tensor, + old_logprobs: torch.Tensor, + mask: torch.Tensor, + args: Namespace, + metrics_tracker: MetricsTracker, +) -> None: + """Compare locally computed log probabilities with reference log probabilities. + + Computes statistics on the difference between two sets of log probabilities and records + debugging metrics including mean/max/std differences and reverse KL divergence. + + Args: + new_logprobs: Locally computed log probabilities (shape: [batch, seq_len]) + old_logprobs: Reference log probabilities from behavior policy (shape: [batch, seq_len]) + mask: Boolean mask indicating valid response tokens (shape: [batch, seq_len]) + args: Training arguments containing masked_mean_axis and masked_mean_denominator + metrics_tracker: MetricsTracker instance for recording debug metrics + """ with torch.no_grad(): - valid_mask = mb_response_masks_bool & ~torch.isnan(mb_vllm_logprobs) - logprob_diff = (mb_new_logprobs - mb_vllm_logprobs).abs() + valid_mask = mask & ~torch.isnan(old_logprobs) + logprob_diff = (new_logprobs - old_logprobs).abs() masked_diff = torch.masked_fill(logprob_diff, ~valid_mask, 0.0) mean_diff = masked_diff.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0 max_diff = masked_diff.max() std_diff = masked_diff[valid_mask].std() if valid_mask.sum() > 1 else 0.0 reverse_kl = masked_mean( - torch.expm1(mb_vllm_logprobs - mb_new_logprobs) + (mb_vllm_logprobs - mb_new_logprobs), - mb_response_masks_bool, + torch.expm1(old_logprobs - new_logprobs) + (old_logprobs - new_logprobs), + mask, args.masked_mean_axis, args.masked_mean_denominator, ) @@ -544,41 +562,138 @@ def compare_vllm_logprobs_to_local(mb_new_logprobs, mb_vllm_logprobs, mb_respons def maybe_apply_importance_sampling( - pg_losses, pg_losses2, mb_old_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args -): - if args.truncated_importance_sampling_ratio_cap > 0 and mb_vllm_logprobs is not None: - old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB - vllm_logprobs_mask = mb_vllm_logprobs != INVALID_LOGPROB - - assert torch.all(old_logprobs_mask == mb_response_masks_bool), ( - f"Old logprobs mask should match response mask. " - f"old_mask sum={old_logprobs_mask.sum()}, " - f"response_mask sum={mb_response_masks_bool.sum()}" - ) - assert torch.all(vllm_logprobs_mask == mb_response_masks_bool), ( - f"vLLM logprobs mask should match response mask. " - f"vllm_mask sum={vllm_logprobs_mask.sum()}, " - f"response_mask sum={mb_response_masks_bool.sum()}" - ) + pg_losses: torch.Tensor, + pg_losses2: torch.Tensor, + old_logprobs: torch.Tensor, + vllm_logprobs: torch.Tensor | None, + response_mask: torch.Tensor, + args: Namespace, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply truncated importance sampling (TIS) to policy gradient losses if enabled. + + Importance sampling corrects for the distribution mismatch between the behavior policy + (vLLM inference) and the current policy. The importance ratio is capped to prevent + high variance in gradient estimates. - valid_mask = mb_response_masks_bool + Args: + pg_losses: Unclipped policy gradient losses (shape: [batch, seq_len]) + pg_losses2: Clipped policy gradient losses (shape: [batch, seq_len]) + old_logprobs: Log probabilities from current policy (shape: [batch, seq_len]) + vllm_logprobs: Log probabilities from behavior policy, or None (shape: [batch, seq_len]) + response_mask: Boolean mask indicating valid response tokens (shape: [batch, seq_len]) + args: Training arguments containing truncated_importance_sampling_ratio_cap - tis_imp_ratio = torch.ones_like(mb_old_logprobs) + Returns: + Tuple of (potentially scaled pg_losses, potentially scaled pg_losses2) + """ + if args.truncated_importance_sampling_ratio_cap <= 0 or vllm_logprobs is None: + return pg_losses, pg_losses2 - if valid_mask.any(): - logprob_diff_is = mb_old_logprobs - mb_vllm_logprobs - logprob_diff_is = torch.where( - valid_mask, logprob_diff_is.clamp(-10.0, 10.0), torch.zeros_like(logprob_diff_is) - ) - tis_imp_ratio = torch.where(valid_mask, torch.exp(logprob_diff_is), tis_imp_ratio) - tis_imp_ratio = torch.clamp(tis_imp_ratio, max=args.truncated_importance_sampling_ratio_cap) + old_logprobs_mask = old_logprobs != INVALID_LOGPROB + vllm_logprobs_mask = vllm_logprobs != INVALID_LOGPROB - pg_losses = pg_losses * tis_imp_ratio - pg_losses2 = pg_losses2 * tis_imp_ratio + assert torch.allclose(old_logprobs_mask.float(), response_mask.float()), ( + f"Old logprobs mask should match response mask. " + f"old_mask sum={old_logprobs_mask.sum()}, " + f"response_mask sum={response_mask.sum()}" + ) + assert torch.allclose(vllm_logprobs_mask.float(), response_mask.float()), ( + f"vLLM logprobs mask should match response mask. " + f"vllm_mask sum={vllm_logprobs_mask.sum()}, " + f"response_mask sum={response_mask.sum()}" + ) + + valid_mask = response_mask + importance_ratio = torch.ones_like(old_logprobs) + + if valid_mask.any(): + logprob_diff = old_logprobs - vllm_logprobs + logprob_diff = torch.where(valid_mask, logprob_diff.clamp(-10.0, 10.0), torch.zeros_like(logprob_diff)) + importance_ratio = torch.where(valid_mask, torch.exp(logprob_diff), importance_ratio) + importance_ratio = torch.clamp(importance_ratio, max=args.truncated_importance_sampling_ratio_cap) + + pg_losses = pg_losses * importance_ratio + pg_losses2 = pg_losses2 * importance_ratio return pg_losses, pg_losses2 +def calculate_loss( + model: Any, + i: int, + loss_statistics: LossStatistics, + mb_new_logprobs: torch.Tensor, + mb_old_logprobs: torch.Tensor, + mb_ref_logprob: torch.Tensor, + mb_advantages: torch.Tensor, + mb_response_masks_bool: torch.Tensor, + mb_vllm_logprobs: torch.Tensor | None, + mb_entropy: torch.Tensor, + accumulation_steps: int, + local_step: int, + args: Namespace, +) -> int: + """Calculate and apply GRPO loss for a single minibatch. + + Computes the policy gradient loss using the clipped surrogate objective from PPO, + combines it with a KL penalty term, performs the backward pass, and optionally + steps the optimizer. + + Args: + model: Model wrapper with backward() and step() methods (e.g., DeepSpeed engine) + i: Minibatch index for tracking statistics + loss_statistics: LossStatistics object to accumulate training metrics + mb_new_logprobs: Log probabilities from current policy (shape: [batch, seq_len]) + mb_old_logprobs: Log probabilities from old/cached policy (shape: [batch, seq_len]) + mb_ref_logprob: Log probabilities from reference model (shape: [batch, seq_len]) + mb_advantages: Advantage estimates for policy gradient (shape: [batch, seq_len+1]) + mb_response_masks_bool: Boolean mask for valid response tokens (shape: [batch, seq_len]) + mb_vllm_logprobs: Log probabilities from behavior policy, or None (shape: [batch, seq_len]) + mb_entropy: Entropy of the policy distribution (shape: [batch, seq_len]) + accumulation_steps: Number of gradient accumulation steps before optimizer update + local_step: Current local training step (used to determine when to step optimizer) + args: Training arguments containing hyperparameters + + Returns: + Updated local_step (incremented by 1) + """ + logprobs_diff = mb_new_logprobs - mb_old_logprobs + ratio = torch.exp(logprobs_diff) + + # PPO clipped surrogate objective: we compute two losses and take the element-wise maximum. + # - unclipped_pg_loss: standard policy gradient loss using the raw importance ratio + # - clipped_pg_loss: policy gradient loss with the ratio clipped to prevent large updates + # Taking the maximum implements a pessimistic bound that prevents the policy from + # deviating too far from the old policy. The clipfrac metric tracks how often clipping + # is active (clipped_pg_loss > unclipped_pg_loss), which indicates constraint saturation. + unclipped_pg_loss = -mb_advantages[:, 1:] * ratio + clipped_pg_loss = -mb_advantages[:, 1:] * torch.clamp(ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher) + + unclipped_pg_loss, clipped_pg_loss = maybe_apply_importance_sampling( + unclipped_pg_loss, clipped_pg_loss, mb_old_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args + ) + + pg_loss_max = torch.max(unclipped_pg_loss, clipped_pg_loss) + + ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) + kl = loss_statistics.update_kl_estimates(i, ref_logprobs_diff, ratio, mb_response_masks_bool, args) + + loss = masked_mean( + pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + loss = loss / accumulation_steps + model.backward(loss) + if (local_step + 1) % accumulation_steps == 0: + model.step() + + with torch.no_grad(): + loss_statistics.update_stats( + i, mb_response_masks_bool, unclipped_pg_loss, clipped_pg_loss, pg_loss_max, ratio, loss, mb_entropy, args + ) + + return local_step + 1 + + def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) if pin_memory: @@ -933,50 +1048,6 @@ def update_ref_policy(self): else: ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) - def calculate_loss( - self, - i, - loss_statistics, - mb_new_logprobs, - mb_old_logprobs, - mb_ref_logprob, - mb_advantages, - mb_response_masks_bool, - mb_vllm_logprobs, - mb_entropy, - accumulation_steps, - local_step, - args, - ): - logprobs_diff = mb_new_logprobs - mb_old_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantages[:, 1:] * ratio - pg_losses2 = -mb_advantages[:, 1:] * torch.clamp(ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher) - - pg_losses, pg_losses2 = maybe_apply_importance_sampling( - pg_losses, pg_losses2, mb_old_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args - ) - - pg_loss_max = torch.max(pg_losses, pg_losses2) - - ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) - kl = loss_statistics.update_kl_estimates(i, ref_logprobs_diff, ratio, mb_response_masks_bool, args) - - loss = masked_mean( - pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - loss = loss / accumulation_steps - self.model.backward(loss) - if (local_step + 1) % accumulation_steps == 0: - self.model.step() - - with torch.no_grad(): - loss_statistics.update_stats( - i, mb_response_masks_bool, pg_losses, pg_losses2, pg_loss_max, ratio, loss, mb_entropy, args - ) - - return local_step + 1 - def train( self, collated_query_responses, @@ -1111,7 +1182,7 @@ def train( # Replace any remaining NaN values (query tokens in packed sequences are set to NaN by pack_sequences in rl_utils.py) mb_vllm_logprobs = torch.nan_to_num(mb_vllm_logprobs, nan=INVALID_LOGPROB) - compare_vllm_logprobs_to_local( + compare_logprobs( mb_local_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args, self.local_metrics ) @@ -1136,7 +1207,8 @@ def train( f"response_mask sum={mb_response_masks_bool.sum()}" ) - local_step = self.calculate_loss( + local_step = calculate_loss( + self.model, i, loss_statistics, mb_new_logprobs, diff --git a/open_instruct/metrics.py b/open_instruct/metrics.py index a965d805f..52e50d9d7 100644 --- a/open_instruct/metrics.py +++ b/open_instruct/metrics.py @@ -1,20 +1,41 @@ import torch -# TODO: Add docstrings to MetricsTracker, LossStatistics, masked_mean, and all methods -# added in this refactoring branch (compare_vllm_logprobs_to_local, -# maybe_apply_importance_sampling, calculate_loss) - def masked_mean( values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None ) -> torch.Tensor: + """Compute the mean of tensor values considering only masked (valid) positions. + + Args: + values: Input tensor to compute mean over + mask: Boolean or binary mask tensor (same shape as values) indicating valid positions + axis: Axis along which to sum before computing mean, or None for all axes + denominator: Optional fixed denominator to use instead of mask.sum(). Useful when + the denominator should be consistent across batches. + + Returns: + Masked mean of the input values as a scalar tensor + """ numerator = (values * mask).sum(axis=axis) denom = mask.sum(axis=axis) if denominator is None else denominator return (numerator / denom).mean() class LossStatistics: + """Accumulates training statistics across minibatches for GRPO training. + + Tracks KL divergence estimates, policy gradient losses, clipping statistics, + and importance ratios across multiple minibatches. Provides methods to update + statistics and convert accumulated values to a metrics dictionary. + """ + def __init__(self, num_batches: int, record_entropy: bool = False): + """Initialize loss statistics storage. + + Args: + num_batches: Number of minibatches to track statistics for + record_entropy: Whether to track policy entropy statistics + """ self.kl_stats = torch.zeros(4, num_batches) self.kl_loss_stats = torch.zeros(num_batches) self.pg_clipfrac_stats = torch.zeros(num_batches) @@ -24,6 +45,21 @@ def __init__(self, num_batches: int, record_entropy: bool = False): self.entropy_stats = torch.zeros(num_batches) if record_entropy else None def update_kl_estimates(self, i, ref_logprobs_diff, ratio, mb_response_masks_bool, args): + """Compute and store KL divergence estimates for a minibatch. + + Computes four different KL estimators (kl1-kl4) based on log probability + differences between the current policy and reference policy. + + Args: + i: Minibatch index + ref_logprobs_diff: Log probability differences (new - ref) [batch, seq_len] + ratio: Importance ratio exp(new_logprobs - old_logprobs) [batch, seq_len] + mb_response_masks_bool: Boolean mask for valid response tokens [batch, seq_len] + args: Training arguments containing kl_estimator, masked_mean settings + + Returns: + KL divergence values for the selected estimator (shape: [batch, seq_len]) + """ kl_values = torch.stack( [ ref_logprobs_diff, @@ -44,6 +80,19 @@ def update_kl_estimates(self, i, ref_logprobs_diff, ratio, mb_response_masks_boo def update_stats( self, i, mb_response_masks_bool, pg_losses, pg_losses2, pg_loss_max, ratio, loss, mb_entropy, args ): + """Update all training statistics for a minibatch. + + Args: + i: Minibatch index + mb_response_masks_bool: Boolean mask for valid response tokens [batch, seq_len] + pg_losses: Unclipped policy gradient losses [batch, seq_len] + pg_losses2: Clipped policy gradient losses [batch, seq_len] + pg_loss_max: Element-wise max of pg_losses and pg_losses2 [batch, seq_len] + ratio: Importance ratio [batch, seq_len] + loss: Total loss value (scalar) + mb_entropy: Policy entropy [batch, seq_len] + args: Training arguments containing beta, record_entropy, masked_mean settings + """ kl_idx = {"kl1": 0, "kl2": 1, "kl3": 2, "kl4": 3}[args.kl_estimator] self.kl_loss_stats[i] = self.kl_stats[kl_idx, i] * args.beta self.pg_clipfrac_stats[i] = masked_mean( @@ -65,6 +114,11 @@ def update_stats( ).float() def to_dict(self) -> dict[str, torch.Tensor]: + """Convert accumulated statistics to a metrics dictionary. + + Returns: + Dictionary mapping metric names to their averaged values across all minibatches + """ metrics = { "objective/kl_avg": self.kl_stats[0].mean(), "objective/kl2_avg": self.kl_stats[1].mean(), @@ -83,16 +137,41 @@ def to_dict(self) -> dict[str, torch.Tensor]: class MetricsTracker: - """A simple class to prellocate all metrics in an array - so we can do only one allreduce operation to get the metrics mean""" + """Preallocated tensor-based metrics storage for efficient distributed reduction. + + Stores all metrics in a single preallocated tensor to enable efficient all-reduce + operations in distributed training. Maintains a mapping from metric names to + tensor indices for fast access. + """ def __init__(self, max_metrics: int = 32, device: str = "cuda"): + """Initialize metrics tracker. + + Args: + max_metrics: Maximum number of unique metrics to track + device: Device to allocate metrics tensor on (default: "cuda") + """ self.metrics = torch.zeros(max_metrics, device=device) self.names2idx = {} self.current_idx = 0 self.max_metrics = max_metrics def add(self, name: str, value: torch.tensor): + """Add or update a metric value. + + If the metric name is new, allocates a new index in the metrics tensor. + If the metric already exists, updates its value at the existing index. + + Args: + name: Metric name (e.g., "loss/policy_avg") + value: Metric value (scalar tensor or convertible to tensor) + + Returns: + Self for method chaining + + Raises: + ValueError: If max_metrics limit is exceeded + """ if name not in self.names2idx: if self.current_idx >= self.max_metrics: raise ValueError(f"Exceeded maximum number of metrics ({self.max_metrics})") @@ -103,10 +182,23 @@ def add(self, name: str, value: torch.tensor): return self def add_dict(self, metrics_dict: dict[str, torch.Tensor]): + """Add multiple metrics from a dictionary. + + Args: + metrics_dict: Dictionary mapping metric names to values + + Returns: + Self for method chaining + """ for k, v in metrics_dict.items(): self.add(k, v) return self def get_metrics_list(self) -> dict[str, float]: + """Convert tracked metrics to a dictionary of Python floats. + + Returns: + Dictionary mapping metric names to their float values + """ metrics_list = self.metrics.tolist() return {name: metrics_list[idx] for name, idx in self.names2idx.items()} From 4f152185ed92488dad72ae0e5dfe9399180fe3b8 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 5 Nov 2025 12:06:02 -0700 Subject: [PATCH 06/10] Cleaned up code --- open_instruct/grpo_fast.py | 89 ++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 42 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 82fabf185..0e6e8bdab 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -527,7 +527,8 @@ def compare_logprobs( new_logprobs: torch.Tensor, old_logprobs: torch.Tensor, mask: torch.Tensor, - args: Namespace, + masked_mean_axis: int | None, + masked_mean_denominator: float | None, metrics_tracker: MetricsTracker, ) -> None: """Compare locally computed log probabilities with reference log probabilities. @@ -539,7 +540,8 @@ def compare_logprobs( new_logprobs: Locally computed log probabilities (shape: [batch, seq_len]) old_logprobs: Reference log probabilities from behavior policy (shape: [batch, seq_len]) mask: Boolean mask indicating valid response tokens (shape: [batch, seq_len]) - args: Training arguments containing masked_mean_axis and masked_mean_denominator + masked_mean_axis: Axis for masked mean reduction + masked_mean_denominator: Denominator for masked mean computation metrics_tracker: MetricsTracker instance for recording debug metrics """ with torch.no_grad(): @@ -552,8 +554,8 @@ def compare_logprobs( reverse_kl = masked_mean( torch.expm1(old_logprobs - new_logprobs) + (old_logprobs - new_logprobs), mask, - args.masked_mean_axis, - args.masked_mean_denominator, + masked_mean_axis, + masked_mean_denominator, ) metrics_tracker.add("debug/vllm_vs_local_logprob_diff_mean", mean_diff) metrics_tracker.add("debug/vllm_vs_local_logprob_diff_max", max_diff) @@ -562,12 +564,12 @@ def compare_logprobs( def maybe_apply_importance_sampling( - pg_losses: torch.Tensor, - pg_losses2: torch.Tensor, + unclipped_pg_loss: torch.Tensor, + clipped_pg_loss: torch.Tensor, old_logprobs: torch.Tensor, vllm_logprobs: torch.Tensor | None, response_mask: torch.Tensor, - args: Namespace, + args: Args, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply truncated importance sampling (TIS) to policy gradient losses if enabled. @@ -576,18 +578,18 @@ def maybe_apply_importance_sampling( high variance in gradient estimates. Args: - pg_losses: Unclipped policy gradient losses (shape: [batch, seq_len]) - pg_losses2: Clipped policy gradient losses (shape: [batch, seq_len]) + unclipped_pg_loss: Unclipped policy gradient losses (shape: [batch, seq_len]) + clipped_pg_loss: Clipped policy gradient losses (shape: [batch, seq_len]) old_logprobs: Log probabilities from current policy (shape: [batch, seq_len]) vllm_logprobs: Log probabilities from behavior policy, or None (shape: [batch, seq_len]) response_mask: Boolean mask indicating valid response tokens (shape: [batch, seq_len]) args: Training arguments containing truncated_importance_sampling_ratio_cap Returns: - Tuple of (potentially scaled pg_losses, potentially scaled pg_losses2) + Tuple of (potentially scaled unclipped_pg_loss, potentially scaled clipped_pg_loss) """ if args.truncated_importance_sampling_ratio_cap <= 0 or vllm_logprobs is None: - return pg_losses, pg_losses2 + return unclipped_pg_loss, clipped_pg_loss old_logprobs_mask = old_logprobs != INVALID_LOGPROB vllm_logprobs_mask = vllm_logprobs != INVALID_LOGPROB @@ -612,26 +614,26 @@ def maybe_apply_importance_sampling( importance_ratio = torch.where(valid_mask, torch.exp(logprob_diff), importance_ratio) importance_ratio = torch.clamp(importance_ratio, max=args.truncated_importance_sampling_ratio_cap) - pg_losses = pg_losses * importance_ratio - pg_losses2 = pg_losses2 * importance_ratio + unclipped_pg_loss = unclipped_pg_loss * importance_ratio + clipped_pg_loss = clipped_pg_loss * importance_ratio - return pg_losses, pg_losses2 + return unclipped_pg_loss, clipped_pg_loss def calculate_loss( - model: Any, + model: deepspeed.DeepSpeedEngine, i: int, loss_statistics: LossStatistics, - mb_new_logprobs: torch.Tensor, - mb_old_logprobs: torch.Tensor, - mb_ref_logprob: torch.Tensor, - mb_advantages: torch.Tensor, - mb_response_masks_bool: torch.Tensor, - mb_vllm_logprobs: torch.Tensor | None, - mb_entropy: torch.Tensor, + local_logprobs: torch.Tensor, + old_logprobs: torch.Tensor, + ref_logprob: torch.Tensor, + advantages: torch.Tensor, + response_masks_bool: torch.Tensor, + vllm_logprobs: torch.Tensor | None, + entropy: torch.Tensor, accumulation_steps: int, local_step: int, - args: Namespace, + args: Args, ) -> int: """Calculate and apply GRPO loss for a single minibatch. @@ -643,13 +645,13 @@ def calculate_loss( model: Model wrapper with backward() and step() methods (e.g., DeepSpeed engine) i: Minibatch index for tracking statistics loss_statistics: LossStatistics object to accumulate training metrics - mb_new_logprobs: Log probabilities from current policy (shape: [batch, seq_len]) - mb_old_logprobs: Log probabilities from old/cached policy (shape: [batch, seq_len]) - mb_ref_logprob: Log probabilities from reference model (shape: [batch, seq_len]) - mb_advantages: Advantage estimates for policy gradient (shape: [batch, seq_len+1]) - mb_response_masks_bool: Boolean mask for valid response tokens (shape: [batch, seq_len]) - mb_vllm_logprobs: Log probabilities from behavior policy, or None (shape: [batch, seq_len]) - mb_entropy: Entropy of the policy distribution (shape: [batch, seq_len]) + local_logprobs: Log probabilities from current policy (shape: [batch, seq_len]) + old_logprobs: Log probabilities from old/cached policy (shape: [batch, seq_len]) + ref_logprob: Log probabilities from reference model (shape: [batch, seq_len]) + advantages: Advantage estimates for policy gradient (shape: [batch, seq_len+1]) + response_masks_bool: Boolean mask for valid response tokens (shape: [batch, seq_len]) + vllm_logprobs: Log probabilities from behavior policy, or None (shape: [batch, seq_len]) + entropy: Entropy of the policy distribution (shape: [batch, seq_len]) accumulation_steps: Number of gradient accumulation steps before optimizer update local_step: Current local training step (used to determine when to step optimizer) args: Training arguments containing hyperparameters @@ -657,7 +659,7 @@ def calculate_loss( Returns: Updated local_step (incremented by 1) """ - logprobs_diff = mb_new_logprobs - mb_old_logprobs + logprobs_diff = local_logprobs - old_logprobs ratio = torch.exp(logprobs_diff) # PPO clipped surrogate objective: we compute two losses and take the element-wise maximum. @@ -666,20 +668,20 @@ def calculate_loss( # Taking the maximum implements a pessimistic bound that prevents the policy from # deviating too far from the old policy. The clipfrac metric tracks how often clipping # is active (clipped_pg_loss > unclipped_pg_loss), which indicates constraint saturation. - unclipped_pg_loss = -mb_advantages[:, 1:] * ratio - clipped_pg_loss = -mb_advantages[:, 1:] * torch.clamp(ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher) + unclipped_pg_loss = -advantages[:, 1:] * ratio + clipped_pg_loss = -advantages[:, 1:] * torch.clamp(ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher) unclipped_pg_loss, clipped_pg_loss = maybe_apply_importance_sampling( - unclipped_pg_loss, clipped_pg_loss, mb_old_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args + unclipped_pg_loss, clipped_pg_loss, old_logprobs, vllm_logprobs, response_masks_bool, args ) pg_loss_max = torch.max(unclipped_pg_loss, clipped_pg_loss) - ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) - kl = loss_statistics.update_kl_estimates(i, ref_logprobs_diff, ratio, mb_response_masks_bool, args) + ref_logprobs_diff = (local_logprobs - ref_logprob).clamp(-40.0, 40.0) + kl = loss_statistics.update_kl_estimates(i, ref_logprobs_diff, ratio, response_masks_bool, args) loss = masked_mean( - pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + pg_loss_max + (args.beta * kl), response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator ) loss = loss / accumulation_steps model.backward(loss) @@ -688,7 +690,7 @@ def calculate_loss( with torch.no_grad(): loss_statistics.update_stats( - i, mb_response_masks_bool, unclipped_pg_loss, clipped_pg_loss, pg_loss_max, ratio, loss, mb_entropy, args + i, response_masks_bool, unclipped_pg_loss, clipped_pg_loss, pg_loss_max, ratio, loss, entropy, args ) return local_step + 1 @@ -1183,11 +1185,14 @@ def train( mb_vllm_logprobs = torch.nan_to_num(mb_vllm_logprobs, nan=INVALID_LOGPROB) compare_logprobs( - mb_local_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args, self.local_metrics + mb_local_logprobs, + mb_vllm_logprobs, + mb_response_masks_bool, + args.masked_mean_axis, + args.masked_mean_denominator, + self.local_metrics, ) - mb_new_logprobs = mb_local_logprobs - # Cache the old logprobs if num_mini_batches > 1: mb_old_logprobs = old_logprobs[i] @@ -1211,7 +1216,7 @@ def train( self.model, i, loss_statistics, - mb_new_logprobs, + mb_local_logprobs, mb_old_logprobs, mb_ref_logprob, mb_advantages, From abe438769df7b79307ad2ab5aebfee543b5911bc Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 6 Nov 2025 09:41:31 -0700 Subject: [PATCH 07/10] updated code to remove metricstracker --- open_instruct/grpo_fast.py | 33 +++++++------- open_instruct/metrics.py | 92 +++++--------------------------------- 2 files changed, 29 insertions(+), 96 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index bc2f4daad..cd3e22226 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -92,7 +92,7 @@ cleanup_all_llm_judge_clients, soft_format_reward_func, ) -from open_instruct.metrics import LossStatistics, MetricsTracker +from open_instruct.metrics import LossStatistics from open_instruct.model_utils import ( Batch, ModelConfig, @@ -529,11 +529,10 @@ def compare_logprobs( mask: torch.Tensor, masked_mean_axis: int | None, masked_mean_denominator: float | None, - metrics_tracker: MetricsTracker, -) -> None: +) -> dict[str, float]: """Compare locally computed log probabilities with reference log probabilities. - Computes statistics on the difference between two sets of log probabilities and records + Computes statistics on the difference between two sets of log probabilities and returns debugging metrics including mean/max/std differences and reverse KL divergence. Args: @@ -542,7 +541,9 @@ def compare_logprobs( mask: Boolean mask indicating valid response tokens (shape: [batch, seq_len]) masked_mean_axis: Axis for masked mean reduction masked_mean_denominator: Denominator for masked mean computation - metrics_tracker: MetricsTracker instance for recording debug metrics + + Returns: + Dictionary of debug metrics """ with torch.no_grad(): valid_mask = mask & ~torch.isnan(old_logprobs) @@ -557,10 +558,12 @@ def compare_logprobs( masked_mean_axis, masked_mean_denominator, ) - metrics_tracker.add("debug/vllm_vs_local_logprob_diff_mean", mean_diff) - metrics_tracker.add("debug/vllm_vs_local_logprob_diff_max", max_diff) - metrics_tracker.add("debug/vllm_vs_local_logprob_diff_std", std_diff) - metrics_tracker.add("debug/vllm_local_reverse_kl", reverse_kl) + return { + "debug/vllm_vs_local_logprob_diff_mean": mean_diff.item(), + "debug/vllm_vs_local_logprob_diff_max": max_diff.item(), + "debug/vllm_vs_local_logprob_diff_std": std_diff.item(), + "debug/vllm_local_reverse_kl": reverse_kl.item(), + } def maybe_apply_importance_sampling( @@ -1009,7 +1012,6 @@ def load(self, path: str, map_location=None): else: self.ref_policy.load_state_dict(state_dict) logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") - self.local_metrics = MetricsTracker(max_metrics=32, device=self.device) return optimization_steps_done def forward( @@ -1147,6 +1149,7 @@ def train( num_mini_batches: int, ): args = self.args + local_metrics = {} to_device_inplace(collated_query_responses, self.device) to_device_inplace(collated_tool_masks, self.device) to_device_inplace(collated_attention_masks, self.device) @@ -1268,13 +1271,12 @@ def train( # Replace any remaining NaN values (query tokens in packed sequences are set to NaN by pack_sequences in rl_utils.py) mb_vllm_logprobs = torch.nan_to_num(mb_vllm_logprobs, nan=INVALID_LOGPROB) - compare_logprobs( + local_metrics |= compare_logprobs( mb_local_logprobs, mb_vllm_logprobs, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator, - self.local_metrics, ) # Cache the old logprobs @@ -1312,10 +1314,9 @@ def train( args, ) - with torch.no_grad(): - self.local_metrics.add_dict(loss_statistics.to_dict()) - self.local_metrics.add("lr", self.scheduler.get_last_lr()[0]) - return self.local_metrics.get_metrics_list() + local_metrics |= loss_statistics.to_dict() + local_metrics["lr"] = self.scheduler.get_last_lr()[0] + return local_metrics def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[str, Any]) -> None: args = self.args diff --git a/open_instruct/metrics.py b/open_instruct/metrics.py index 52e50d9d7..53dcddee3 100644 --- a/open_instruct/metrics.py +++ b/open_instruct/metrics.py @@ -113,92 +113,24 @@ def update_stats( mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator ).float() - def to_dict(self) -> dict[str, torch.Tensor]: + def to_dict(self) -> dict[str, float]: """Convert accumulated statistics to a metrics dictionary. Returns: Dictionary mapping metric names to their averaged values across all minibatches """ metrics = { - "objective/kl_avg": self.kl_stats[0].mean(), - "objective/kl2_avg": self.kl_stats[1].mean(), - "objective/kl3_avg": self.kl_stats[2].mean(), - "objective/kl4_avg": self.kl_stats[3].mean(), - "loss/policy_avg": self.pg_loss_stats.mean(), - "loss/kl_avg": self.kl_loss_stats.mean(), - "loss/total_avg": self.loss_stats.mean(), - "policy/clipfrac_avg": self.pg_clipfrac_stats.mean(), - "val/ratio": self.ratio_stats.mean(), - "val/ratio_var": self.ratio_stats.var(), + "objective/kl_avg": self.kl_stats[0].mean().item(), + "objective/kl2_avg": self.kl_stats[1].mean().item(), + "objective/kl3_avg": self.kl_stats[2].mean().item(), + "objective/kl4_avg": self.kl_stats[3].mean().item(), + "loss/policy_avg": self.pg_loss_stats.mean().item(), + "loss/kl_avg": self.kl_loss_stats.mean().item(), + "loss/total_avg": self.loss_stats.mean().item(), + "policy/clipfrac_avg": self.pg_clipfrac_stats.mean().item(), + "val/ratio": self.ratio_stats.mean().item(), + "val/ratio_var": self.ratio_stats.var().item(), } if self.entropy_stats is not None: - metrics["policy/entropy_avg"] = self.entropy_stats.mean() + metrics["policy/entropy_avg"] = self.entropy_stats.mean().item() return metrics - - -class MetricsTracker: - """Preallocated tensor-based metrics storage for efficient distributed reduction. - - Stores all metrics in a single preallocated tensor to enable efficient all-reduce - operations in distributed training. Maintains a mapping from metric names to - tensor indices for fast access. - """ - - def __init__(self, max_metrics: int = 32, device: str = "cuda"): - """Initialize metrics tracker. - - Args: - max_metrics: Maximum number of unique metrics to track - device: Device to allocate metrics tensor on (default: "cuda") - """ - self.metrics = torch.zeros(max_metrics, device=device) - self.names2idx = {} - self.current_idx = 0 - self.max_metrics = max_metrics - - def add(self, name: str, value: torch.tensor): - """Add or update a metric value. - - If the metric name is new, allocates a new index in the metrics tensor. - If the metric already exists, updates its value at the existing index. - - Args: - name: Metric name (e.g., "loss/policy_avg") - value: Metric value (scalar tensor or convertible to tensor) - - Returns: - Self for method chaining - - Raises: - ValueError: If max_metrics limit is exceeded - """ - if name not in self.names2idx: - if self.current_idx >= self.max_metrics: - raise ValueError(f"Exceeded maximum number of metrics ({self.max_metrics})") - self.names2idx[name] = self.current_idx - self.current_idx += 1 - - self.metrics[self.names2idx[name]] = value - return self - - def add_dict(self, metrics_dict: dict[str, torch.Tensor]): - """Add multiple metrics from a dictionary. - - Args: - metrics_dict: Dictionary mapping metric names to values - - Returns: - Self for method chaining - """ - for k, v in metrics_dict.items(): - self.add(k, v) - return self - - def get_metrics_list(self) -> dict[str, float]: - """Convert tracked metrics to a dictionary of Python floats. - - Returns: - Dictionary mapping metric names to their float values - """ - metrics_list = self.metrics.tolist() - return {name: metrics_list[idx] for name, idx in self.names2idx.items()} From 3f716d6ecc500f216c85d188f0e2784de19cbde9 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 6 Nov 2025 10:00:16 -0700 Subject: [PATCH 08/10] Updated code --- CLAUDE.md | 5 +++++ open_instruct/grpo_fast.py | 15 +++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 3d2e569a7..dd2c9a16f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,3 +8,8 @@ - To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes. - Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`. - Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`. + +# Comments Policy +- NEVER remove existing comments from code when making edits unless they are obviously outdated, in which case ALWAYS ask for permission. +- Always preserve all existing comments, especially explanatory ones +- Only add comments when they are needed for clarity diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index cd3e22226..3a4b50a42 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -623,7 +623,7 @@ def maybe_apply_importance_sampling( return unclipped_pg_loss, clipped_pg_loss -def calculate_loss( +def calculate_loss_and_backward( model: deepspeed.DeepSpeedEngine, i: int, loss_statistics: LossStatistics, @@ -638,14 +638,13 @@ def calculate_loss( local_step: int, args: Args, ) -> int: - """Calculate and apply GRPO loss for a single minibatch. + """Calculate GRPO loss and perform backward pass for a single minibatch. Computes the policy gradient loss using the clipped surrogate objective from PPO, - combines it with a KL penalty term, performs the backward pass, and optionally - steps the optimizer. + combines it with a KL penalty term, and performs the backward pass. Args: - model: Model wrapper with backward() and step() methods (e.g., DeepSpeed engine) + model: Model wrapper with backward() method (e.g., DeepSpeed engine) i: Minibatch index for tracking statistics loss_statistics: LossStatistics object to accumulate training metrics local_logprobs: Log probabilities from current policy (shape: [batch, seq_len]) @@ -688,8 +687,6 @@ def calculate_loss( ) loss = loss / accumulation_steps model.backward(loss) - if (local_step + 1) % accumulation_steps == 0: - model.step() with torch.no_grad(): loss_statistics.update_stats( @@ -1298,7 +1295,7 @@ def train( f"response_mask sum={mb_response_masks_bool.sum()}" ) - local_step = calculate_loss( + local_step = calculate_loss_and_backward( self.model, i, loss_statistics, @@ -1313,6 +1310,8 @@ def train( local_step, args, ) + if local_step % accumulation_steps == 0: + self.model.step() local_metrics |= loss_statistics.to_dict() local_metrics["lr"] = self.scheduler.get_last_lr()[0] From 196a26b299964e71791394f40303cfde124be1a1 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 6 Nov 2025 10:05:54 -0700 Subject: [PATCH 09/10] Update code --- open_instruct/grpo_fast.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3a4b50a42..0b53d5370 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1282,10 +1282,9 @@ def train( else: with torch.no_grad(): if epoch_idx == 0: - if args.use_vllm_logprobs: - old_logprobs[i] = mb_vllm_logprobs - else: - old_logprobs[i] = mb_local_logprobs.detach() + old_logprobs[i] = ( + mb_vllm_logprobs if args.use_vllm_logprobs else mb_local_logprobs.detach() + ) mb_old_logprobs = old_logprobs[i] old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB From adb108b3d22326ab663a5f5a639f7284a38bc3c3 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 7 Nov 2025 11:30:50 -0700 Subject: [PATCH 10/10] Added tests --- open_instruct/test_grpo_fast.py | 236 ++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index e633f5314..21a8a5107 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -1086,5 +1086,241 @@ def test_distribution_and_structure( self.assertTrue(torch.all(row[first_pad_idx:] == pad_token_id)) +class TestCompareLogprobs(unittest.TestCase): + @parameterized.expand( + [ + ( + torch.tensor([[1.0, 2.0, 3.0]]), + torch.tensor([[1.1, 2.1, 3.1]]), + torch.tensor([[True, True, True]]), + None, + None, + ), + ( + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + torch.tensor([[1.2, 2.2], [3.2, 4.2]]), + torch.tensor([[True, True], [True, True]]), + None, + None, + ), + ] + ) + def test_basic_functionality(self, new_logprobs, old_logprobs, mask, masked_mean_axis, masked_mean_denominator): + result = grpo_fast.compare_logprobs( + new_logprobs, old_logprobs, mask, masked_mean_axis, masked_mean_denominator + ) + self.assertIsInstance(result, dict) + self.assertIn("debug/vllm_vs_local_logprob_diff_mean", result) + self.assertIn("debug/vllm_vs_local_logprob_diff_max", result) + self.assertIn("debug/vllm_vs_local_logprob_diff_std", result) + self.assertIn("debug/vllm_local_reverse_kl", result) + self.assertGreater(result["debug/vllm_vs_local_logprob_diff_mean"], 0.0) + + def test_with_nan_values(self): + new_logprobs = torch.tensor([[1.0, 2.0, 3.0]]) + old_logprobs = torch.tensor([[1.1, float("nan"), 3.1]]) + mask = torch.tensor([[True, True, True]]) + result = grpo_fast.compare_logprobs(new_logprobs, old_logprobs, mask, None, None) + self.assertIsInstance(result, dict) + self.assertFalse(torch.isnan(torch.tensor(result["debug/vllm_vs_local_logprob_diff_mean"]))) + + def test_multiple_elements(self): + new_logprobs = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) + old_logprobs = torch.tensor([[1.1, 2.1, 3.1, 4.1]]) + mask = torch.tensor([[True, True, True, True]]) + result = grpo_fast.compare_logprobs(new_logprobs, old_logprobs, mask, None, None) + self.assertIsInstance(result, dict) + self.assertGreater(result["debug/vllm_vs_local_logprob_diff_mean"], 0.0) + self.assertGreater(result["debug/vllm_vs_local_logprob_diff_std"], 0.0) + + +class TestMaybeApplyImportanceSampling(unittest.TestCase): + def test_early_return_zero_cap(self): + unclipped = torch.tensor([[-1.0, -2.0]]) + clipped = torch.tensor([[-1.1, -2.1]]) + old_logprobs = torch.tensor([[-0.5, -0.6]]) + vllm_logprobs = torch.tensor([[-0.4, -0.5]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 0.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, vllm_logprobs, mask, mock_args + ) + self.assertTrue(torch.equal(result_unclipped, unclipped)) + self.assertTrue(torch.equal(result_clipped, clipped)) + + def test_early_return_none_vllm_logprobs(self): + unclipped = torch.tensor([[-1.0, -2.0]]) + clipped = torch.tensor([[-1.1, -2.1]]) + old_logprobs = torch.tensor([[-0.5, -0.6]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 2.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, None, mask, mock_args + ) + self.assertTrue(torch.equal(result_unclipped, unclipped)) + self.assertTrue(torch.equal(result_clipped, clipped)) + + def test_importance_ratio_computation(self): + unclipped = torch.tensor([[-1.0, -2.0]]) + clipped = torch.tensor([[-1.0, -2.0]]) + old_logprobs = torch.tensor([[-2.0, -2.0]]) + vllm_logprobs = torch.tensor([[-3.0, -3.0]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 10.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, vllm_logprobs, mask, mock_args + ) + expected_ratio = torch.exp(torch.tensor([[1.0, 1.0]])) + self.assertTrue(torch.allclose(result_unclipped, unclipped * expected_ratio)) + self.assertTrue(torch.allclose(result_clipped, clipped * expected_ratio)) + + def test_ratio_capping(self): + unclipped = torch.tensor([[-1.0, -1.0]]) + clipped = torch.tensor([[-1.0, -1.0]]) + old_logprobs = torch.tensor([[-2.0, -2.0]]) + vllm_logprobs = torch.tensor([[-7.0, -7.0]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 2.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, vllm_logprobs, mask, mock_args + ) + self.assertTrue(torch.all(result_unclipped / unclipped <= 2.0)) + self.assertTrue(torch.all(result_clipped / clipped <= 2.0)) + + def test_logprob_diff_clamping(self): + unclipped = torch.tensor([[-1.0, -1.0]]) + clipped = torch.tensor([[-1.0, -1.0]]) + old_logprobs = torch.tensor([[-2.0, -17.0]]) + vllm_logprobs = torch.tensor([[-17.0, -2.0]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 1000.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, vllm_logprobs, mask, mock_args + ) + max_expected_ratio = torch.exp(torch.tensor(10.0)) + min_expected_ratio = torch.exp(torch.tensor(-10.0)) + self.assertTrue(torch.all(result_unclipped / unclipped <= max_expected_ratio)) + self.assertTrue(torch.all(result_unclipped / unclipped >= min_expected_ratio)) + + +class TestCalculateLossAndBackward(unittest.TestCase): + def test_basic_loss_computation(self): + mock_model = Mock() + mock_model.backward = Mock() + loss_statistics = Mock() + loss_statistics.update_kl_estimates = Mock(return_value=torch.tensor(0.5)) + loss_statistics.update_stats = Mock() + local_logprobs = torch.tensor([[1.0, 2.0]]) + old_logprobs = torch.tensor([[0.9, 1.9]]) + ref_logprob = torch.tensor([[0.8, 1.8]]) + advantages = torch.tensor([[0.0, 1.0, 1.0]]) + response_masks_bool = torch.tensor([[True, True]]) + entropy = torch.tensor([[0.5, 0.5]]) + mock_args = Mock() + mock_args.clip_lower = 0.2 + mock_args.clip_higher = 0.2 + mock_args.beta = 0.1 + mock_args.masked_mean_axis = None + mock_args.masked_mean_denominator = None + mock_args.truncated_importance_sampling_ratio_cap = 0.0 + result = grpo_fast.calculate_loss_and_backward( + mock_model, + 0, + loss_statistics, + local_logprobs, + old_logprobs, + ref_logprob, + advantages, + response_masks_bool, + None, + entropy, + 1, + 0, + mock_args, + ) + mock_model.backward.assert_called_once() + self.assertEqual(result, 1) + + def test_gradient_accumulation(self): + mock_model = Mock() + mock_model.backward = Mock() + loss_statistics = Mock() + loss_statistics.update_kl_estimates = Mock(return_value=torch.tensor(0.5)) + loss_statistics.update_stats = Mock() + local_logprobs = torch.tensor([[1.0, 2.0]]) + old_logprobs = torch.tensor([[0.9, 1.9]]) + ref_logprob = torch.tensor([[0.8, 1.8]]) + advantages = torch.tensor([[0.0, 1.0, 1.0]]) + response_masks_bool = torch.tensor([[True, True]]) + entropy = torch.tensor([[0.5, 0.5]]) + mock_args = Mock() + mock_args.clip_lower = 0.2 + mock_args.clip_higher = 0.2 + mock_args.beta = 0.1 + mock_args.masked_mean_axis = None + mock_args.masked_mean_denominator = None + mock_args.truncated_importance_sampling_ratio_cap = 0.0 + grpo_fast.calculate_loss_and_backward( + mock_model, + 0, + loss_statistics, + local_logprobs, + old_logprobs, + ref_logprob, + advantages, + response_masks_bool, + None, + entropy, + 4, + 0, + mock_args, + ) + mock_model.backward.assert_called_once() + loss_arg = mock_model.backward.call_args[0][0] + self.assertIsInstance(loss_arg, torch.Tensor) + + def test_advantages_slicing(self): + mock_model = Mock() + mock_model.backward = Mock() + loss_statistics = Mock() + loss_statistics.update_kl_estimates = Mock(return_value=torch.tensor(0.5)) + loss_statistics.update_stats = Mock() + local_logprobs = torch.tensor([[1.0, 2.0, 3.0]]) + old_logprobs = torch.tensor([[1.0, 2.0, 3.0]]) + ref_logprob = torch.tensor([[1.0, 2.0, 3.0]]) + advantages = torch.tensor([[0.0, 1.0, 1.0, 1.0]]) + response_masks_bool = torch.tensor([[True, True, True]]) + entropy = torch.tensor([[0.5, 0.5, 0.5]]) + mock_args = Mock() + mock_args.clip_lower = 0.2 + mock_args.clip_higher = 0.2 + mock_args.beta = 0.1 + mock_args.masked_mean_axis = None + mock_args.masked_mean_denominator = None + mock_args.truncated_importance_sampling_ratio_cap = 0.0 + result = grpo_fast.calculate_loss_and_backward( + mock_model, + 0, + loss_statistics, + local_logprobs, + old_logprobs, + ref_logprob, + advantages, + response_masks_bool, + None, + entropy, + 1, + 0, + mock_args, + ) + mock_model.backward.assert_called_once() + self.assertEqual(result, 1) + + if __name__ == "__main__": unittest.main()