Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ async def test_task_scheduler_simple(self):
self.assertEqual(len(task_scheduler_state), 1)
self.assertEqual(task_scheduler_state[0]["current_index"], 4)
# no effect
task_scheduler.update({"metric1": 0.5})
task_scheduler.feedback({"metric1": 0.5})

task_scheduler = get_taskset_scheduler(
{
Expand Down
48 changes: 30 additions & 18 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def select_batch(self, indices: List[int]) -> List:
batch = []
for i in indices:
assert 0 <= i < self.dataset_size
if self.current_offset >= self.total_samples:
if not self.drop_last and len(batch) > 0:
break
self.progress_bar.close()
raise StopIteration
batch.append(self.dataset[int(i)])
self.current_offset += 1

self.progress_bar.update(len(batch)) # update progress bar
return batch

Expand All @@ -104,20 +111,16 @@ def __init__(self, config: StorageConfig):
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
return self.reader.read(batch_size)

def read_with_indices(self, indices: List[int]) -> List:
"""Read tasks with indices."""
return self.reader.read_with_indices(indices)

async def read_with_indices_async(self, indices: List[int]) -> List:
"""Read tasks with indices asynchronously."""
return await self.reader.read_with_indices_async(indices)

def state_dict(self):
return self.reader.state_dict()

def load_state_dict(self, state_dict):
return self.reader.load_state_dict(state_dict)

def feedback(self, **pipeline_metrics):
if self.reader.selector is not None:
self.reader.selector.feedback(**pipeline_metrics)

def __len__(self):
return self.reader.__len__()

Expand All @@ -139,6 +142,7 @@ def __init__(self, config: StorageConfig):
total_steps=config.total_steps,
enable_progress_bar=config.enable_progress_bar,
)
self.selector = None

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size)
Expand Down Expand Up @@ -178,6 +182,15 @@ def __init__(self, config: StorageConfig):
enable_progress_bar=self.config.enable_progress_bar,
)
self.formatter = FORMATTER.get("task")(config)
if self.config.task_selector is not None:
from trinity.buffer.selector import SELECTORS
from trinity.buffer.selector.selector import BaseSelector

self.selector: BaseSelector = SELECTORS.get(self.config.task_selector.selector_type)(
self.dataset, self.config.task_selector
)
else:
self.selector = None

def _get_tasks(self, samples: List, indices: List) -> List:
tasks = []
Expand All @@ -189,22 +202,21 @@ def _get_tasks(self, samples: List, indices: List) -> List:

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = batch_size or self.read_batch_size
samples, indices = self.dataset.read_batch(batch_size)
return self._get_tasks(samples, indices)

def read_with_indices(self, indices: List[int]) -> List:
"""Read tasks with indices."""
samples = self.dataset.select_batch(indices)
if self.selector is not None:
indices = self.selector.get_indices(batch_size)
samples = self.dataset.select_batch(indices)
else:
samples, indices = self.dataset.read_batch(batch_size)
return self._get_tasks(samples, indices)

async def read_with_indices_async(self, indices: List[int]) -> List:
"""Read tasks with indices asynchronously."""
return self.read_with_indices(indices)

def state_dict(self):
if self.selector is not None:
return self.selector.state_dict()
return {"current_index": self.dataset.current_offset}

def load_state_dict(self, state_dict):
if self.selector is not None:
self.selector.load_state_dict(state_dict)
self.dataset.current_offset = state_dict["current_index"]

def __len__(self):
Expand Down
12 changes: 6 additions & 6 deletions trinity/buffer/selector/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
"""
raise NotImplementedError

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document should be update accordingly

"""
Update internal state based on feedback (e.g., model loss, accuracy).

Expand Down Expand Up @@ -95,7 +95,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
return list(range(start, end))
return list(range(start, self.dataset_size)) + list(range(0, end - self.dataset_size))

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# No-op: sequential selection doesn't adapt based on feedback
pass

Expand Down Expand Up @@ -150,7 +150,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
self.current_index += batch_size
return ret

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# No-op: static shuffling does not adapt
pass

Expand Down Expand Up @@ -188,7 +188,7 @@ def get_indices(self, batch_size, return_extra_info=False):
else:
return selected_indices

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# No-op: basic random selection doesn't adapt
pass

Expand Down Expand Up @@ -239,7 +239,7 @@ def __init__(self, data_source, config: TaskSelectorConfig):
self.dataset_size = data_source.dataset_size
self.current_index = 0

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
# No-op: this selector does not adapt based on runtime feedback
pass

Expand Down Expand Up @@ -340,7 +340,7 @@ def build_diff_estimator(self, dataset, feature_keys: List[str], config: dict):
adaptive_rho=adaptive_rho,
)

def update(self, indices: List[int], values: List[float]) -> None:
def feedback(self, indices: List[int], values: List[float]) -> None:
"""
Updates the difficulty estimator with observed performance on selected samples.

Expand Down
37 changes: 15 additions & 22 deletions trinity/buffer/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
"""The taskset scheduler."""

from collections import Counter
from copy import deepcopy
from typing import Dict, List

import numpy as np

from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.selector import SELECTORS
from trinity.common.config import Config
from trinity.common.constants import SELECTOR_METRIC
from trinity.utils.annotations import Experimental
Expand Down Expand Up @@ -47,7 +47,7 @@ def state_dict(self) -> List[Dict]:
"""
raise NotImplementedError

def update(self, pipeline_metrics: Dict) -> None:
def feedback(self, pipeline_metrics: Dict) -> None:
"""Update selectors using feedback from the training pipeline."""
raise NotImplementedError

Expand All @@ -68,16 +68,18 @@ def __init__(self, explorer_state: Dict, config: Config):
index = self.explorer_state.get("taskset_states", [{"current_index": 0}])[0].get(
"current_index", 0
)
self.config.buffer.explorer_input.tasksets[0].index = index
self.reader = get_buffer_reader(config.buffer.explorer_input.tasksets[0])
taskset_config = deepcopy(self.config.buffer.explorer_input.tasksets[0])
taskset_config.index = index
taskset_config.task_selector = None # disable selection
self.reader = get_buffer_reader(taskset_config)

async def read_async(self) -> List:
return await self.reader.read_async()

def state_dict(self) -> List[Dict]:
return [self.reader.state_dict()]

def update(self, pipeline_metrics: Dict) -> None:
def feedback(self, pipeline_metrics: Dict) -> None:
# do nothing here
return

Expand Down Expand Up @@ -127,7 +129,6 @@ def __init__(self, explorer_state: Dict, config: Config):
"taskset_states", [{"current_index": 0}] * len(taskset_configs)
)
self.tasksets = []
self.selectors = []
for taskset_config, taskset_state in zip(taskset_configs, taskset_states):
assert not taskset_config.is_eval # assume drop last
taskset = get_buffer_reader(taskset_config)
Expand All @@ -136,15 +137,8 @@ def __init__(self, explorer_state: Dict, config: Config):
f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'."
f"Currently, only 'FileReader' is supported by TasksetScheduler."
)

# Create selector based on type specified in config (e.g., 'sequential', 'shuffle')
selector = SELECTORS.get(taskset_config.task_selector.selector_type)(
taskset.reader.dataset, taskset_config.task_selector
)
selector.load_state_dict(taskset_state) # Restore any prior state

taskset.load_state_dict(taskset_state) # Restore any prior state
self.tasksets.append(taskset)
self.selectors.append(selector)

# Each explorer step calls read_async once → track step globally
self.step = explorer_state.get("latest_iteration", 0)
Expand Down Expand Up @@ -224,8 +218,7 @@ async def read_async(self) -> List:
counter = Counter(taskset_ids)
batch = []
for taskset_id, count in counter.items():
indices = self.selectors[taskset_id].get_indices(batch_size=count)
tasks = await self.tasksets[taskset_id].read_with_indices_async(indices)
tasks = await self.tasksets[taskset_id].read_async(batch_size=count)
# Annotate each task with its origin
for task in tasks:
task.index["taskset_id"] = taskset_id
Expand All @@ -239,13 +232,13 @@ def state_dict(self) -> List[Dict]:
Save persistent state for checkpointing.

Returns:
List[Dict]: State dicts for all selectors (one per taskset)
List[Dict]: State dicts for all tasksets
"""
return [selector.state_dict() for selector in self.selectors]
return [taskset.state_dict() for taskset in self.tasksets]

def update(self, pipeline_metrics: Dict) -> None:
def feedback(self, pipeline_metrics: Dict) -> None:
"""
Update selectors using feedback from the training pipeline.
Update selectors in tasksets using feedback from the training pipeline.

Expected format:
pipeline_metrics = {
Expand All @@ -265,5 +258,5 @@ def update(self, pipeline_metrics: Dict) -> None:
return
selector_metric = pipeline_metrics.pop(SELECTOR_METRIC, {})
for taskset_id, taskset_kwargs in selector_metric.items():
selector = self.selectors[taskset_id]
selector.update(**taskset_kwargs)
taskset = self.tasksets[taskset_id]
taskset.feedback(**taskset_kwargs)
5 changes: 5 additions & 0 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ def model_name(self) -> Optional[str]:
"""Get the name of the model."""
return self._model_name

@property
def model_config(self) -> InferenceModelConfig:
"""Get the model config."""
return self.config

@property
def generate_kwargs(self) -> Dict[str, Any]:
"""Get the generation kwargs for openai client."""
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None:
batch_id=step, min_num=self.min_wait_num
)
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
self.taskset.update(pipeline_metrics)
self.taskset.feedback(pipeline_metrics)
metric.update(pipeline_metrics)
if statuses:
metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout"))
Expand Down