Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions verifiers/envs/env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,29 @@ async def rollout(
def get_env_for_task(self, task: str) -> vf.Environment:
return self.env_map.get(task, self.envs[0])

def get_prompt_components(self) -> dict[str, str]:
"""Return shared prompt components across all sub-environments.

Enforces strict matching: every sub-environment must expose identical
prompt component keys and identical initial values.
"""
if not self.envs:
return {}

reference = self.envs[0].get_prompt_components()
reference_name = self.env_names[0]

for env, name in zip(self.envs[1:], self.env_names[1:]):
components = env.get_prompt_components()
if components != reference:
raise ValueError(
"EnvGroup GEPA requires all sub-environments to expose identical "
"prompt components and initial values. "
f"Mismatch between '{reference_name}' and '{name}'."
)

return dict(reference)

def set_max_seq_len(self, max_seq_len: int | None) -> None:
"""Set the max_seq_len value for this environment group and all sub-environments."""
self.max_seq_len = max_seq_len
Expand Down
35 changes: 35 additions & 0 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,15 @@ async def init_state(
state_input = deepcopy(input)
if "info" in state_input and isinstance(state_input["info"], str):
state_input["info"] = json.loads(state_input["info"])
info = state_input.get("info")
if isinstance(info, dict):
prompt_components = info.get("prompt_components")
if isinstance(prompt_components, dict):
system_prompt = prompt_components.get("system_prompt")
if isinstance(system_prompt, str) and system_prompt:
state_input["prompt"] = self._inject_system_prompt_to_prompt(
state_input.get("prompt"), system_prompt
)
if "task" not in state_input:
state_input["task"] = self.env_id or "default"
state = State(input=RolloutInput(**state_input)) # type: ignore[missing-typed-dict-key]
Expand Down Expand Up @@ -718,6 +727,32 @@ async def rollout(
"""
pass

def get_prompt_components(self) -> dict[str, str]:
"""Return optimizable prompt components for GEPA.

Default is to just return the system prompt
"""
if not self.system_prompt:
return {}
return {"system_prompt": self.system_prompt}

@staticmethod
def _inject_system_prompt_to_prompt(
prompt: Messages | None,
system_prompt: str,
) -> Messages:
"""Inject or replace system prompt in a prompt payload."""
sys_msg = cast(ChatMessage, {"role": "system", "content": system_prompt})
if prompt is None or (isinstance(prompt, list) and not prompt):
return [sys_msg]
if isinstance(prompt, str):
return f"{system_prompt}\n\n{prompt}"
prompt_list = cast(List[ChatMessage], [dict(m) for m in prompt])
if prompt_list[0].get("role") == "system":
prompt_list[0]["content"] = system_prompt
return prompt_list
return [sys_msg] + prompt_list

async def _cleanup(self, state: State):
"""
Clean up rollout resources.
Expand Down
26 changes: 26 additions & 0 deletions verifiers/envs/experimental/rlm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3648,6 +3648,9 @@ async def setup_state(self, state: State, **kwargs) -> State:
state["rlm_fs_source"] = fs_source
state["rlm_fs_has_data"] = fs_has_data
state["retain_filesystem_after_rollout"] = self.retain_filesystem_after_rollout

# rolloutInput is stored in state so we should look at that info field and
# then set prompts accordingly
if self.custom_system_prompt:
base_system_prompt = self.custom_system_prompt
elif self.repl_language == "bash":
Expand All @@ -3659,6 +3662,14 @@ async def setup_state(self, state: State, **kwargs) -> State:
self.root_prompt_verbosity
]

info = state.get("info")
if isinstance(info, dict):
prompt_components = info.get("prompt_components")
if isinstance(prompt_components, dict):
override_prompt = prompt_components.get("base_system_prompt")
if isinstance(override_prompt, str) and override_prompt:
base_system_prompt = override_prompt

packages_docs = self._generate_packages_documentation()
root_tools_docs = self._generate_root_tools_documentation()
sub_tools_docs = self._generate_sub_tools_documentation()
Expand Down Expand Up @@ -3694,6 +3705,21 @@ async def setup_state(self, state: State, **kwargs) -> State:

return state

def get_prompt_components(self) -> dict[str, str]:
components = super().get_prompt_components()
if self.custom_system_prompt:
base_system_prompt = self.custom_system_prompt
elif self.repl_language == "bash":
base_system_prompt = _RLM_BASH_SYSTEM_PROMPT_STORE[
self.root_prompt_verbosity
]
else:
base_system_prompt = _RLM_PYTHON_SYSTEM_PROMPT_STORE[
self.root_prompt_verbosity
]
components["base_system_prompt"] = base_system_prompt
return components

# =========================================================================
# Code Execution
# =========================================================================
Expand Down
11 changes: 11 additions & 0 deletions verifiers/envs/tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def remove_tool(self, tool: Callable):
self.tool_map.pop(tool_name)
self.tool_monitor_rubric.remove_tool_metric(tool)

def get_prompt_components(self) -> dict[str, str]:
components = super().get_prompt_components()
for tool in self.oai_tools or []:
func = tool.get("function", {})
name = func.get("name")
if not name:
continue
description = func.get("description", "")
components[f"tool:{name}"] = description
return components

@vf.stop
async def no_tools_called(self, state: vf.State) -> bool:
if len(state["trajectory"]) == 0:
Expand Down
98 changes: 74 additions & 24 deletions verifiers/gepa/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence
Expand Down Expand Up @@ -86,7 +87,10 @@ def evaluate(
"""
Run verifiers evaluation with the candidate system prompt.
"""
inputs = _inject_system_prompt(batch, candidate.get("system_prompt", ""))

# Attach prompt components to info for envs that read them,
# leaving prompt mutation to the environment.
inputs = _attach_prompt_components(batch, candidate, self.env)

results = asyncio.get_event_loop().run_until_complete(
self.env.generate(
Expand All @@ -106,10 +110,10 @@ def evaluate(

# Update display if configured
if self.display is not None:
prompt_text = candidate.get("system_prompt", "")
if prompt_text not in self._seen_prompts:
self._seen_prompts[prompt_text] = len(self._seen_prompts)
candidate_idx = self._seen_prompts[prompt_text]
candidate_key = _candidate_key(candidate)
if candidate_key not in self._seen_prompts:
self._seen_prompts[candidate_key] = len(self._seen_prompts)
candidate_idx = self._seen_prompts[candidate_key]

self.display.update_eval(
candidate_idx=candidate_idx,
Expand Down Expand Up @@ -158,39 +162,77 @@ def make_reflective_dataset(

records.append(record)

# we might want this to become more sophisticated, only giving the relevant records for each component
return {comp: records for comp in components_to_update}


def _inject_system_prompt(
def _attach_prompt_components(
inputs: list[RolloutInput],
system_prompt: str,
candidate: dict[str, str],
env: Environment,
) -> list[RolloutInput]:
"""Inject or replace system prompt in each input's prompt."""
if not system_prompt:
"""Attach prompt components to info and update tool descriptions if provided."""
if not candidate:
return inputs

tool_overrides = {
key.split(":", 1)[1]: value
for key, value in candidate.items()
if key.startswith("tool:") and isinstance(value, str)
}

modified = []
for inp in inputs:
inp_copy = dict(inp)
prompt = inp_copy.get("prompt", [])

if isinstance(prompt, str):
inp_copy["prompt"] = f"{system_prompt}\n\n{prompt}"
else:
prompt = [dict(m) for m in prompt]
if not prompt:
# Empty prompt list - just add system message
prompt = [{"role": "system", "content": system_prompt}]
elif prompt[0].get("role") == "system":
prompt[0] = {**prompt[0], "content": system_prompt}
else:
prompt = [{"role": "system", "content": system_prompt}] + prompt
inp_copy["prompt"] = prompt

info = inp_copy.get("info")
if not isinstance(info, dict):
info = {}
info = dict(info)
info["prompt_components"] = dict(candidate)

tool_source_env = _resolve_env_for_input(env, inp_copy)
tool_source = getattr(tool_source_env, "oai_tools", None)
if tool_overrides and tool_source:
new_tools = []
for tool in tool_source or []:
tool_copy = dict(tool)
func = tool_copy.get("function")
if not isinstance(func, dict):
logger.warning("Skipping tool override: invalid tool function")
new_tools.append(tool_copy)
continue
tool_name = func.get("name")
if not tool_name:
logger.warning("Skipping tool override: missing tool name")
new_tools.append(tool_copy)
continue
if tool_name in tool_overrides:
func_copy = dict(func)
func_copy["description"] = tool_overrides[tool_name]
tool_copy["function"] = func_copy
new_tools.append(tool_copy)
info["oai_tools"] = new_tools

inp_copy["info"] = info
modified.append(inp_copy)
return modified


def _resolve_env_for_input(env: Environment, inp: RolloutInput) -> Environment:
"""Resolve per-task sub-environment for EnvGroup-like wrappers."""
task = inp.get("task")
if task is not None and hasattr(env, "get_env_for_task"):
get_env_for_task = getattr(env, "get_env_for_task")
if callable(get_env_for_task):
try:
resolved = get_env_for_task(task)
if isinstance(resolved, Environment):
return resolved
except Exception:
pass
return env


def _extract_user_query(prompt: Messages) -> str:
"""Extract user query from prompt, skipping system message."""
if isinstance(prompt, str):
Expand All @@ -202,3 +244,11 @@ def _extract_user_query(prompt: Messages) -> str:
return content
return str(content) if content else ""
return ""


def _candidate_key(candidate: dict[str, str]) -> str:
"""Stable key for multi-component candidates."""
try:
return json.dumps(candidate, sort_keys=True, ensure_ascii=True)
except (TypeError, ValueError):
return str(sorted(candidate.items()))
35 changes: 22 additions & 13 deletions verifiers/gepa/gepa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Simple artifact saving for GEPA optimization.

Saves:
- pareto_frontier.jsonl: Per valset row, the best prompt(s) and their scores
- best_prompt.txt: The single best overall system prompt
- pareto_frontier.jsonl: Per valset row, the best candidate(s) and their scores
- best_candidate.json: The single best overall candidate (all components)
- best_prompt.txt: The best system prompt (if present)
- metadata.json: Run configuration and summary
"""

Expand Down Expand Up @@ -74,35 +75,43 @@ def save_gepa_results(
best_prompts = [
{
"candidate_idx": cand_idx,
"system_prompt": candidates[cand_idx].get("system_prompt", ""),
"candidate": candidates[cand_idx],
"score": score,
}
for cand_idx, score in row_scores
if score == best_score
]

records.append({
"valset_row": row_idx,
"best_score": best_score,
"num_best_prompts": len(best_prompts),
"best_prompts": best_prompts,
})
records.append(
{
"valset_row": row_idx,
"best_score": best_score,
"num_best_prompts": len(best_prompts),
"best_prompts": best_prompts,
}
)

# Save frontier as JSONL
if records:
frontier_ds = Dataset.from_list(records)
frontier_ds.to_json(run_dir / "pareto_frontier.jsonl")

# Save best prompt as plain text
best_prompt = best_candidate.get("system_prompt", "")
(run_dir / "best_prompt.txt").write_text(best_prompt)
# Save best candidate as JSON
(run_dir / "best_candidate.json").write_text(json.dumps(best_candidate, indent=2))

# Save best system prompt as plain text (if present)
best_prompt = best_candidate.get("system_prompt")
if isinstance(best_prompt, str):
(run_dir / "best_prompt.txt").write_text(best_prompt)

# Build and save metadata
val_scores = getattr(result, "val_aggregate_scores", [])
metadata = {
"num_candidates": len(candidates),
"best_idx": best_idx,
"best_score": float(val_scores[best_idx]) if val_scores and best_idx < len(val_scores) else None,
"best_score": float(val_scores[best_idx])
if val_scores and best_idx < len(val_scores)
else None,
"total_metric_calls": getattr(result, "total_metric_calls", None),
"completed_at": datetime.now().isoformat(),
}
Expand Down
Loading
Loading