@@ -113,92 +113,24 @@ def update_stats(
113113 mb_entropy , mb_response_masks_bool , args .masked_mean_axis , args .masked_mean_denominator
114114 ).float ()
115115
116- def to_dict (self ) -> dict [str , torch . Tensor ]:
116+ def to_dict (self ) -> dict [str , float ]:
117117 """Convert accumulated statistics to a metrics dictionary.
118118
119119 Returns:
120120 Dictionary mapping metric names to their averaged values across all minibatches
121121 """
122122 metrics = {
123- "objective/kl_avg" : self .kl_stats [0 ].mean (),
124- "objective/kl2_avg" : self .kl_stats [1 ].mean (),
125- "objective/kl3_avg" : self .kl_stats [2 ].mean (),
126- "objective/kl4_avg" : self .kl_stats [3 ].mean (),
127- "loss/policy_avg" : self .pg_loss_stats .mean (),
128- "loss/kl_avg" : self .kl_loss_stats .mean (),
129- "loss/total_avg" : self .loss_stats .mean (),
130- "policy/clipfrac_avg" : self .pg_clipfrac_stats .mean (),
131- "val/ratio" : self .ratio_stats .mean (),
132- "val/ratio_var" : self .ratio_stats .var (),
123+ "objective/kl_avg" : self .kl_stats [0 ].mean (). item () ,
124+ "objective/kl2_avg" : self .kl_stats [1 ].mean (). item () ,
125+ "objective/kl3_avg" : self .kl_stats [2 ].mean (). item () ,
126+ "objective/kl4_avg" : self .kl_stats [3 ].mean (). item () ,
127+ "loss/policy_avg" : self .pg_loss_stats .mean (). item () ,
128+ "loss/kl_avg" : self .kl_loss_stats .mean (). item () ,
129+ "loss/total_avg" : self .loss_stats .mean (). item () ,
130+ "policy/clipfrac_avg" : self .pg_clipfrac_stats .mean (). item () ,
131+ "val/ratio" : self .ratio_stats .mean (). item () ,
132+ "val/ratio_var" : self .ratio_stats .var (). item () ,
133133 }
134134 if self .entropy_stats is not None :
135- metrics ["policy/entropy_avg" ] = self .entropy_stats .mean ()
135+ metrics ["policy/entropy_avg" ] = self .entropy_stats .mean (). item ()
136136 return metrics
137-
138-
139- class MetricsTracker :
140- """Preallocated tensor-based metrics storage for efficient distributed reduction.
141-
142- Stores all metrics in a single preallocated tensor to enable efficient all-reduce
143- operations in distributed training. Maintains a mapping from metric names to
144- tensor indices for fast access.
145- """
146-
147- def __init__ (self , max_metrics : int = 32 , device : str = "cuda" ):
148- """Initialize metrics tracker.
149-
150- Args:
151- max_metrics: Maximum number of unique metrics to track
152- device: Device to allocate metrics tensor on (default: "cuda")
153- """
154- self .metrics = torch .zeros (max_metrics , device = device )
155- self .names2idx = {}
156- self .current_idx = 0
157- self .max_metrics = max_metrics
158-
159- def add (self , name : str , value : torch .tensor ):
160- """Add or update a metric value.
161-
162- If the metric name is new, allocates a new index in the metrics tensor.
163- If the metric already exists, updates its value at the existing index.
164-
165- Args:
166- name: Metric name (e.g., "loss/policy_avg")
167- value: Metric value (scalar tensor or convertible to tensor)
168-
169- Returns:
170- Self for method chaining
171-
172- Raises:
173- ValueError: If max_metrics limit is exceeded
174- """
175- if name not in self .names2idx :
176- if self .current_idx >= self .max_metrics :
177- raise ValueError (f"Exceeded maximum number of metrics ({ self .max_metrics } )" )
178- self .names2idx [name ] = self .current_idx
179- self .current_idx += 1
180-
181- self .metrics [self .names2idx [name ]] = value
182- return self
183-
184- def add_dict (self , metrics_dict : dict [str , torch .Tensor ]):
185- """Add multiple metrics from a dictionary.
186-
187- Args:
188- metrics_dict: Dictionary mapping metric names to values
189-
190- Returns:
191- Self for method chaining
192- """
193- for k , v in metrics_dict .items ():
194- self .add (k , v )
195- return self
196-
197- def get_metrics_list (self ) -> dict [str , float ]:
198- """Convert tracked metrics to a dictionary of Python floats.
199-
200- Returns:
201- Dictionary mapping metric names to their float values
202- """
203- metrics_list = self .metrics .tolist ()
204- return {name : metrics_list [idx ] for name , idx in self .names2idx .items ()}
0 commit comments