From fa18e8acf152325bfc1412ca74f83c4005a3e69c Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:44:44 -0500 Subject: [PATCH] Update task result references --- tests/tasks/test_tasks.py | 12 ++++++------ tests/test_run.py | 2 +- tests/utilities/test_testing.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 52480232..fa6b2fb7 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -485,14 +485,14 @@ class TestSuccessTool: def test_success_tool(self): task = Task("choose 5", result_type=int) tool = task.get_success_tool() - tool.run(input=dict(task_result=5)) + tool.run(input=dict(result=5)) assert task.is_successful() assert task.result == 5 def test_success_tool_with_list_of_options(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) tool = task.get_success_tool() - tool.run(input=dict(task_result=1)) + tool.run(input=dict(result=1)) assert task.is_successful() assert task.result == "good" @@ -500,12 +500,12 @@ def test_success_tool_with_list_of_options_requires_int(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) tool = task.get_success_tool() with pytest.raises(ValueError): - tool.run(input=dict(task_result="good")) + tool.run(input=dict(result="good")) def test_tuple_of_ints_result(self): task = Task("choose 5", result_type=(4, 5, 6)) tool = task.get_success_tool() - tool.run(input=dict(task_result=1)) + tool.run(input=dict(result=1)) assert task.result == 5 def test_tuple_of_pydantic_models_result(self): @@ -518,7 +518,7 @@ class Person(BaseModel): result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)), ) tool = task.get_success_tool() - tool.run(input=dict(task_result=1)) + tool.run(input=dict(result=1)) assert task.result == Person(name="Bob", age=35) assert isinstance(task.result, Person) @@ -604,7 +604,7 @@ def test_invalid_completion_tool(self): def test_manual_success_tool(self): task = Task(objective="Test task", completion_tools=[], result_type=int) success_tool = task.get_success_tool() - success_tool.run(input=dict(task_result=5)) + success_tool.run(input=dict(result=5)) assert task.is_successful() assert task.result == 5 diff --git a/tests/test_run.py b/tests/test_run.py index c8f1cabd..a78e37a0 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -191,7 +191,7 @@ def task(self, default_fake_llm): tool_calls=[ { "name": "mark_task_12345_successful", - "args": {"task_result": "Hello!"}, + "args": {"result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", } diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py index 380bdd5f..e4acefd3 100644 --- a/tests/utilities/test_testing.py +++ b/tests/utilities/test_testing.py @@ -20,7 +20,7 @@ def test_record_task_events(default_fake_llm): tool_calls=[ { "name": "mark_task_12345_successful", - "args": {"task_result": "Hello!"}, + "args": {"result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", } @@ -39,7 +39,7 @@ def test_record_task_events(default_fake_llm): assert events[3].event == "tool-result" assert events[3].tool_result.tool_call == { "name": "mark_task_12345_successful", - "args": {"task_result": "Hello!"}, + "args": {"result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", }