Skip to content

Commit abe4387

Browse files
updated code to remove metricstracker
1 parent 783579d commit abe4387

File tree

2 files changed

+29
-96
lines changed

2 files changed

+29
-96
lines changed

open_instruct/grpo_fast.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
cleanup_all_llm_judge_clients,
9393
soft_format_reward_func,
9494
)
95-
from open_instruct.metrics import LossStatistics, MetricsTracker
95+
from open_instruct.metrics import LossStatistics
9696
from open_instruct.model_utils import (
9797
Batch,
9898
ModelConfig,
@@ -529,11 +529,10 @@ def compare_logprobs(
529529
mask: torch.Tensor,
530530
masked_mean_axis: int | None,
531531
masked_mean_denominator: float | None,
532-
metrics_tracker: MetricsTracker,
533-
) -> None:
532+
) -> dict[str, float]:
534533
"""Compare locally computed log probabilities with reference log probabilities.
535534
536-
Computes statistics on the difference between two sets of log probabilities and records
535+
Computes statistics on the difference between two sets of log probabilities and returns
537536
debugging metrics including mean/max/std differences and reverse KL divergence.
538537
539538
Args:
@@ -542,7 +541,9 @@ def compare_logprobs(
542541
mask: Boolean mask indicating valid response tokens (shape: [batch, seq_len])
543542
masked_mean_axis: Axis for masked mean reduction
544543
masked_mean_denominator: Denominator for masked mean computation
545-
metrics_tracker: MetricsTracker instance for recording debug metrics
544+
545+
Returns:
546+
Dictionary of debug metrics
546547
"""
547548
with torch.no_grad():
548549
valid_mask = mask & ~torch.isnan(old_logprobs)
@@ -557,10 +558,12 @@ def compare_logprobs(
557558
masked_mean_axis,
558559
masked_mean_denominator,
559560
)
560-
metrics_tracker.add("debug/vllm_vs_local_logprob_diff_mean", mean_diff)
561-
metrics_tracker.add("debug/vllm_vs_local_logprob_diff_max", max_diff)
562-
metrics_tracker.add("debug/vllm_vs_local_logprob_diff_std", std_diff)
563-
metrics_tracker.add("debug/vllm_local_reverse_kl", reverse_kl)
561+
return {
562+
"debug/vllm_vs_local_logprob_diff_mean": mean_diff.item(),
563+
"debug/vllm_vs_local_logprob_diff_max": max_diff.item(),
564+
"debug/vllm_vs_local_logprob_diff_std": std_diff.item(),
565+
"debug/vllm_local_reverse_kl": reverse_kl.item(),
566+
}
564567

565568

566569
def maybe_apply_importance_sampling(
@@ -1009,7 +1012,6 @@ def load(self, path: str, map_location=None):
10091012
else:
10101013
self.ref_policy.load_state_dict(state_dict)
10111014
logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}")
1012-
self.local_metrics = MetricsTracker(max_metrics=32, device=self.device)
10131015
return optimization_steps_done
10141016

10151017
def forward(
@@ -1147,6 +1149,7 @@ def train(
11471149
num_mini_batches: int,
11481150
):
11491151
args = self.args
1152+
local_metrics = {}
11501153
to_device_inplace(collated_query_responses, self.device)
11511154
to_device_inplace(collated_tool_masks, self.device)
11521155
to_device_inplace(collated_attention_masks, self.device)
@@ -1268,13 +1271,12 @@ def train(
12681271
# Replace any remaining NaN values (query tokens in packed sequences are set to NaN by pack_sequences in rl_utils.py)
12691272
mb_vllm_logprobs = torch.nan_to_num(mb_vllm_logprobs, nan=INVALID_LOGPROB)
12701273

1271-
compare_logprobs(
1274+
local_metrics |= compare_logprobs(
12721275
mb_local_logprobs,
12731276
mb_vllm_logprobs,
12741277
mb_response_masks_bool,
12751278
args.masked_mean_axis,
12761279
args.masked_mean_denominator,
1277-
self.local_metrics,
12781280
)
12791281

12801282
# Cache the old logprobs
@@ -1312,10 +1314,9 @@ def train(
13121314
args,
13131315
)
13141316

1315-
with torch.no_grad():
1316-
self.local_metrics.add_dict(loss_statistics.to_dict())
1317-
self.local_metrics.add("lr", self.scheduler.get_last_lr()[0])
1318-
return self.local_metrics.get_metrics_list()
1317+
local_metrics |= loss_statistics.to_dict()
1318+
local_metrics["lr"] = self.scheduler.get_last_lr()[0]
1319+
return local_metrics
13191320

13201321
def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[str, Any]) -> None:
13211322
args = self.args

open_instruct/metrics.py

Lines changed: 12 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -113,92 +113,24 @@ def update_stats(
113113
mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
114114
).float()
115115

116-
def to_dict(self) -> dict[str, torch.Tensor]:
116+
def to_dict(self) -> dict[str, float]:
117117
"""Convert accumulated statistics to a metrics dictionary.
118118
119119
Returns:
120120
Dictionary mapping metric names to their averaged values across all minibatches
121121
"""
122122
metrics = {
123-
"objective/kl_avg": self.kl_stats[0].mean(),
124-
"objective/kl2_avg": self.kl_stats[1].mean(),
125-
"objective/kl3_avg": self.kl_stats[2].mean(),
126-
"objective/kl4_avg": self.kl_stats[3].mean(),
127-
"loss/policy_avg": self.pg_loss_stats.mean(),
128-
"loss/kl_avg": self.kl_loss_stats.mean(),
129-
"loss/total_avg": self.loss_stats.mean(),
130-
"policy/clipfrac_avg": self.pg_clipfrac_stats.mean(),
131-
"val/ratio": self.ratio_stats.mean(),
132-
"val/ratio_var": self.ratio_stats.var(),
123+
"objective/kl_avg": self.kl_stats[0].mean().item(),
124+
"objective/kl2_avg": self.kl_stats[1].mean().item(),
125+
"objective/kl3_avg": self.kl_stats[2].mean().item(),
126+
"objective/kl4_avg": self.kl_stats[3].mean().item(),
127+
"loss/policy_avg": self.pg_loss_stats.mean().item(),
128+
"loss/kl_avg": self.kl_loss_stats.mean().item(),
129+
"loss/total_avg": self.loss_stats.mean().item(),
130+
"policy/clipfrac_avg": self.pg_clipfrac_stats.mean().item(),
131+
"val/ratio": self.ratio_stats.mean().item(),
132+
"val/ratio_var": self.ratio_stats.var().item(),
133133
}
134134
if self.entropy_stats is not None:
135-
metrics["policy/entropy_avg"] = self.entropy_stats.mean()
135+
metrics["policy/entropy_avg"] = self.entropy_stats.mean().item()
136136
return metrics
137-
138-
139-
class MetricsTracker:
140-
"""Preallocated tensor-based metrics storage for efficient distributed reduction.
141-
142-
Stores all metrics in a single preallocated tensor to enable efficient all-reduce
143-
operations in distributed training. Maintains a mapping from metric names to
144-
tensor indices for fast access.
145-
"""
146-
147-
def __init__(self, max_metrics: int = 32, device: str = "cuda"):
148-
"""Initialize metrics tracker.
149-
150-
Args:
151-
max_metrics: Maximum number of unique metrics to track
152-
device: Device to allocate metrics tensor on (default: "cuda")
153-
"""
154-
self.metrics = torch.zeros(max_metrics, device=device)
155-
self.names2idx = {}
156-
self.current_idx = 0
157-
self.max_metrics = max_metrics
158-
159-
def add(self, name: str, value: torch.tensor):
160-
"""Add or update a metric value.
161-
162-
If the metric name is new, allocates a new index in the metrics tensor.
163-
If the metric already exists, updates its value at the existing index.
164-
165-
Args:
166-
name: Metric name (e.g., "loss/policy_avg")
167-
value: Metric value (scalar tensor or convertible to tensor)
168-
169-
Returns:
170-
Self for method chaining
171-
172-
Raises:
173-
ValueError: If max_metrics limit is exceeded
174-
"""
175-
if name not in self.names2idx:
176-
if self.current_idx >= self.max_metrics:
177-
raise ValueError(f"Exceeded maximum number of metrics ({self.max_metrics})")
178-
self.names2idx[name] = self.current_idx
179-
self.current_idx += 1
180-
181-
self.metrics[self.names2idx[name]] = value
182-
return self
183-
184-
def add_dict(self, metrics_dict: dict[str, torch.Tensor]):
185-
"""Add multiple metrics from a dictionary.
186-
187-
Args:
188-
metrics_dict: Dictionary mapping metric names to values
189-
190-
Returns:
191-
Self for method chaining
192-
"""
193-
for k, v in metrics_dict.items():
194-
self.add(k, v)
195-
return self
196-
197-
def get_metrics_list(self) -> dict[str, float]:
198-
"""Convert tracked metrics to a dictionary of Python floats.
199-
200-
Returns:
201-
Dictionary mapping metric names to their float values
202-
"""
203-
metrics_list = self.metrics.tolist()
204-
return {name: metrics_list[idx] for name, idx in self.names2idx.items()}

0 commit comments

Comments
 (0)