diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst
index e232fa8155..9b589fa53e 100644
--- a/docs/sphinx_doc/source/index.rst
+++ b/docs/sphinx_doc/source/index.rst
@@ -26,6 +26,7 @@ Welcome to Trinity-RFT's documentation!
tutorial/trinity_gpu_configs.md
tutorial/synchronizer.md
tutorial/align_with_verl.md
+ tutorial/metrics_reference.md
.. toctree::
diff --git a/docs/sphinx_doc/source/tutorial/metrics_reference.md b/docs/sphinx_doc/source/tutorial/metrics_reference.md
new file mode 100644
index 0000000000..9f9b7b3368
--- /dev/null
+++ b/docs/sphinx_doc/source/tutorial/metrics_reference.md
@@ -0,0 +1,183 @@
+# Metrics Reference
+
+This document provides an overview of the metric categories used in Trinity-RFT for tracking exploration, evaluation, and training progress.
+
+## Metric Naming Convention
+
+Most metrics follow a hierarchical naming convention: `{category}/{taskset_name}/{metric_name}/{statistic}`
+
+- **Category**: Broad functional area (rollout, eval, time, actor, critic, etc.)
+- **Taskset name**: Name of the taskset used, only applicable for eval metrics
+- **Metric name**: Specific metric being measured
+- **Statistic**: Aggregation method (mean, max, min, std, etc.) if applicable
+
+
+## Metric Categories
+
+In the following, metrics are categorized by their source component (where they are generated) and their metric prefix (the first part of the metric name).
+
+### Explorer Metrics
+
+Explorer metrics track performance during the rollout phase where the model generates responses, including rollout metrics (`rollout/`), eval metrics (`eval/`), and some time metrics (`time/`).
+
+#### Metric Aggregation Levels
+
+Consider an exploration step with `batch_size` tasks, where each task has `repeat_times` runs. Rollout metrics (e.g., `rollout/`) are computed and aggregated at different levels:
+
+- **Task level**: Metrics aggregated across `repeat_times` runs of the same task. For example, `rollout/accuracy` is the average accuracy of all runs of the task.
+
+- **Step level**: Metrics are reported at the step level. For example, `rollout/accuracy/mean`, `rollout/accuracy/max`, `rollout/accuracy/min` are the average, max, and min accuracy (`rollout/accuracy`) of all tasks in the step.
+
+The following diagram illustrates the aggregation process for rollout metrics:
+
+```mermaid
+graph TD
+ subgraph Step["Batch_size=3 Tasks"]
+ subgraph Task1["Task 1 (repeat_times=2)"]
+ Run1_1["Run 1
accuracy: 0.8"]
+ Run1_2["Run 2
accuracy: 0.9"]
+ Run1_1 --> Task1_Metric["rollout/accuracy
= 0.85"]
+ Run1_2 --> Task1_Metric
+ end
+
+ subgraph Task2["Task 2 (repeat_times=2)"]
+ Run2_1["Run 1
accuracy: 0.6"]
+ Run2_2["Run 2
accuracy: 0.9"]
+ Run2_1 --> Task2_Metric["rollout/accuracy
= 0.75"]
+ Run2_2 --> Task2_Metric
+ end
+
+ subgraph TaskN["Task 3 (repeat_times=2)"]
+ Run3_1["Run 1
accuracy: 0.95"]
+ Run3_2["Run 2
accuracy: 0.85"]
+ Run3_1 --> Task3_Metric["rollout/accuracy
= 0.9"]
+ Run3_2 --> Task3_Metric
+ end
+
+ Task1_Metric --> Step_Metrics["Step Level Metrics
rollout/accuracy/mean=0.83
rollout/accuracy/max=0.9
rollout/accuracy/min=0.6"]
+ Task2_Metric --> Step_Metrics
+ Task3_Metric --> Step_Metrics
+ end
+```
+
+Consider an evaluation step with `len(eval_taskset)` tasks, where each task has `repeat_times` runs. Evaluation metrics (e.g., `eval/`, `bench/`) are computed and aggregated at different levels:
+
+- **Task level**: Task-level metrics include (e.g., `mean@4`, `std@4`, `best@2`, `worst@2`) that are computed from k runs of the task.
+
+- **Step level**: By default, we report the mean of the metric across all evaluation tasks. For example, `mean@k`, `std@k`, `best@k`, `worst@k` of the metrics across all evaluation tasks are reported. If you want to return detailed statistics, including mean, std, min, max, you can set `monitor.detailed_stats` to `True` in the config.
+
+The following diagram illustrates the aggregation process on a dummy dataset with three tasks for evaluation metrics. By default, `mean@k`, `std@k`, `best@k`, `worst@k` of the metrics across all evaluation tasks are reported. you may configure `monitor.detailed_stats` to `True` in the config to return detailed statistics, including mean, std, min, max, e.g., `eval/dummy/accuracy/mean@2/mean=0.83`, `eval/dummy/accuracy/mean@2/std=0.062`, `eval/dummy/accuracy/mean@2/max=0.9`, and `eval/dummy/accuracy/mean@2/min=0.75`.
+
+```mermaid
+graph TD
+ subgraph Step["len(eval_taskset)=3 Tasks"]
+ subgraph Task1["Task 1 (repeat_times=2)"]
+ Run1_1["Run 1
accuracy: 0.8"]
+ Run1_2["Run 2
accuracy: 0.9"]
+ Run1_1 --> Task1_Metric["eval/dummy/accuracy/mean@2=0.85
eval/dummy/accuracy/std@2=0.05"]
+ Run1_2 --> Task1_Metric
+ end
+
+ subgraph Task2["Task 2 (repeat_times=2)"]
+ Run2_1["Run 1
accuracy: 0.6"]
+ Run2_2["Run 2
accuracy: 0.9"]
+ Run2_1 --> Task2_Metric["eval/dummy/accuracy/mean@2=0.75
eval/dummy/accuracy/std@2=0.15"]
+ Run2_2 --> Task2_Metric
+ end
+
+ subgraph TaskN["Task 3 (repeat_times=2)"]
+ Run3_1["Run 1
accuracy: 0.95"]
+ Run3_2["Run 2
accuracy: 0.85"]
+ Run3_1 --> Task3_Metric["eval/dummy/accuracy/mean@2=0.9
eval/dummy/accuracy/std@2=0.05"]
+ Run3_2 --> Task3_Metric
+ end
+
+ Task1_Metric --> Step_Metrics["Step Level Metrics
eval/dummy/accuracy/mean@2=0.83
eval/dummy/accuracy/std@2=0.083"]
+ Task2_Metric --> Step_Metrics
+ Task3_Metric --> Step_Metrics
+ end
+```
+
+
+#### Rollout Metrics (`rollout/`)
+
+Rollout metrics track performance during the rollout phase where the model generates responses.
+
+- **Format**: `rollout/{metric_name}/{statistic}`
+- **Examples**:
+ - `rollout/accuracy/mean`: Average accuracy of generated responses
+ - `rollout/format_score/mean`: Average format correctness score
+
+#### Eval Metrics (`eval/`) and Benchmark Metrics (`bench/`)
+
+Evaluation metrics measure model performance on held-out evaluation tasks. These metrics are computed during periodic evaluation runs.
+
+- **Format**: `eval/{task_name}/{metric_name}/{statistic}` or `bench/{task_name}/{metric_name}/{statistic}`
+- **Examples**:
+ - `eval/gsm8k-eval/accuracy/mean@4`: Mean accuracy across repeat_times=4 runs
+ - `bench/gsm8k-eval/accuracy/best@4`: Best accuracy value across repeat_times=4 runs
+
+- **Note**:
+ - Eval and bench metrics are computed in the same way, the only difference is the prefix of the metric name.
+ - By default, only the *mean* of the metric is returned. If you want to return detailed statistics, you can set `monitor.detailed_stats` to `True` in the config.
+
+
+#### Time Metrics (`time/`)
+
+Time metrics measure execution duration for various operations throughout the training pipeline.
+
+- **Format**: `time/{operation_name}`
+- **Examples**:
+ - `time/eval`: Time from the start of submitting evaluation tasks to the end of the evaluation phase; this duration includes both evaluation tasks and some rollout tasks.
+ - `time/train_step`: Total time for one training step
+
+**Note**:
+ - Time measuring can be inaccurate due to the asynchronous nature of the exploration pipeline, but it is still useful for monitoring the overall training progress.
+ - Above metrics are reported in seconds unless otherwise specified.
+ - Some training operations also report per-token timing metrics with the prefix `timing_per_token_ms/` (e.g., `timing_per_token_ms/update_actor`, `timing_per_token_ms/update_critic`, `timing_per_token_ms/adv`, `timing_per_token_ms/values`). These metrics normalize execution time by the number of tokens processed, providing efficiency measurements independent of batch size.
+
+
+### Training Metrics
+
+This category includes metrics that track the training dynamics of the policy (actor) model (`actor/`) and the value function (critic) model (`critic/`), as well as some performance metrics (`perf/`, `global_seqlen/`, `response_length/`, `prompt_length/`, `time/`). These metrics are adapted from [veRL](https://github.com/volcengine/verl). Interested users can refer to the [veRL documentation](https://verl.readthedocs.io/en/latest/index.html) for more details.
+
+
+### Data Processing Metrics
+
+This category includes metrics that track the processing of experiences through various pipeline operators (`experience_pipeline/`) and data sampling statistics (`sample/`). These metrics are aggregated at the step level, as the experience pipeline and data sampling are performed in each step.
+
+
+#### Experience Pipeline Metrics (`experience_pipeline/` and `time/experience_pipeline/`)
+
+Experience pipeline metrics track the processing of experiences through various pipeline operators. Each metric represents the count of the specific operator in one step.
+
+- **Format**: `experience_pipeline/{metric_name}`
+- **Examples**:
+ - `experience_pipeline/experience_count`: Number of experiences processed
+ - `experience_pipeline/group_advantages/reward_mean/mean`: Here `reward_mean` is the mean reward of each task, then we compute the mean of the mean rewards of all tasks in the step.
+
+The following diagram illustrates the aggregation process for data processing metrics:
+```mermaid
+graph TD
+ subgraph Step["4 Experiences in one step"]
+ subgraph Task1["Experience 1"]
+ Run1_1["Run 1
reward_mean: 0.8"]
+ Run1_2["Run 2
reward_mean: 0.8"]
+ Run2_1["Run 3
reward_mean: 0.9"]
+ Run2_2["Run 4
reward_mean: 0.9"]
+ Run1_1 --> Task1_Metric["rollout/accuracy
= 0.85"]
+ Run1_2 --> Task1_Metric
+ Run2_1 --> Task1_Metric
+ Run2_2 --> Task1_Metric
+ end
+ end
+```
+
+#### Sample Metrics (`sample/`)
+
+Sample metrics track data sampling statistics during training.
+
+- **Format**: `sample/{metric_name}`
+- **Examples**:
+ - `sample/model_version/mean`: Mean model version of sampled experiences
+ - `sample/task_count`: Number of tasks in the sampled batch
diff --git a/docs/sphinx_doc/source_zh/index.rst b/docs/sphinx_doc/source_zh/index.rst
index 09105b12bb..dcb749b8a5 100644
--- a/docs/sphinx_doc/source_zh/index.rst
+++ b/docs/sphinx_doc/source_zh/index.rst
@@ -25,6 +25,7 @@
tutorial/trinity_gpu_configs.md
tutorial/synchronizer.md
tutorial/align_with_verl.md
+ tutorial/metrics_reference.md
.. toctree::
:maxdepth: 1
diff --git a/docs/sphinx_doc/source_zh/tutorial/metrics_reference.md b/docs/sphinx_doc/source_zh/tutorial/metrics_reference.md
new file mode 100644
index 0000000000..93ebe5caf5
--- /dev/null
+++ b/docs/sphinx_doc/source_zh/tutorial/metrics_reference.md
@@ -0,0 +1,183 @@
+# 指标解释
+
+本文档解释了 Trinity-RFT 中用于跟踪探索、评估和训练进度的指标类别。
+
+## 指标命名规范
+
+大多数指标遵循分层命名规范:`{category}/{taskset_name}/{metric_name}/{statistic}`
+
+- **Category(类别)**:广泛的功能领域(rollout、eval、time、actor、critic 等)
+- **Taskset name(任务集名称)**:使用的任务集名称,仅适用于评估指标
+- **Metric name(指标名称)**:正在测量的具体指标
+- **Statistic(统计量)**:统计指标(mean、max、min、std 等,如适用)
+
+
+## 指标类别
+
+以下内容按指标来源(生成位置)和指标前缀(指标名称的第一部分)对指标进行分类。
+
+### Explorer 相关指标
+
+探索器指标跟踪模型生成响应的 rollout 阶段的性能,包括 rollout 指标(`rollout/`)、评估指标(`eval/`)和一些时间指标(`time/`)。
+
+#### 指标聚合级别
+
+考虑一个包含 `batch_size` 个任务的探索步骤,其中每个任务有 `repeat_times` 次运行。Rollout 指标(例如,`rollout/`)在不同级别计算和聚合:
+
+- **任务级别**:跨同一任务的 `repeat_times` 次运行聚合的指标。例如,`rollout/accuracy` 是该任务所有运行的平均准确率。
+
+- **步骤级别**:在步骤级别报告指标。例如,`rollout/accuracy/mean`、`rollout/accuracy/max`、`rollout/accuracy/min` 分别是步骤中所有任务的准确率(`rollout/accuracy`)的平均值、最大值和最小值。
+
+以下图表说明了 rollout 指标的聚合过程:
+
+```mermaid
+graph TD
+ subgraph Step["Batch_size=3 Tasks"]
+ subgraph Task1["Task 1 (repeat_times=2)"]
+ Run1_1["Run 1
accuracy: 0.8"]
+ Run1_2["Run 2
accuracy: 0.9"]
+ Run1_1 --> Task1_Metric["rollout/accuracy
= 0.85"]
+ Run1_2 --> Task1_Metric
+ end
+
+ subgraph Task2["Task 2 (repeat_times=2)"]
+ Run2_1["Run 1
accuracy: 0.6"]
+ Run2_2["Run 2
accuracy: 0.9"]
+ Run2_1 --> Task2_Metric["rollout/accuracy
= 0.75"]
+ Run2_2 --> Task2_Metric
+ end
+
+ subgraph TaskN["Task 3 (repeat_times=2)"]
+ Run3_1["Run 1
accuracy: 0.95"]
+ Run3_2["Run 2
accuracy: 0.85"]
+ Run3_1 --> Task3_Metric["rollout/accuracy
= 0.9"]
+ Run3_2 --> Task3_Metric
+ end
+
+ Task1_Metric --> Step_Metrics["Step Level Metrics
rollout/accuracy/mean=0.83
rollout/accuracy/max=0.9
rollout/accuracy/min=0.6"]
+ Task2_Metric --> Step_Metrics
+ Task3_Metric --> Step_Metrics
+ end
+```
+
+考虑一个包含 `len(eval_taskset)` 个任务的评估步骤,其中每个任务有 `repeat_times` 次运行。评估指标(例如,`eval/`、`bench/`)在不同级别计算和聚合:
+
+- **任务级别**:任务级别指标包括(例如,`mean@4`、`std@4`、`best@2`、`worst@2`),这些指标是从任务的 k 次运行中计算的。
+
+- **步骤级别**:默认情况下,我们报告所有评估任务中指标的平均值。例如,报告所有评估任务中指标的 `mean@k`、`std@k`、`best@k`、`worst@k`。如果你想返回详细统计信息,包括 mean、std、min、max,可以在配置中将 `monitor.detailed_stats` 设置为 `True`。
+
+以下图表说明了在包含三个任务的虚拟数据集上评估指标的聚合过程。默认情况下,报告所有评估任务中指标的 `mean@k`、`std@k`、`best@k`、`worst@k`。你可以在配置中将 `monitor.detailed_stats` 设置为 `True` 以返回详细统计信息,包括 mean、std、min、max,例如 `eval/dummy/accuracy/mean@2/mean=0.83`、`eval/dummy/accuracy/mean@2/std=0.062`、`eval/dummy/accuracy/mean@2/max=0.9` 和 `eval/dummy/accuracy/mean@2/min=0.75`。
+
+```mermaid
+graph TD
+ subgraph Step["len(eval_taskset)=3 Tasks"]
+ subgraph Task1["Task 1 (repeat_times=2)"]
+ Run1_1["Run 1
accuracy: 0.8"]
+ Run1_2["Run 2
accuracy: 0.9"]
+ Run1_1 --> Task1_Metric["eval/dummy/accuracy/mean@2=0.85
eval/dummy/accuracy/std@2=0.05"]
+ Run1_2 --> Task1_Metric
+ end
+
+ subgraph Task2["Task 2 (repeat_times=2)"]
+ Run2_1["Run 1
accuracy: 0.6"]
+ Run2_2["Run 2
accuracy: 0.9"]
+ Run2_1 --> Task2_Metric["eval/dummy/accuracy/mean@2=0.75
eval/dummy/accuracy/std@2=0.15"]
+ Run2_2 --> Task2_Metric
+ end
+
+ subgraph TaskN["Task 3 (repeat_times=2)"]
+ Run3_1["Run 1
accuracy: 0.95"]
+ Run3_2["Run 2
accuracy: 0.85"]
+ Run3_1 --> Task3_Metric["eval/dummy/accuracy/mean@2=0.9
eval/dummy/accuracy/std@2=0.05"]
+ Run3_2 --> Task3_Metric
+ end
+
+ Task1_Metric --> Step_Metrics["Step Level Metrics
eval/dummy/accuracy/mean@2=0.83
eval/dummy/accuracy/std@2=0.083"]
+ Task2_Metric --> Step_Metrics
+ Task3_Metric --> Step_Metrics
+ end
+```
+
+
+#### Rollout 指标(`rollout/`)
+
+Rollout 指标跟踪模型生成响应的 rollout 阶段的性能。
+
+- **格式**:`rollout/{metric_name}/{statistic}`
+- **示例**:
+ - `rollout/accuracy/mean`:生成响应的平均准确率
+ - `rollout/format_score/mean`:平均格式正确性分数
+
+#### 评估指标(`eval/`)和基准测试指标(`bench/`)
+
+评估指标衡量模型在保留的评估任务上的性能。这些指标在定期评估运行期间计算。
+
+- **格式**:`eval/{task_name}/{metric_name}/{statistic}` 或 `bench/{task_name}/{metric_name}/{statistic}`
+- **示例**:
+ - `eval/gsm8k-eval/accuracy/mean@4`:跨 repeat_times=4 次运行的平均准确率
+ - `bench/gsm8k-eval/accuracy/best@4`:跨 repeat_times=4 次运行的最佳准确率值
+
+- **注意**:
+ - Eval 和 bench 指标的计算方式相同,唯一的区别是指标名称的前缀。
+ - 默认情况下,只返回指标的*平均值*。如果你想返回详细统计信息,可以在配置中将 `monitor.detailed_stats` 设置为 `True`。
+
+
+#### 时间指标(`time/`)
+
+时间指标测量整个训练管道中各种操作的执行持续时间。
+
+- **格式**:`time/{operation_name}`
+- **示例**:
+ - `time/eval`:从提交评估任务开始到评估阶段结束的时间;此持续时间包括评估任务和一些 rollout 任务。
+ - `time/train_step`:一个训练步骤的总时间
+
+**注意**:
+ - 由于探索管道的异步性质,时间测量可能不准确,但对于监控整体训练进度仍然有用。
+ - 除非另有说明,上述指标以秒为单位报告。
+ - 一些训练操作还报告每 token 的时间指标,前缀为 `timing_per_token_ms/`(例如,`timing_per_token_ms/update_actor`、`timing_per_token_ms/update_critic`、`timing_per_token_ms/adv`、`timing_per_token_ms/values`)。这些指标通过处理的 token 数量对执行时间进行归一化,提供独立于批次大小的效率测量。
+
+
+### 训练指标
+
+此类别包括跟踪策略(actor)模型(`actor/`)和价值函数(critic)模型(`critic/`)的训练动态的指标,以及一些性能指标(`perf/`、`global_seqlen/`、`response_length/`、`prompt_length/`、`time/`)。这些指标改编自 [veRL](https://github.com/volcengine/verl)。感兴趣的用户可以参考 [veRL 文档](https://verl.readthedocs.io/en/latest/index.html) 了解更多详细信息。
+
+
+### 数据处理指标
+
+此类别包括跟踪通过各种管道操作符处理经验(`experience_pipeline/`)和数据采样统计(`sample/`)的指标。这些指标在步骤级别聚合,因为经验管道和数据采样在每个步骤中执行。
+
+
+#### Experience Pipeline 相关指标(`experience_pipeline/` 和 `time/experience_pipeline/`)
+
+经验管道指标跟踪通过各种管道操作符处理经验。每个指标表示一个步骤中特定操作符的计数。
+
+- **格式**:`experience_pipeline/{metric_name}`
+- **示例**:
+ - `experience_pipeline/experience_count`:处理的经验数量
+ - `experience_pipeline/group_advantages/reward_mean/mean`:这里 `reward_mean` 是每个任务的平均奖励,然后我们计算步骤中所有任务的平均奖励的平均值。
+
+以下图表说明了数据处理指标的聚合过程:
+```mermaid
+graph TD
+ subgraph Step["4 Experiences in one step"]
+ subgraph Task1["Experience 1"]
+ Run1_1["Run 1
reward_mean: 0.8"]
+ Run1_2["Run 2
reward_mean: 0.8"]
+ Run2_1["Run 3
reward_mean: 0.9"]
+ Run2_2["Run 4
reward_mean: 0.9"]
+ Run1_1 --> Task1_Metric["rollout/accuracy
= 0.85"]
+ Run1_2 --> Task1_Metric
+ Run2_1 --> Task1_Metric
+ Run2_2 --> Task1_Metric
+ end
+ end
+```
+
+#### 采样相关指标(`sample/`)
+
+采样指标跟踪训练期间的数据采样统计。
+
+- **格式**:`sample/{metric_name}`
+- **示例**:
+ - `sample/model_version/mean`:采样经验的平均模型版本
+ - `sample/task_count`:采样批次中的任务数量
diff --git a/pyproject.toml b/pyproject.toml
index 29970181b2..9de73beeb6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,7 +56,7 @@ data = [
"py-data-juicer>=1.4.3"
]
agent = [
- "agentscope>=1.0.9"
+ "agentscope>=1.0.12"
]
rm_gallery = [
"rm-gallery>=0.1.5"
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000000..d5385884f3
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,38 @@
+import datetime
+
+import pytest
+
+
+# Get the result of each test
+@pytest.hookimpl(tryfirst=True, hookwrapper=True)
+def pytest_runtest_makereport(item, call):
+ outcome = yield
+ rep = outcome.get_result()
+ setattr(item, "rep_" + rep.when, rep)
+
+
+# Real-time print of start and end of test
+@pytest.fixture(autouse=True)
+def log_test_lifecycle(request):
+ node_id = request.node.nodeid
+ start_time = datetime.datetime.now().strftime("%H:%M:%S")
+
+ print(f"\n[START] {start_time} - Running: {node_id}")
+
+ yield
+
+ end_time = datetime.datetime.now().strftime("%H:%M:%S")
+ # Get the result of each test (setup, call, teardown)
+ report = getattr(request.node, "rep_call", None)
+
+ if report:
+ if report.passed:
+ status = "PASSED"
+ elif report.failed:
+ status = "FAILED"
+ else:
+ status = report.outcome.upper()
+ else:
+ status = "UNKNOWN"
+
+ print(f"\n[END] {end_time} - Result: {status} - {node_id}")
diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py
index b17bf7709b..dd4bf6c58e 100644
--- a/tests/explorer/explorer_test.py
+++ b/tests/explorer/explorer_test.py
@@ -43,6 +43,7 @@ def setUp(self):
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.synchronizer.sync_interval = 2
self.config.explorer.eval_interval = 4
+ self.config.monitor.detailed_stats = False
class TestExplorerCountdownEval(BaseExplorerCase):
@@ -69,14 +70,49 @@ def test_explorer(self):
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
for eval_taskset, k_list in zip(eval_tasksets, [[1], [2, 4, 6], [2, 4, 8, 10]]):
- for eval_stats in ["mean", "best", "worst"]:
- for k in k_list:
- for stats in ["mean", "std"]:
- metric_name = "score" if eval_taskset.name == "countdown" else "accuracy"
- self.assertIn(
- f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}/{stats}",
- eval_metrics,
- )
+ metric_name = "score" if eval_taskset.name == "countdown" else "accuracy"
+ repeat_times = k_list[-1]
+ expected_stat_suffixes = [f"mean@{repeat_times}", f"std@{repeat_times}"]
+ for k in k_list:
+ if k == 1:
+ continue
+ expected_stat_suffixes.extend([f"best@{k}", f"worst@{k}"])
+ # only return the mean of the column
+ for stat_suffix in expected_stat_suffixes:
+ self.assertIn(
+ f"eval/{eval_taskset.name}/{metric_name}/{stat_suffix}",
+ eval_metrics,
+ )
+
+
+class TestExplorerEvalDetailedStats(BaseExplorerCase):
+ def test_explorer(self):
+ self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
+ self.config.monitor.detailed_stats = True
+ eval_taskset = get_unittest_dataset_config("eval_short")
+ eval_taskset.repeat_times = 6
+ self.config.buffer.explorer_input.eval_tasksets = [eval_taskset]
+ self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
+ self.config.check_and_update()
+ explore(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ eval_metrics = parser.metric_list("eval")
+ self.assertTrue(len(eval_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
+ self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
+ metric_name, repeat_times, k_list = "accuracy", 6, [2, 4, 6]
+ expected_stat_suffixes = [f"mean@{repeat_times}", f"std@{repeat_times}"]
+ for k in k_list: # k_list does not include 1
+ expected_stat_suffixes.extend([f"best@{k}", f"worst@{k}"])
+ # test detailed stats
+ for stat_suffix in expected_stat_suffixes:
+ for stats in ["mean", "std", "max", "min"]:
+ self.assertIn(
+ f"eval/{eval_taskset.name}/{metric_name}/{stat_suffix}/{stats}",
+ eval_metrics,
+ )
class TestExplorerGSM8KRULERNoEval(BaseExplorerCase):
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index 94182937a2..eb225f2390 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -172,12 +172,14 @@ def test_trainer(self):
for taskset_name in ["countdown", "copy_countdown"]:
metrics = parser.metric_list(f"{prefix}/{taskset_name}")
self.assertGreater(len(metrics), 0, f"{prefix}/{taskset_name} metrics not found")
- for eval_stats in ["mean", "best", "worst"]:
- for k in [2, 4]:
- for stats in ["mean", "std"]:
- metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}"
- metric_steps = parser.metric_steps(metric_name)
- self.assertEqual(metric_steps, [0, 4, 8])
+ repeat_times, k_list = 4, [2, 4]
+ expected_stat_suffixes = [f"mean@{repeat_times}", f"std@{repeat_times}"]
+ for k in k_list:
+ expected_stat_suffixes.extend([f"best@{k}", f"worst@{k}"])
+ for stat_suffix in expected_stat_suffixes:
+ metric_name = f"{prefix}/{taskset_name}/score/{stat_suffix}"
+ metric_steps = parser.metric_steps(metric_name)
+ self.assertEqual(metric_steps, [0, 4, 8])
def tearDown(self):
# remove dir only when the test passed
@@ -1338,12 +1340,14 @@ def test_trainer(self):
for prefix in ["eval", "bench"]:
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
self.assertGreater(len(gsm8k_metrics), 0, f"{prefix}/gsm8k metrics not found")
- for eval_stats in ["mean", "best", "worst"]:
- for k in [2, 4, 8]:
- for stats in ["mean", "std"]:
- metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}"
- metric_steps = parser.metric_steps(metric_name)
- self.assertEqual(metric_steps, [0, 2])
+ repeat_times, k_list = 8, [2, 4, 8]
+ expected_stat_suffixes = [f"mean@{repeat_times}", f"std@{repeat_times}"]
+ for k in k_list:
+ expected_stat_suffixes.extend([f"best@{k}", f"worst@{k}"])
+ for stat_suffix in expected_stat_suffixes:
+ metric_name = f"{prefix}/gsm8k/accuracy/{stat_suffix}"
+ metric_steps = parser.metric_steps(metric_name)
+ self.assertEqual(metric_steps, [0, 2])
def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 581e07d6a7..4eed89540e 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -746,6 +746,8 @@ class MonitorConfig:
monitor_type: str = "tensorboard"
# the default args for monitor
monitor_args: Optional[Dict] = None
+ # whether to return detailed stats (mean, std, max, min) for evaluation metrics
+ detailed_stats: bool = False
# whether to enable ray timeline profile
# the output file will be saved to `cache_dir/timeline.json`
enable_ray_timeline: bool = False
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index b0893b8c52..617e4c4a1b 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -30,7 +30,7 @@
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.annotations import Experimental
from trinity.utils.log import get_logger
-from trinity.utils.monitor import MONITOR, gather_metrics
+from trinity.utils.monitor import MONITOR, gather_eval_metrics, gather_metrics
from trinity.utils.plugin_loader import load_plugins
from trinity.utils.timer import Timer
@@ -66,6 +66,7 @@ def __init__(self, config: Config):
role=self.config.explorer.name,
config=config,
)
+ self.detailed_stats = config.monitor.detailed_stats
if config.explorer.over_rollout.ratio > 0.0:
self.min_wait_num = math.ceil(
config.buffer.batch_size * (1 - config.explorer.over_rollout.ratio)
@@ -431,10 +432,10 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
statuses, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}")
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
metric.update(
- gather_metrics(
+ gather_eval_metrics(
[status.metrics[0] for status in statuses],
f"{prefix}/{eval_task_name}",
- output_stats=["mean", "std"],
+ detailed_stats=self.detailed_stats,
)
)
if self.eval_start_time is not None:
diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py
index f84bb6ea26..e1bb8fde5d 100644
--- a/trinity/explorer/scheduler.py
+++ b/trinity/explorer/scheduler.py
@@ -6,8 +6,9 @@
import traceback
from collections import defaultdict, deque
from dataclasses import dataclass, field, replace
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import numpy as np
import ray
from trinity.common.config import Config
@@ -31,6 +32,48 @@ class TaskWrapper:
results: List[Tuple[Status, List[Experience]]] = field(default_factory=list)
+# Adapted from verl/trainer/ppo/metric_utils.py
+def bootstrap_metric(
+ data: list[Any],
+ subset_size: int,
+ reduce_fns: list[Callable[[np.ndarray], float]],
+ n_bootstrap: int = 1000,
+ seed: int = 42,
+) -> list[tuple[float, float]]:
+ """
+ Performs bootstrap resampling to estimate statistics of metrics.
+
+ This function uses bootstrap resampling to estimate the mean and standard deviation
+ of metrics computed by the provided reduction functions on random subsets of the data.
+
+ Args:
+ data: List of data points to bootstrap from.
+ subset_size: Size of each bootstrap sample.
+ reduce_fns: List of functions that compute a metric from a subset of data.
+ n_bootstrap: Number of bootstrap iterations. Defaults to 1000.
+ seed: Random seed for reproducibility. Defaults to 42.
+
+ Returns:
+ A list of tuples, where each tuple contains (mean, std) for a metric
+ corresponding to each reduction function in reduce_fns.
+
+ Example:
+ >>> data = [1, 2, 3, 4, 5]
+ >>> reduce_fns = [np.mean, np.max]
+ >>> bootstrap_metric(data, 3, reduce_fns)
+ [(3.0, 0.5), (4.5, 0.3)] # Example values
+ """
+ np.random.seed(seed)
+
+ bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]
+ for _ in range(n_bootstrap):
+ bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True)
+ bootstrap_data = [data[i] for i in bootstrap_idxs]
+ for i, reduce_fn in enumerate(reduce_fns):
+ bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data))
+ return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts]
+
+
def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str, float]:
"""Calculate task level metrics (mean) from multiple runs of the same task.
@@ -54,16 +97,25 @@ def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str
if "time/task_execution" in key or "time/run_execution" in key:
result[key] = sum(values) / len(values)
continue
- k_list = []
- k = 2
- while k < len(values):
- k_list.append(k)
- k *= 2
- k_list.append(len(values))
- for k in k_list:
- result[f"{key}/mean@{k}"] = sum(values[:k]) / k
- result[f"{key}/best@{k}"] = max(values[:k])
- result[f"{key}/worst@{k}"] = min(values[:k])
+
+ n_values = len(values)
+ result[f"{key}/mean@{n_values}"] = np.mean(values)
+ result[f"{key}/std@{n_values}"] = np.std(values)
+
+ if n_values > 1:
+ ns = []
+ n = 2
+ while n < n_values:
+ ns.append(n)
+ n *= 2
+ ns.append(n_values)
+
+ for n in ns:
+ [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
+ data=values, subset_size=n, reduce_fns=[np.max, np.min], seed=42
+ )
+ result[f"{key}/best@{n}"] = bon_mean
+ result[f"{key}/worst@{n}"] = won_mean
return result
else:
return {
diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py
index 21ef7726f1..d421e2b279 100644
--- a/trinity/utils/monitor.py
+++ b/trinity/utils/monitor.py
@@ -57,6 +57,32 @@ def gather_metrics(
raise ValueError(f"Failed to gather metrics: {e}") from e
+def gather_eval_metrics(
+ metric_list: List[Dict],
+ prefix: str,
+ output_stats: List[str] = ["mean", "max", "min", "std"],
+ detailed_stats: bool = False,
+) -> Dict:
+ if not metric_list:
+ return {}
+ try:
+ df = pd.DataFrame(metric_list)
+ numeric_df = df.select_dtypes(include=[np.number])
+ metric = {}
+ for col in numeric_df.columns:
+ if detailed_stats:
+ stats_df = numeric_df[[col]].agg(output_stats)
+ for stats in output_stats:
+ metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item()
+ else:
+ # only return the mean of the column
+ metric[f"{prefix}/{col}"] = numeric_df[col].mean()
+
+ return metric
+ except Exception as e:
+ raise ValueError(f"Failed to gather eval metrics: {e}") from e
+
+
class Monitor(ABC):
"""Monitor"""