diff --git a/.gitignore b/.gitignore index ee5a0432f..ad425fb2e 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,9 @@ __pycache__/ # libraries /prime-rl +/packages/tasksets +/packages/harnesses + # outputs wandb/ diff --git a/docs/environments.md b/docs/environments.md index 84f884482..df9d908c6 100644 --- a/docs/environments.md +++ b/docs/environments.md @@ -567,7 +567,7 @@ class MyGameEnv(vf.MultiTurnEnv): return state.get("lives", 1) <= 0 ``` -`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, and `max_turns` by default. +`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and incomplete response detection by default. Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones: diff --git a/environments/wiki_search/pyproject.toml b/environments/wiki_search/pyproject.toml index 8bc7680f2..50c758500 100644 --- a/environments/wiki_search/pyproject.toml +++ b/environments/wiki_search/pyproject.toml @@ -3,9 +3,9 @@ name = "wiki-search" description = "Agentic RAG over Wikipedia pages for trivia Q&A" tags = ["wikipedia", "multi-turn", "agentic-search", "rag", "train", "eval", "llm-judge"] requires-python = ">=3.11" -version = "0.1.23" +version = "0.1.24" dependencies = [ - "verifiers>=0.1.9", + "verifiers>=0.1.11.dev0", "chromadb", "datasets", "openai", diff --git a/environments/wiki_search/wiki_search.py b/environments/wiki_search/wiki_search.py index 9ca18dc90..56e9f78a3 100644 --- a/environments/wiki_search/wiki_search.py +++ b/environments/wiki_search/wiki_search.py @@ -260,7 +260,10 @@ async def read_section(section_id: str) -> str: ) async def judge_reward_func(judge, prompt, completion, answer, state) -> float: - judge_response = await judge(prompt, completion, answer, state) + cleaned_completion = [ + {x["role"]: x["content"].split("")[-1] for x in completion} + ] + judge_response = await judge(prompt, cleaned_completion, answer, state) if "yes" in judge_response.lower(): return 1.0 else: diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 40ef74cbe..a1ccd3daf 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -156,7 +156,9 @@ def __init__( if dataset is not None: if callable(dataset): - self.dataset_source: DatasetBuilder | None = dataset + self.dataset_source: DatasetBuilder | None = cast( + DatasetBuilder, dataset + ) else: self.dataset_source = lambda ds=dataset: ds self.build_dataset() # Eagerly build for raw datasets (backwards compat) @@ -165,7 +167,9 @@ def __init__( if eval_dataset is not None: if callable(eval_dataset): - self.eval_dataset_source: DatasetBuilder | None = eval_dataset + self.eval_dataset_source: DatasetBuilder | None = cast( + DatasetBuilder, eval_dataset + ) else: self.eval_dataset_source = lambda ds=eval_dataset: ds self.build_eval_dataset() # Eagerly build for raw datasets (backwards compat) diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py index 4e2c867c8..70507807a 100644 --- a/verifiers/envs/multiturn_env.py +++ b/verifiers/envs/multiturn_env.py @@ -64,6 +64,10 @@ async def has_final_env_response(self, state: State) -> bool: """Check if env_response signaled termination via final_env_response.""" return state.get("final_env_response") is not None + @vf.stop + async def has_incomplete_response(self, state: State) -> bool: + return state.get("incomplete_response", False) + async def setup_state(self, state: State) -> State: """Override to add environment-specific state fields.""" return state @@ -121,9 +125,15 @@ async def add_model_response( ): completion_messages = await parse_response_message(response) tokens = await parse_response_tokens(response, self.max_seq_len) + has_content = bool(response.message.content) + has_tool_calls = bool(response.message.tool_calls) + if not has_content and not has_tool_calls: + state["incomplete_response"] = True response_is_truncated = response.message.is_truncated or False - is_truncated = response_is_truncated or ( - tokens is not None and bool(tokens.get("is_truncated")) + is_truncated = ( + response_is_truncated + or (tokens is not None and bool(tokens.get("is_truncated"))) + or state.get("incomplete_response", False) ) trajectory_step = TrajectoryStep( prompt=prompt_messages,