From 593087c0461379fcc2317417c4042bf8dfdf0ee3 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 21 Jan 2026 15:07:50 +0800 Subject: [PATCH 1/4] Move `selector` into `TaskFileReader` --- trinity/buffer/reader/file_reader.py | 44 ++++++++++++++++------------ trinity/buffer/task_scheduler.py | 31 ++++++++------------ 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index ac4d728263..44f79e677f 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -104,20 +104,16 @@ def __init__(self, config: StorageConfig): def read(self, batch_size: Optional[int] = None, **kwargs) -> List: return self.reader.read(batch_size) - def read_with_indices(self, indices: List[int]) -> List: - """Read tasks with indices.""" - return self.reader.read_with_indices(indices) - - async def read_with_indices_async(self, indices: List[int]) -> List: - """Read tasks with indices asynchronously.""" - return await self.reader.read_with_indices_async(indices) - def state_dict(self): return self.reader.state_dict() def load_state_dict(self, state_dict): return self.reader.load_state_dict(state_dict) + def update(self, **pipeline_metrics: dict): + if self.reader.selector is not None: + self.reader.selector.update(**pipeline_metrics) + def __len__(self): return self.reader.__len__() @@ -139,6 +135,7 @@ def __init__(self, config: StorageConfig): total_steps=config.total_steps, enable_progress_bar=config.enable_progress_bar, ) + self.selector = None def read(self, batch_size: Optional[int] = None, **kwargs) -> List: samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size) @@ -178,6 +175,15 @@ def __init__(self, config: StorageConfig): enable_progress_bar=self.config.enable_progress_bar, ) self.formatter = FORMATTER.get("task")(config) + if self.config.task_selector is not None: + from trinity.buffer.selector import SELECTORS + from trinity.buffer.selector.selector import BaseSelector + + self.selector: BaseSelector = SELECTORS.get(self.config.task_selector.selector_type)( + self.dataset, self.config.task_selector + ) + else: + self.selector = None def _get_tasks(self, samples: List, indices: List) -> List: tasks = [] @@ -189,23 +195,23 @@ def _get_tasks(self, samples: List, indices: List) -> List: def read(self, batch_size: Optional[int] = None, **kwargs) -> List: batch_size = batch_size or self.read_batch_size - samples, indices = self.dataset.read_batch(batch_size) - return self._get_tasks(samples, indices) - - def read_with_indices(self, indices: List[int]) -> List: - """Read tasks with indices.""" - samples = self.dataset.select_batch(indices) + if self.selector is not None: + indices = self.selector.get_indices(batch_size) + samples = self.dataset.select_batch(indices) + else: + samples, indices = self.dataset.read_batch(batch_size) return self._get_tasks(samples, indices) - async def read_with_indices_async(self, indices: List[int]) -> List: - """Read tasks with indices asynchronously.""" - return self.read_with_indices(indices) - def state_dict(self): + if self.selector is not None: + return self.selector.state_dict() return {"current_index": self.dataset.current_offset} def load_state_dict(self, state_dict): - self.dataset.current_offset = state_dict["current_index"] + if self.selector is not None: + self.selector.load_state_dict(state_dict) + else: + self.dataset.current_offset = state_dict["current_index"] def __len__(self): return self.dataset.dataset_size diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 713f836e90..2592269053 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -2,12 +2,12 @@ """The taskset scheduler.""" from collections import Counter +from copy import deepcopy from typing import Dict, List import numpy as np from trinity.buffer.buffer import get_buffer_reader -from trinity.buffer.selector import SELECTORS from trinity.common.config import Config from trinity.common.constants import SELECTOR_METRIC from trinity.utils.annotations import Experimental @@ -68,8 +68,10 @@ def __init__(self, explorer_state: Dict, config: Config): index = self.explorer_state.get("taskset_states", [{"current_index": 0}])[0].get( "current_index", 0 ) - self.config.buffer.explorer_input.tasksets[0].index = index - self.reader = get_buffer_reader(config.buffer.explorer_input.tasksets[0]) + taskset_config = deepcopy(self.config.buffer.explorer_input.tasksets[0]) + taskset_config.index = index + taskset_config.task_selector = None # disable selection + self.reader = get_buffer_reader(taskset_config) async def read_async(self) -> List: return await self.reader.read_async() @@ -127,7 +129,6 @@ def __init__(self, explorer_state: Dict, config: Config): "taskset_states", [{"current_index": 0}] * len(taskset_configs) ) self.tasksets = [] - self.selectors = [] for taskset_config, taskset_state in zip(taskset_configs, taskset_states): assert not taskset_config.is_eval # assume drop last taskset = get_buffer_reader(taskset_config) @@ -136,15 +137,8 @@ def __init__(self, explorer_state: Dict, config: Config): f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'." f"Currently, only 'FileReader' is supported by TasksetScheduler." ) - - # Create selector based on type specified in config (e.g., 'sequential', 'shuffle') - selector = SELECTORS.get(taskset_config.task_selector.selector_type)( - taskset.reader.dataset, taskset_config.task_selector - ) - selector.load_state_dict(taskset_state) # Restore any prior state - + taskset.load_state_dict(taskset_state) # Restore any prior state self.tasksets.append(taskset) - self.selectors.append(selector) # Each explorer step calls read_async once → track step globally self.step = explorer_state.get("latest_iteration", 0) @@ -224,8 +218,7 @@ async def read_async(self) -> List: counter = Counter(taskset_ids) batch = [] for taskset_id, count in counter.items(): - indices = self.selectors[taskset_id].get_indices(batch_size=count) - tasks = await self.tasksets[taskset_id].read_with_indices_async(indices) + tasks = await self.tasksets[taskset_id].read_async(batch_size=count) # Annotate each task with its origin for task in tasks: task.index["taskset_id"] = taskset_id @@ -239,13 +232,13 @@ def state_dict(self) -> List[Dict]: Save persistent state for checkpointing. Returns: - List[Dict]: State dicts for all selectors (one per taskset) + List[Dict]: State dicts for all tasksets """ - return [selector.state_dict() for selector in self.selectors] + return [taskset.state_dict() for taskset in self.tasksets] def update(self, pipeline_metrics: Dict) -> None: """ - Update selectors using feedback from the training pipeline. + Update selectors in tasksets using feedback from the training pipeline. Expected format: pipeline_metrics = { @@ -265,5 +258,5 @@ def update(self, pipeline_metrics: Dict) -> None: return selector_metric = pipeline_metrics.pop(SELECTOR_METRIC, {}) for taskset_id, taskset_kwargs in selector_metric.items(): - selector = self.selectors[taskset_id] - selector.update(**taskset_kwargs) + taskset = self.tasksets[taskset_id] + taskset.update(**taskset_kwargs) From e9feba4097952155e3ebd699968c3394f3d214de Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 21 Jan 2026 17:27:12 +0800 Subject: [PATCH 2/4] apply reviews and fix unittest --- tests/buffer/task_scheduler_test.py | 2 +- trinity/buffer/reader/file_reader.py | 14 ++++++++++---- trinity/buffer/selector/selector.py | 12 ++++++------ trinity/buffer/task_scheduler.py | 8 ++++---- trinity/explorer/explorer.py | 2 +- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index 901bd51aa3..6792290d68 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -340,7 +340,7 @@ async def test_task_scheduler_simple(self): self.assertEqual(len(task_scheduler_state), 1) self.assertEqual(task_scheduler_state[0]["current_index"], 4) # no effect - task_scheduler.update({"metric1": 0.5}) + task_scheduler.feedback({"metric1": 0.5}) task_scheduler = get_taskset_scheduler( { diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 44f79e677f..6b668d8994 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -79,7 +79,14 @@ def select_batch(self, indices: List[int]) -> List: batch = [] for i in indices: assert 0 <= i < self.dataset_size + if self.current_offset >= self.total_samples: + if not self.drop_last and len(batch) > 0: + break + self.progress_bar.close() + raise StopIteration batch.append(self.dataset[int(i)]) + self.current_offset += 1 + self.progress_bar.update(len(batch)) # update progress bar return batch @@ -110,9 +117,9 @@ def state_dict(self): def load_state_dict(self, state_dict): return self.reader.load_state_dict(state_dict) - def update(self, **pipeline_metrics: dict): + def feedback(self, **pipeline_metrics): if self.reader.selector is not None: - self.reader.selector.update(**pipeline_metrics) + self.reader.selector.feedback(**pipeline_metrics) def __len__(self): return self.reader.__len__() @@ -210,8 +217,7 @@ def state_dict(self): def load_state_dict(self, state_dict): if self.selector is not None: self.selector.load_state_dict(state_dict) - else: - self.dataset.current_offset = state_dict["current_index"] + self.dataset.current_offset = state_dict["current_index"] def __len__(self): return self.dataset.dataset_size diff --git a/trinity/buffer/selector/selector.py b/trinity/buffer/selector/selector.py index 8782244f55..a67036dd45 100644 --- a/trinity/buffer/selector/selector.py +++ b/trinity/buffer/selector/selector.py @@ -44,7 +44,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[ """ raise NotImplementedError - def update(self, indices: List[int], values: List[float]) -> None: + def feedback(self, indices: List[int], values: List[float]) -> None: """ Update internal state based on feedback (e.g., model loss, accuracy). @@ -95,7 +95,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[ return list(range(start, end)) return list(range(start, self.dataset_size)) + list(range(0, end - self.dataset_size)) - def update(self, indices: List[int], values: List[float]) -> None: + def feedback(self, indices: List[int], values: List[float]) -> None: # No-op: sequential selection doesn't adapt based on feedback pass @@ -150,7 +150,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[ self.current_index += batch_size return ret - def update(self, indices: List[int], values: List[float]) -> None: + def feedback(self, indices: List[int], values: List[float]) -> None: # No-op: static shuffling does not adapt pass @@ -188,7 +188,7 @@ def get_indices(self, batch_size, return_extra_info=False): else: return selected_indices - def update(self, indices: List[int], values: List[float]) -> None: + def feedback(self, indices: List[int], values: List[float]) -> None: # No-op: basic random selection doesn't adapt pass @@ -239,7 +239,7 @@ def __init__(self, data_source, config: TaskSelectorConfig): self.dataset_size = data_source.dataset_size self.current_index = 0 - def update(self, indices: List[int], values: List[float]) -> None: + def feedback(self, indices: List[int], values: List[float]) -> None: # No-op: this selector does not adapt based on runtime feedback pass @@ -340,7 +340,7 @@ def build_diff_estimator(self, dataset, feature_keys: List[str], config: dict): adaptive_rho=adaptive_rho, ) - def update(self, indices: List[int], values: List[float]) -> None: + def feedback(self, indices: List[int], values: List[float]) -> None: """ Updates the difficulty estimator with observed performance on selected samples. diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 2592269053..c511886145 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -47,7 +47,7 @@ def state_dict(self) -> List[Dict]: """ raise NotImplementedError - def update(self, pipeline_metrics: Dict) -> None: + def feedback(self, pipeline_metrics: Dict) -> None: """Update selectors using feedback from the training pipeline.""" raise NotImplementedError @@ -79,7 +79,7 @@ async def read_async(self) -> List: def state_dict(self) -> List[Dict]: return [self.reader.state_dict()] - def update(self, pipeline_metrics: Dict) -> None: + def feedback(self, pipeline_metrics: Dict) -> None: # do nothing here return @@ -236,7 +236,7 @@ def state_dict(self) -> List[Dict]: """ return [taskset.state_dict() for taskset in self.tasksets] - def update(self, pipeline_metrics: Dict) -> None: + def feedback(self, pipeline_metrics: Dict) -> None: """ Update selectors in tasksets using feedback from the training pipeline. @@ -259,4 +259,4 @@ def update(self, pipeline_metrics: Dict) -> None: selector_metric = pipeline_metrics.pop(SELECTOR_METRIC, {}) for taskset_id, taskset_kwargs in selector_metric.items(): taskset = self.tasksets[taskset_id] - taskset.update(**taskset_kwargs) + taskset.feedback(**taskset_kwargs) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index b0893b8c52..fdd2f63932 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -411,7 +411,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: batch_id=step, min_num=self.min_wait_num ) pipeline_metrics = await self.experience_pipeline.process.remote(exps) - self.taskset.update(pipeline_metrics) + self.taskset.feedback(pipeline_metrics) metric.update(pipeline_metrics) if statuses: metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout")) From 8613d8e182cc5c15018376b53aea61dcb4fecb12 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 21 Jan 2026 17:40:18 +0800 Subject: [PATCH 3/4] add `model_config` property for `ModelWrapper` --- trinity/common/models/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 5622511ca5..ed982fe9b9 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -454,6 +454,11 @@ def model_name(self) -> Optional[str]: """Get the name of the model.""" return self._model_name + @property + def model_config(self) -> InferenceModelConfig: + """Get the model config.""" + return self.config + @property def generate_kwargs(self) -> Dict[str, Any]: """Get the generation kwargs for openai client.""" From 21dd88785dd5971568c9eb86363185e8ab854b8f Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 22 Jan 2026 10:42:01 +0800 Subject: [PATCH 4/4] update docs --- docs/sphinx_doc/source/tutorial/develop_selector.md | 4 ++-- docs/sphinx_doc/source_zh/tutorial/develop_selector.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/develop_selector.md b/docs/sphinx_doc/source/tutorial/develop_selector.md index a63da515a1..d0ddbbce31 100644 --- a/docs/sphinx_doc/source/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source/tutorial/develop_selector.md @@ -51,7 +51,7 @@ To create a new selector, inherit from `BaseSelector` and implement the followin | Method | Purpose | |-------|--------| | `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | Return a list of sample indices to read next. | -| `update(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). | +| `feedback(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). | | `state_dict() -> Dict` | Serialize current state for checkpointing. | | `load_state_dict(state_dict: Dict)` | Restore state from a saved dictionary. | @@ -113,7 +113,7 @@ class DifficultyBasedSelector(BaseSelector): else: return selected_indices - def update(self, indices: List[int], values: List[float]) -> None: + def feedback(self, indices: List[int], values: List[float]) -> None: # Update difficulty model with observed rewards self.diff_estimator.update(indices, values) diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md index 1d08b42508..a4e7fedb15 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md @@ -49,7 +49,7 @@ | 方法 | 功能说明 | |------|---------| | `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | 返回接下来要读取的样本索引列表。 | -| `update(indices: List[int], values: List[float])` | 使用反馈信息(如奖励、损失)更新内部状态,用于自适应调整。 | +| `feedback(indices: List[int], values: List[float])` | 使用反馈信息(如奖励、损失)更新内部状态,用于自适应调整。 | | `state_dict() -> Dict` | 序列化当前状态,用于保存检查点。 | | `load_state_dict(state_dict: Dict)` | 从保存的状态字典中恢复选择器状态。 | @@ -111,7 +111,7 @@ class DifficultyBasedSelector(BaseSelector): else: return selected_indices - def update(self, indices: List[int], values: List[float]) -> None: + def feedback(self, indices: List[int], values: List[float]) -> None: # 使用观测到的奖励更新难度模型 self.diff_estimator.update(indices, values)