diff --git a/.github/workflows/workflow.yaml b/.github/workflows/workflow.yaml index 4f430cb..d2e50a1 100644 --- a/.github/workflows/workflow.yaml +++ b/.github/workflows/workflow.yaml @@ -23,12 +23,16 @@ jobs: - name: Lint with ruff run: | - poetry run ruff check + poetry run ruff check --config pyproject.toml - name: Check formatting with ruff run: | - poetry run ruff format --check + poetry run ruff format --check --config pyproject.toml - name: Check types with mypy run: | - poetry run mypy . + poetry run mypy . --config-file pyproject.toml + + - name: Check types with pyright + run: | + poetry run pyright diff --git a/environments/alfworld/adapt.ipynb b/environments/alfworld/adapt.ipynb index 21260e8..1e9e55c 100644 --- a/environments/alfworld/adapt.ipynb +++ b/environments/alfworld/adapt.ipynb @@ -22,16 +22,7 @@ }, { "cell_type": "code", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "import requests\n", "from environments.alfworld.common.environment import ALFWorldEnv\n", @@ -50,12 +41,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:55:03.331747Z", - "start_time": "2024-04-26T09:55:03.306341Z" + "end_time": "2024-05-22T11:46:56.309298Z", + "start_time": "2024-05-22T11:46:49.025162Z" } }, "id": "917b334526ae1418", - "execution_count": 5 + "execution_count": 1 }, { "cell_type": "markdown", @@ -72,18 +63,18 @@ "outputs": [], "source": [ "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", - "os.environ[\"LANGCHAIN_PROJECT\"] = \"ALFWorld + ADaPT (new)\"\n", + "os.environ[\"LANGCHAIN_PROJECT\"] = \"test\"\n", "os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:55:03.990181Z", - "start_time": "2024-04-26T09:55:03.970463Z" + "end_time": "2024-05-22T11:46:56.326234Z", + "start_time": "2024-05-22T11:46:56.309732Z" } }, "id": "add6277843d9e0d5", - "execution_count": 6 + "execution_count": 2 }, { "cell_type": "markdown", @@ -109,7 +100,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 8810/8810 [00:03<00:00, 2813.19it/s]" + "100%|██████████| 8810/8810 [00:05<00:00, 1712.34it/s]" ] }, { @@ -144,12 +135,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:55:07.937069Z", - "start_time": "2024-04-26T09:55:04.620750Z" + "end_time": "2024-05-22T11:47:01.752821Z", + "start_time": "2024-05-22T11:46:56.326441Z" } }, "id": "a79748e60e33a68c", - "execution_count": 7 + "execution_count": 3 }, { "cell_type": "code", @@ -173,12 +164,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:55:08.251850Z", - "start_time": "2024-04-26T09:55:07.937277Z" + "end_time": "2024-05-22T11:47:02.433994Z", + "start_time": "2024-05-22T11:47:01.751585Z" } }, "id": "e6feb0742d5fd134", - "execution_count": 8 + "execution_count": 4 }, { "cell_type": "markdown", @@ -216,12 +207,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:55:10.085814Z", - "start_time": "2024-04-26T09:55:10.060199Z" + "end_time": "2024-05-22T11:47:02.450759Z", + "start_time": "2024-05-22T11:47:02.433294Z" } }, "id": "95866c282598a829", - "execution_count": 9 + "execution_count": 5 }, { "cell_type": "markdown", @@ -260,12 +251,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:56:27.197982Z", - "start_time": "2024-04-26T09:56:26.988634Z" + "end_time": "2024-05-22T11:47:02.547042Z", + "start_time": "2024-05-22T11:47:02.450869Z" } }, "id": "979177816823c6a7", - "execution_count": 10 + "execution_count": 6 }, { "cell_type": "markdown", @@ -298,12 +289,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:56:28.545281Z", - "start_time": "2024-04-26T09:56:28.510826Z" + "end_time": "2024-05-22T11:47:02.574634Z", + "start_time": "2024-05-22T11:47:02.547704Z" } }, "id": "7c15cf73c24ab3f4", - "execution_count": 11 + "execution_count": 7 }, { "cell_type": "code", @@ -320,12 +311,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:56:29.270904Z", - "start_time": "2024-04-26T09:56:29.237099Z" + "end_time": "2024-05-22T11:47:02.601599Z", + "start_time": "2024-05-22T11:47:02.574821Z" } }, "id": "e682b3c778b7973b", - "execution_count": 12 + "execution_count": 8 }, { "cell_type": "markdown", @@ -352,12 +343,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:56:35.826852Z", - "start_time": "2024-04-26T09:56:35.801304Z" + "end_time": "2024-05-22T11:47:02.621126Z", + "start_time": "2024-05-22T11:47:02.601875Z" } }, "id": "d7c2708702ec7712", - "execution_count": 13 + "execution_count": 9 }, { "cell_type": "code", @@ -372,13 +363,14 @@ "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mI tried to put the pencil on the shelf, but nothing happened. I will attempt to complete the task again.\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mI encountered an issue while trying to put the pencil on the shelf. Let me try again to complete the task.\u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mI have set up the first subtask in the plan, which is to \"Pick up a pencil from the desk 1\" with the requirement to successfully complete this task. The aggregation mode is set to \"and\", meaning all subtasks need to be completed for the overall task to be considered successful. If you have any additional steps or adjustments you'd like to make, feel free to let me know!\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mI have added the first subtask to the plan: \"Go to the shelf 1 and place a pencil on it.\" \n", + "The plan is set to aggregate the results with an \"and\" logic, meaning all subtasks must be successfully completed. \u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\n", @@ -390,31 +382,37 @@ "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mI have added a subtask to try picking up the pencil from desk 1.\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mThe step-by-step plan to successfully solve the task \"Go to the shelf 1 and place a pencil on it\" with \"and\" logic aggregation is as follows:\n", + "\n", + "1. Go to shelf 1.\n", + "2. Place a pencil on shelf 1.\n", + "\n", + "This plan ensures that both subtasks are executed sequentially, and the task will only be considered successfully solved if both subtasks are completed successfully in the specified order.\u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mI couldn't pick up the pencil from desk 1. Let me try again.\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mIt seems like I couldn't go to shelf 1. Let me try that again.\u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mI have added the first step to attempt to take the pencil from desk 1. The plan will stop as soon as this subtask is successfully completed.\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mI have added the first step to the plan:\n", + "1. Go to shelf 1\u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3m\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mI encountered an issue while trying to go to shelf 1. Let me try that again.\u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mI have added the first subtask to the plan, which is attempting to go to desk 1. The aggregation logic is set to \"and\", meaning all subtasks need to be completed successfully for the original task to be considered solved.\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mI have added the first step to the plan, which is to \"Move towards shelf 1\". The plan will require the successful completion of each step in sequence for the task to be considered solved.\u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\u001B[32;1m\u001B[1;3mCouldn't solve the task. Last log: Couldn't solve the task. Last log: Couldn't solve the task. Last log: Couldn't solve the task. Last log: Maximum decomposition depth reached.\u001B[0m\n", @@ -426,7 +424,7 @@ "data": { "text/plain": "{'inputs': {'inputs': '-= Welcome to TextWorld, ALFRED! =-\\n\\nYou are in the middle of a room. Looking quickly around you, you see a bed 1, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a desk 1, a drawer 2, a drawer 1, a garbagecan 1, a shelf 1, and a sidetable 1.\\n\\nYour task is to: put some pencil on shelf.'},\n 'intermediate_steps': [[]],\n 'finish_log': [\"Couldn't solve the task. Last log: Couldn't solve the task. Last log: Couldn't solve the task. Last log: Couldn't solve the task. Last log: Maximum decomposition depth reached.\"]}" }, - "execution_count": 14, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -439,12 +437,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-26T09:57:02.892981Z", - "start_time": "2024-04-26T09:56:36.221827Z" + "end_time": "2024-05-22T11:47:28.000648Z", + "start_time": "2024-05-22T11:47:02.621338Z" } }, "id": "8d1b5178959c06", - "execution_count": 14 + "execution_count": 10 }, { "cell_type": "code", @@ -461,12 +459,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-24T09:50:32.745288Z", - "start_time": "2024-04-24T09:50:32.719793Z" + "end_time": "2024-05-22T11:48:45.000490Z", + "start_time": "2024-05-22T11:48:44.974383Z" } }, "id": "4efce110755d134e", - "execution_count": 8 + "execution_count": 11 }, { "cell_type": "code", @@ -487,13 +485,19 @@ "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3m\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mtask completed\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n", + "\n", + "\n", + "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mIt seems like I couldn't put the pencil on shelf 1. Let me try again.\u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\n", "\n", "\u001B[1m> Entering new SimpleStrategy chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mIt seems I couldn't go to shelf 1. Let me try again.\u001B[0m\n", + "\u001B[32;1m\u001B[1;3m\u001B[0m\n", "\n", "\u001B[1m> Finished chain.\u001B[0m\n", "\n", @@ -509,9 +513,9 @@ }, { "data": { - "text/plain": "{'inputs': {'inputs': '-= Welcome to TextWorld, ALFRED! =-\\n\\nYou are in the middle of a room. Looking quickly around you, you see a bed 1, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a desk 1, a drawer 2, a drawer 1, a garbagecan 1, a shelf 1, and a sidetable 1.\\n\\nYour task is to: put some pencil on shelf.'},\n 'intermediate_steps': [[]],\n 'finish_log': [\"Couldn't solve the task. Last log: Couldn't solve the task. Last log: Couldn't solve the task. Last log: Couldn't solve the task. Last log: Maximum decomposition depth reached.\"]}" + "text/plain": "{'inputs': {'inputs': '-= Welcome to TextWorld, ALFRED! =-\\n\\nYou are in the middle of a room. Looking quickly around you, you see a bed 1, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a desk 1, a drawer 2, a drawer 1, a garbagecan 1, a shelf 1, and a sidetable 1.\\n\\nYour task is to: put some pencil on shelf.'},\n 'intermediate_steps': [[(OpenAIToolAgentAction(tool='goto', tool_input={'receptable_type': 'sidetable', 'receptable_id': 1}, log=\"\\nInvoking: `goto` with `{'receptable_type': 'sidetable', 'receptable_id': 1}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_zVKtRn26Dhvg6cESJx9Xfk54', 'function': {'arguments': '{\"receptable_type\":\"sidetable\",\"receptable_id\":1}', 'name': 'goto'}, 'type': 'function'}]})], tool_call_id='call_zVKtRn26Dhvg6cESJx9Xfk54'),\n AgentStep(action=OpenAIToolAgentAction(tool='goto', tool_input={'receptable_type': 'sidetable', 'receptable_id': 1}, log=\"\\nInvoking: `goto` with `{'receptable_type': 'sidetable', 'receptable_id': 1}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_zVKtRn26Dhvg6cESJx9Xfk54', 'function': {'arguments': '{\"receptable_type\":\"sidetable\",\"receptable_id\":1}', 'name': 'goto'}, 'type': 'function'}]})], tool_call_id='call_zVKtRn26Dhvg6cESJx9Xfk54'), observation=('You arrive at loc 15. On the sidetable 1, you see a cd 1, a cellphone 2, a cellphone 1, and a pencil 1.', 0, False, False, {'won': False, 'extra.expert_plan': ['take pencil 1 from sidetable 1'], 'extra.gamefile': None, 'admissible_commands': ['examine sidetable 1', 'go to bed 1', 'go to cabinet 1', 'go to cabinet 2', 'go to cabinet 3', 'go to cabinet 4', 'go to desk 1', 'go to drawer 1', 'go to drawer 2', 'go to garbagecan 1', 'go to shelf 1', 'inventory', 'look', 'take cd 1 from sidetable 1', 'take cellphone 1 from sidetable 1', 'take cellphone 2 from sidetable 1', 'take pencil 1 from sidetable 1'], 'facts': [Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('baseballbattype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('baseballbattype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('pillowtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('inreceptacle', (Variable('pencil 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('box 1', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('laptop 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 1', 'object'), Variable('drawer 2', 'receptacle'))), Proposition('inreceptacle', (Variable('alarmclock 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('laptop 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cellphone 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 2', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('keychain 1', 'object'), Variable('drawer 2', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('alarmclock 1', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cd 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('keychain 2', 'object'), Variable('shelf 1', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cellphone 2', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pencil 2', 'object'), Variable('drawer 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pen 1', 'object'), Variable('shelf 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pencil 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pillow 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('atlocation', (Variable('agent1 ', 'agent'), Variable('loc 15', 'location'))), Proposition('objecttype', (Variable('book 1', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 1', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('book 3', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('laptop 1', 'object'), Variable('laptoptype ', 'otype'))), Proposition('objecttype', (Variable('pillow 1', 'object'), Variable('pillowtype ', 'otype'))), Proposition('objecttype', (Variable('window 1', 'object'), Variable('windowtype ', 'otype'))), Proposition('objecttype', (Variable('cellphone 2', 'object'), Variable('cellphonetype ', 'otype'))), Proposition('objecttype', (Variable('keychain 2', 'object'), Variable('keychaintype ', 'otype'))), Proposition('objecttype', (Variable('baseballbat 1', 'object'), Variable('baseballbattype ', 'otype'))), Proposition('objecttype', (Variable('basketball 1', 'object'), Variable('basketballtype ', 'otype'))), Proposition('objecttype', (Variable('keychain 1', 'object'), Variable('keychaintype ', 'otype'))), Proposition('objecttype', (Variable('pen 1', 'object'), Variable('pentype ', 'otype'))), Proposition('objecttype', (Variable('lightswitch 1', 'object'), Variable('lightswitchtype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 2', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('laptop 2', 'object'), Variable('laptoptype ', 'otype'))), Proposition('objecttype', (Variable('book 2', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('blinds 1', 'object'), Variable('blindstype ', 'otype'))), Proposition('objecttype', (Variable('alarmclock 2', 'object'), Variable('alarmclocktype ', 'otype'))), Proposition('objecttype', (Variable('cd 1', 'object'), Variable('cdtype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 3', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('tennisracket 1', 'object'), Variable('tennisrackettype ', 'otype'))), Proposition('objecttype', (Variable('chair 1', 'object'), Variable('chairtype ', 'otype'))), Proposition('objecttype', (Variable('box 1', 'object'), Variable('boxtype ', 'otype'))), Proposition('objecttype', (Variable('pencil 1', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('pencil 2', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('pencil 3', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('mirror 1', 'object'), Variable('mirrortype ', 'otype'))), Proposition('objecttype', (Variable('cellphone 1', 'object'), Variable('cellphonetype ', 'otype'))), Proposition('objecttype', (Variable('alarmclock 1', 'object'), Variable('alarmclocktype ', 'otype'))), Proposition('objecttype', (Variable('poster 1', 'object'), Variable('postertype ', 'otype'))), Proposition('objectatlocation', (Variable('pillow 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('alarmclock 1', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('pencil 3', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('tennisracket 1', 'object'), Variable('loc 5', 'location'))), Proposition('objectatlocation', (Variable('cellphone 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('lightswitch 1', 'object'), Variable('loc 13', 'location'))), Proposition('objectatlocation', (Variable('keychain 2', 'object'), Variable('loc 9', 'location'))), Proposition('objectatlocation', (Variable('book 3', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('cd 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('creditcard 1', 'object'), Variable('loc 17', 'location'))), Proposition('objectatlocation', (Variable('pencil 2', 'object'), Variable('loc 16', 'location'))), Proposition('objectatlocation', (Variable('chair 1', 'object'), Variable('loc 7', 'location'))), Proposition('objectatlocation', (Variable('laptop 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('book 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('cellphone 2', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('poster 1', 'object'), Variable('loc 1', 'location'))), Proposition('objectatlocation', (Variable('mirror 1', 'object'), Variable('loc 14', 'location'))), Proposition('objectatlocation', (Variable('pen 1', 'object'), Variable('loc 9', 'location'))), Proposition('objectatlocation', (Variable('book 2', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('box 1', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('blinds 1', 'object'), Variable('loc 3', 'location'))), Proposition('objectatlocation', (Variable('laptop 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('creditcard 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('keychain 1', 'object'), Variable('loc 17', 'location'))), Proposition('objectatlocation', (Variable('baseballbat 1', 'object'), Variable('loc 6', 'location'))), Proposition('objectatlocation', (Variable('pencil 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('basketball 1', 'object'), Variable('loc 10', 'location'))), Proposition('objectatlocation', (Variable('alarmclock 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('window 1', 'object'), Variable('loc 4', 'location'))), Proposition('objectatlocation', (Variable('creditcard 3', 'object'), Variable('loc 8', 'location'))), Proposition('pickupable', (Variable('creditcard 1', 'object'),)), Proposition('pickupable', (Variable('laptop 2', 'object'),)), Proposition('pickupable', (Variable('book 1', 'object'),)), Proposition('pickupable', (Variable('basketball 1', 'object'),)), Proposition('pickupable', (Variable('cellphone 1', 'object'),)), Proposition('pickupable', (Variable('pencil 1', 'object'),)), Proposition('pickupable', (Variable('pen 1', 'object'),)), Proposition('pickupable', (Variable('keychain 2', 'object'),)), Proposition('pickupable', (Variable('pencil 3', 'object'),)), Proposition('pickupable', (Variable('cd 1', 'object'),)), Proposition('pickupable', (Variable('cellphone 2', 'object'),)), Proposition('pickupable', (Variable('box 1', 'object'),)), Proposition('pickupable', (Variable('creditcard 2', 'object'),)), Proposition('pickupable', (Variable('alarmclock 2', 'object'),)), Proposition('pickupable', (Variable('book 3', 'object'),)), Proposition('pickupable', (Variable('tennisracket 1', 'object'),)), Proposition('pickupable', (Variable('keychain 1', 'object'),)), Proposition('pickupable', (Variable('pillow 1', 'object'),)), Proposition('pickupable', (Variable('pencil 2', 'object'),)), Proposition('pickupable', (Variable('book 2', 'object'),)), Proposition('pickupable', (Variable('creditcard 3', 'object'),)), Proposition('pickupable', (Variable('baseballbat 1', 'object'),)), Proposition('pickupable', (Variable('alarmclock 1', 'object'),)), Proposition('pickupable', (Variable('laptop 1', 'object'),)), Proposition('receptacletype', (Variable('cabinet 2', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('drawer 2', 'receptacle'), Variable('drawertype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 4', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('shelf 1', 'receptacle'), Variable('shelftype ', 'rtype'))), Proposition('receptacletype', (Variable('drawer 1', 'receptacle'), Variable('drawertype ', 'rtype'))), Proposition('receptacletype', (Variable('garbagecan 1', 'receptacle'), Variable('garbagecantype ', 'rtype'))), Proposition('receptacletype', (Variable('sidetable 1', 'receptacle'), Variable('sidetabletype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 3', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 1', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('bed 1', 'receptacle'), Variable('bedtype ', 'rtype'))), Proposition('receptacletype', (Variable('desk 1', 'receptacle'), Variable('desktype ', 'rtype'))), Proposition('receptacleatlocation', (Variable('drawer 1', 'receptacle'), Variable('loc 16', 'location'))), Proposition('receptacleatlocation', (Variable('garbagecan 1', 'receptacle'), Variable('loc 19', 'location'))), Proposition('receptacleatlocation', (Variable('bed 1', 'receptacle'), Variable('loc 11', 'location'))), Proposition('receptacleatlocation', (Variable('shelf 1', 'receptacle'), Variable('loc 9', 'location'))), Proposition('receptacleatlocation', (Variable('drawer 2', 'receptacle'), Variable('loc 17', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 4', 'receptacle'), Variable('loc 12', 'location'))), Proposition('receptacleatlocation', (Variable('sidetable 1', 'receptacle'), Variable('loc 15', 'location'))), Proposition('receptacleatlocation', (Variable('desk 1', 'receptacle'), Variable('loc 8', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 3', 'receptacle'), Variable('loc 9', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 2', 'receptacle'), Variable('loc 2', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 1', 'receptacle'), Variable('loc 2', 'location'))), Proposition('openable', (Variable('cabinet 4', 'receptacle'),)), Proposition('openable', (Variable('cabinet 2', 'receptacle'),)), Proposition('openable', (Variable('cabinet 3', 'receptacle'),)), Proposition('openable', (Variable('drawer 1', 'receptacle'),)), Proposition('openable', (Variable('cabinet 1', 'receptacle'),)), Proposition('openable', (Variable('drawer 2', 'receptacle'),)), Proposition('isreceptacleobject', (Variable('box 1', 'object'),)), Proposition('not_atlocation', (Variable('agent1 ', 'agent'), Variable('loc 18', 'location'))), Proposition('not_atlocation', (Variable('agent1 ', 'agent'), Variable('loc 9', 'location')))]}))),\n (OpenAIToolAgentAction(tool='goto', tool_input={'receptable_type': 'sidetable', 'receptable_id': 1}, log=\"\\nInvoking: `goto` with `{'receptable_type': 'sidetable', 'receptable_id': 1}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_zVKtRn26Dhvg6cESJx9Xfk54', 'function': {'arguments': '{\"receptable_type\":\"sidetable\",\"receptable_id\":1}', 'name': 'goto'}, 'type': 'function'}]})], tool_call_id='call_zVKtRn26Dhvg6cESJx9Xfk54'),\n AgentStep(action=OpenAIToolAgentAction(tool='goto', tool_input={'receptable_type': 'sidetable', 'receptable_id': 1}, log=\"\\nInvoking: `goto` with `{'receptable_type': 'sidetable', 'receptable_id': 1}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_zVKtRn26Dhvg6cESJx9Xfk54', 'function': {'arguments': '{\"receptable_type\":\"sidetable\",\"receptable_id\":1}', 'name': 'goto'}, 'type': 'function'}]})], tool_call_id='call_zVKtRn26Dhvg6cESJx9Xfk54'), observation=('You arrive at loc 15. On the sidetable 1, you see a cd 1, a cellphone 2, a cellphone 1, and a pencil 1.', 0, False, False, {'won': False, 'extra.expert_plan': ['take pencil 1 from sidetable 1'], 'extra.gamefile': None, 'admissible_commands': ['examine sidetable 1', 'go to bed 1', 'go to cabinet 1', 'go to cabinet 2', 'go to cabinet 3', 'go to cabinet 4', 'go to desk 1', 'go to drawer 1', 'go to drawer 2', 'go to garbagecan 1', 'go to shelf 1', 'inventory', 'look', 'take cd 1 from sidetable 1', 'take cellphone 1 from sidetable 1', 'take cellphone 2 from sidetable 1', 'take pencil 1 from sidetable 1'], 'facts': [Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('baseballbattype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('baseballbattype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('pillowtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('inreceptacle', (Variable('pencil 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('box 1', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('laptop 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 1', 'object'), Variable('drawer 2', 'receptacle'))), Proposition('inreceptacle', (Variable('alarmclock 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('laptop 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cellphone 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 2', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('keychain 1', 'object'), Variable('drawer 2', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('alarmclock 1', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cd 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('keychain 2', 'object'), Variable('shelf 1', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cellphone 2', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pencil 2', 'object'), Variable('drawer 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pen 1', 'object'), Variable('shelf 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pencil 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pillow 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('atlocation', (Variable('agent1 ', 'agent'), Variable('loc 15', 'location'))), Proposition('objecttype', (Variable('book 1', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 1', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('book 3', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('laptop 1', 'object'), Variable('laptoptype ', 'otype'))), Proposition('objecttype', (Variable('pillow 1', 'object'), Variable('pillowtype ', 'otype'))), Proposition('objecttype', (Variable('window 1', 'object'), Variable('windowtype ', 'otype'))), Proposition('objecttype', (Variable('cellphone 2', 'object'), Variable('cellphonetype ', 'otype'))), Proposition('objecttype', (Variable('keychain 2', 'object'), Variable('keychaintype ', 'otype'))), Proposition('objecttype', (Variable('baseballbat 1', 'object'), Variable('baseballbattype ', 'otype'))), Proposition('objecttype', (Variable('basketball 1', 'object'), Variable('basketballtype ', 'otype'))), Proposition('objecttype', (Variable('keychain 1', 'object'), Variable('keychaintype ', 'otype'))), Proposition('objecttype', (Variable('pen 1', 'object'), Variable('pentype ', 'otype'))), Proposition('objecttype', (Variable('lightswitch 1', 'object'), Variable('lightswitchtype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 2', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('laptop 2', 'object'), Variable('laptoptype ', 'otype'))), Proposition('objecttype', (Variable('book 2', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('blinds 1', 'object'), Variable('blindstype ', 'otype'))), Proposition('objecttype', (Variable('alarmclock 2', 'object'), Variable('alarmclocktype ', 'otype'))), Proposition('objecttype', (Variable('cd 1', 'object'), Variable('cdtype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 3', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('tennisracket 1', 'object'), Variable('tennisrackettype ', 'otype'))), Proposition('objecttype', (Variable('chair 1', 'object'), Variable('chairtype ', 'otype'))), Proposition('objecttype', (Variable('box 1', 'object'), Variable('boxtype ', 'otype'))), Proposition('objecttype', (Variable('pencil 1', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('pencil 2', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('pencil 3', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('mirror 1', 'object'), Variable('mirrortype ', 'otype'))), Proposition('objecttype', (Variable('cellphone 1', 'object'), Variable('cellphonetype ', 'otype'))), Proposition('objecttype', (Variable('alarmclock 1', 'object'), Variable('alarmclocktype ', 'otype'))), Proposition('objecttype', (Variable('poster 1', 'object'), Variable('postertype ', 'otype'))), Proposition('objectatlocation', (Variable('pillow 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('alarmclock 1', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('pencil 3', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('tennisracket 1', 'object'), Variable('loc 5', 'location'))), Proposition('objectatlocation', (Variable('cellphone 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('lightswitch 1', 'object'), Variable('loc 13', 'location'))), Proposition('objectatlocation', (Variable('keychain 2', 'object'), Variable('loc 9', 'location'))), Proposition('objectatlocation', (Variable('book 3', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('cd 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('creditcard 1', 'object'), Variable('loc 17', 'location'))), Proposition('objectatlocation', (Variable('pencil 2', 'object'), Variable('loc 16', 'location'))), Proposition('objectatlocation', (Variable('chair 1', 'object'), Variable('loc 7', 'location'))), Proposition('objectatlocation', (Variable('laptop 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('book 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('cellphone 2', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('poster 1', 'object'), Variable('loc 1', 'location'))), Proposition('objectatlocation', (Variable('mirror 1', 'object'), Variable('loc 14', 'location'))), Proposition('objectatlocation', (Variable('pen 1', 'object'), Variable('loc 9', 'location'))), Proposition('objectatlocation', (Variable('book 2', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('box 1', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('blinds 1', 'object'), Variable('loc 3', 'location'))), Proposition('objectatlocation', (Variable('laptop 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('creditcard 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('keychain 1', 'object'), Variable('loc 17', 'location'))), Proposition('objectatlocation', (Variable('baseballbat 1', 'object'), Variable('loc 6', 'location'))), Proposition('objectatlocation', (Variable('pencil 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('basketball 1', 'object'), Variable('loc 10', 'location'))), Proposition('objectatlocation', (Variable('alarmclock 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('window 1', 'object'), Variable('loc 4', 'location'))), Proposition('objectatlocation', (Variable('creditcard 3', 'object'), Variable('loc 8', 'location'))), Proposition('pickupable', (Variable('creditcard 1', 'object'),)), Proposition('pickupable', (Variable('laptop 2', 'object'),)), Proposition('pickupable', (Variable('book 1', 'object'),)), Proposition('pickupable', (Variable('basketball 1', 'object'),)), Proposition('pickupable', (Variable('cellphone 1', 'object'),)), Proposition('pickupable', (Variable('pencil 1', 'object'),)), Proposition('pickupable', (Variable('pen 1', 'object'),)), Proposition('pickupable', (Variable('keychain 2', 'object'),)), Proposition('pickupable', (Variable('pencil 3', 'object'),)), Proposition('pickupable', (Variable('cd 1', 'object'),)), Proposition('pickupable', (Variable('cellphone 2', 'object'),)), Proposition('pickupable', (Variable('box 1', 'object'),)), Proposition('pickupable', (Variable('creditcard 2', 'object'),)), Proposition('pickupable', (Variable('alarmclock 2', 'object'),)), Proposition('pickupable', (Variable('book 3', 'object'),)), Proposition('pickupable', (Variable('tennisracket 1', 'object'),)), Proposition('pickupable', (Variable('keychain 1', 'object'),)), Proposition('pickupable', (Variable('pillow 1', 'object'),)), Proposition('pickupable', (Variable('pencil 2', 'object'),)), Proposition('pickupable', (Variable('book 2', 'object'),)), Proposition('pickupable', (Variable('creditcard 3', 'object'),)), Proposition('pickupable', (Variable('baseballbat 1', 'object'),)), Proposition('pickupable', (Variable('alarmclock 1', 'object'),)), Proposition('pickupable', (Variable('laptop 1', 'object'),)), Proposition('receptacletype', (Variable('cabinet 2', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('drawer 2', 'receptacle'), Variable('drawertype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 4', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('shelf 1', 'receptacle'), Variable('shelftype ', 'rtype'))), Proposition('receptacletype', (Variable('drawer 1', 'receptacle'), Variable('drawertype ', 'rtype'))), Proposition('receptacletype', (Variable('garbagecan 1', 'receptacle'), Variable('garbagecantype ', 'rtype'))), Proposition('receptacletype', (Variable('sidetable 1', 'receptacle'), Variable('sidetabletype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 3', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 1', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('bed 1', 'receptacle'), Variable('bedtype ', 'rtype'))), Proposition('receptacletype', (Variable('desk 1', 'receptacle'), Variable('desktype ', 'rtype'))), Proposition('receptacleatlocation', (Variable('drawer 1', 'receptacle'), Variable('loc 16', 'location'))), Proposition('receptacleatlocation', (Variable('garbagecan 1', 'receptacle'), Variable('loc 19', 'location'))), Proposition('receptacleatlocation', (Variable('bed 1', 'receptacle'), Variable('loc 11', 'location'))), Proposition('receptacleatlocation', (Variable('shelf 1', 'receptacle'), Variable('loc 9', 'location'))), Proposition('receptacleatlocation', (Variable('drawer 2', 'receptacle'), Variable('loc 17', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 4', 'receptacle'), Variable('loc 12', 'location'))), Proposition('receptacleatlocation', (Variable('sidetable 1', 'receptacle'), Variable('loc 15', 'location'))), Proposition('receptacleatlocation', (Variable('desk 1', 'receptacle'), Variable('loc 8', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 3', 'receptacle'), Variable('loc 9', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 2', 'receptacle'), Variable('loc 2', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 1', 'receptacle'), Variable('loc 2', 'location'))), Proposition('openable', (Variable('cabinet 4', 'receptacle'),)), Proposition('openable', (Variable('cabinet 2', 'receptacle'),)), Proposition('openable', (Variable('cabinet 3', 'receptacle'),)), Proposition('openable', (Variable('drawer 1', 'receptacle'),)), Proposition('openable', (Variable('cabinet 1', 'receptacle'),)), Proposition('openable', (Variable('drawer 2', 'receptacle'),)), Proposition('isreceptacleobject', (Variable('box 1', 'object'),)), Proposition('not_atlocation', (Variable('agent1 ', 'agent'), Variable('loc 18', 'location'))), Proposition('not_atlocation', (Variable('agent1 ', 'agent'), Variable('loc 9', 'location')))]}))),\n (OpenAIToolAgentAction(tool='goto', tool_input={'receptable_type': 'sidetable', 'receptable_id': 1}, log=\"\\nInvoking: `goto` with `{'receptable_type': 'sidetable', 'receptable_id': 1}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_zVKtRn26Dhvg6cESJx9Xfk54', 'function': {'arguments': '{\"receptable_type\":\"sidetable\",\"receptable_id\":1}', 'name': 'goto'}, 'type': 'function'}]})], tool_call_id='call_zVKtRn26Dhvg6cESJx9Xfk54'),\n AgentStep(action=OpenAIToolAgentAction(tool='goto', tool_input={'receptable_type': 'sidetable', 'receptable_id': 1}, log=\"\\nInvoking: `goto` with `{'receptable_type': 'sidetable', 'receptable_id': 1}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_zVKtRn26Dhvg6cESJx9Xfk54', 'function': {'arguments': '{\"receptable_type\":\"sidetable\",\"receptable_id\":1}', 'name': 'goto'}, 'type': 'function'}]})], tool_call_id='call_zVKtRn26Dhvg6cESJx9Xfk54'), observation=('You arrive at loc 15. On the sidetable 1, you see a cd 1, a cellphone 2, a cellphone 1, and a pencil 1.', 0, False, False, {'won': False, 'extra.expert_plan': ['take pencil 1 from sidetable 1'], 'extra.gamefile': None, 'admissible_commands': ['examine sidetable 1', 'go to bed 1', 'go to cabinet 1', 'go to cabinet 2', 'go to cabinet 3', 'go to cabinet 4', 'go to desk 1', 'go to drawer 1', 'go to drawer 2', 'go to garbagecan 1', 'go to shelf 1', 'inventory', 'look', 'take cd 1 from sidetable 1', 'take cellphone 1 from sidetable 1', 'take cellphone 2 from sidetable 1', 'take pencil 1 from sidetable 1'], 'facts': [Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('baseballbattype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('baseballbattype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('pillowtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('inreceptacle', (Variable('pencil 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('box 1', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('laptop 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 1', 'object'), Variable('drawer 2', 'receptacle'))), Proposition('inreceptacle', (Variable('alarmclock 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('laptop 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cellphone 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 2', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('keychain 1', 'object'), Variable('drawer 2', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('alarmclock 1', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cd 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('keychain 2', 'object'), Variable('shelf 1', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cellphone 2', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pencil 2', 'object'), Variable('drawer 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pen 1', 'object'), Variable('shelf 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pencil 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pillow 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('atlocation', (Variable('agent1 ', 'agent'), Variable('loc 15', 'location'))), Proposition('objecttype', (Variable('book 1', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 1', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('book 3', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('laptop 1', 'object'), Variable('laptoptype ', 'otype'))), Proposition('objecttype', (Variable('pillow 1', 'object'), Variable('pillowtype ', 'otype'))), Proposition('objecttype', (Variable('window 1', 'object'), Variable('windowtype ', 'otype'))), Proposition('objecttype', (Variable('cellphone 2', 'object'), Variable('cellphonetype ', 'otype'))), Proposition('objecttype', (Variable('keychain 2', 'object'), Variable('keychaintype ', 'otype'))), Proposition('objecttype', (Variable('baseballbat 1', 'object'), Variable('baseballbattype ', 'otype'))), Proposition('objecttype', (Variable('basketball 1', 'object'), Variable('basketballtype ', 'otype'))), Proposition('objecttype', (Variable('keychain 1', 'object'), Variable('keychaintype ', 'otype'))), Proposition('objecttype', (Variable('pen 1', 'object'), Variable('pentype ', 'otype'))), Proposition('objecttype', (Variable('lightswitch 1', 'object'), Variable('lightswitchtype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 2', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('laptop 2', 'object'), Variable('laptoptype ', 'otype'))), Proposition('objecttype', (Variable('book 2', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('blinds 1', 'object'), Variable('blindstype ', 'otype'))), Proposition('objecttype', (Variable('alarmclock 2', 'object'), Variable('alarmclocktype ', 'otype'))), Proposition('objecttype', (Variable('cd 1', 'object'), Variable('cdtype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 3', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('tennisracket 1', 'object'), Variable('tennisrackettype ', 'otype'))), Proposition('objecttype', (Variable('chair 1', 'object'), Variable('chairtype ', 'otype'))), Proposition('objecttype', (Variable('box 1', 'object'), Variable('boxtype ', 'otype'))), Proposition('objecttype', (Variable('pencil 1', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('pencil 2', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('pencil 3', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('mirror 1', 'object'), Variable('mirrortype ', 'otype'))), Proposition('objecttype', (Variable('cellphone 1', 'object'), Variable('cellphonetype ', 'otype'))), Proposition('objecttype', (Variable('alarmclock 1', 'object'), Variable('alarmclocktype ', 'otype'))), Proposition('objecttype', (Variable('poster 1', 'object'), Variable('postertype ', 'otype'))), Proposition('objectatlocation', (Variable('pillow 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('alarmclock 1', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('pencil 3', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('tennisracket 1', 'object'), Variable('loc 5', 'location'))), Proposition('objectatlocation', (Variable('cellphone 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('lightswitch 1', 'object'), Variable('loc 13', 'location'))), Proposition('objectatlocation', (Variable('keychain 2', 'object'), Variable('loc 9', 'location'))), Proposition('objectatlocation', (Variable('book 3', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('cd 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('creditcard 1', 'object'), Variable('loc 17', 'location'))), Proposition('objectatlocation', (Variable('pencil 2', 'object'), Variable('loc 16', 'location'))), Proposition('objectatlocation', (Variable('chair 1', 'object'), Variable('loc 7', 'location'))), Proposition('objectatlocation', (Variable('laptop 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('book 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('cellphone 2', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('poster 1', 'object'), Variable('loc 1', 'location'))), Proposition('objectatlocation', (Variable('mirror 1', 'object'), Variable('loc 14', 'location'))), Proposition('objectatlocation', (Variable('pen 1', 'object'), Variable('loc 9', 'location'))), Proposition('objectatlocation', (Variable('book 2', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('box 1', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('blinds 1', 'object'), Variable('loc 3', 'location'))), Proposition('objectatlocation', (Variable('laptop 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('creditcard 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('keychain 1', 'object'), Variable('loc 17', 'location'))), Proposition('objectatlocation', (Variable('baseballbat 1', 'object'), Variable('loc 6', 'location'))), Proposition('objectatlocation', (Variable('pencil 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('basketball 1', 'object'), Variable('loc 10', 'location'))), Proposition('objectatlocation', (Variable('alarmclock 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('window 1', 'object'), Variable('loc 4', 'location'))), Proposition('objectatlocation', (Variable('creditcard 3', 'object'), Variable('loc 8', 'location'))), Proposition('pickupable', (Variable('creditcard 1', 'object'),)), Proposition('pickupable', (Variable('laptop 2', 'object'),)), Proposition('pickupable', (Variable('book 1', 'object'),)), Proposition('pickupable', (Variable('basketball 1', 'object'),)), Proposition('pickupable', (Variable('cellphone 1', 'object'),)), Proposition('pickupable', (Variable('pencil 1', 'object'),)), Proposition('pickupable', (Variable('pen 1', 'object'),)), Proposition('pickupable', (Variable('keychain 2', 'object'),)), Proposition('pickupable', (Variable('pencil 3', 'object'),)), Proposition('pickupable', (Variable('cd 1', 'object'),)), Proposition('pickupable', (Variable('cellphone 2', 'object'),)), Proposition('pickupable', (Variable('box 1', 'object'),)), Proposition('pickupable', (Variable('creditcard 2', 'object'),)), Proposition('pickupable', (Variable('alarmclock 2', 'object'),)), Proposition('pickupable', (Variable('book 3', 'object'),)), Proposition('pickupable', (Variable('tennisracket 1', 'object'),)), Proposition('pickupable', (Variable('keychain 1', 'object'),)), Proposition('pickupable', (Variable('pillow 1', 'object'),)), Proposition('pickupable', (Variable('pencil 2', 'object'),)), Proposition('pickupable', (Variable('book 2', 'object'),)), Proposition('pickupable', (Variable('creditcard 3', 'object'),)), Proposition('pickupable', (Variable('baseballbat 1', 'object'),)), Proposition('pickupable', (Variable('alarmclock 1', 'object'),)), Proposition('pickupable', (Variable('laptop 1', 'object'),)), Proposition('receptacletype', (Variable('cabinet 2', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('drawer 2', 'receptacle'), Variable('drawertype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 4', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('shelf 1', 'receptacle'), Variable('shelftype ', 'rtype'))), Proposition('receptacletype', (Variable('drawer 1', 'receptacle'), Variable('drawertype ', 'rtype'))), Proposition('receptacletype', (Variable('garbagecan 1', 'receptacle'), Variable('garbagecantype ', 'rtype'))), Proposition('receptacletype', (Variable('sidetable 1', 'receptacle'), Variable('sidetabletype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 3', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 1', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('bed 1', 'receptacle'), Variable('bedtype ', 'rtype'))), Proposition('receptacletype', (Variable('desk 1', 'receptacle'), Variable('desktype ', 'rtype'))), Proposition('receptacleatlocation', (Variable('drawer 1', 'receptacle'), Variable('loc 16', 'location'))), Proposition('receptacleatlocation', (Variable('garbagecan 1', 'receptacle'), Variable('loc 19', 'location'))), Proposition('receptacleatlocation', (Variable('bed 1', 'receptacle'), Variable('loc 11', 'location'))), Proposition('receptacleatlocation', (Variable('shelf 1', 'receptacle'), Variable('loc 9', 'location'))), Proposition('receptacleatlocation', (Variable('drawer 2', 'receptacle'), Variable('loc 17', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 4', 'receptacle'), Variable('loc 12', 'location'))), Proposition('receptacleatlocation', (Variable('sidetable 1', 'receptacle'), Variable('loc 15', 'location'))), Proposition('receptacleatlocation', (Variable('desk 1', 'receptacle'), Variable('loc 8', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 3', 'receptacle'), Variable('loc 9', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 2', 'receptacle'), Variable('loc 2', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 1', 'receptacle'), Variable('loc 2', 'location'))), Proposition('openable', (Variable('cabinet 4', 'receptacle'),)), Proposition('openable', (Variable('cabinet 2', 'receptacle'),)), Proposition('openable', (Variable('cabinet 3', 'receptacle'),)), Proposition('openable', (Variable('drawer 1', 'receptacle'),)), Proposition('openable', (Variable('cabinet 1', 'receptacle'),)), Proposition('openable', (Variable('drawer 2', 'receptacle'),)), Proposition('isreceptacleobject', (Variable('box 1', 'object'),)), Proposition('not_atlocation', (Variable('agent1 ', 'agent'), Variable('loc 18', 'location'))), Proposition('not_atlocation', (Variable('agent1 ', 'agent'), Variable('loc 9', 'location')))]}))),\n (OpenAIToolAgentAction(tool='goto', tool_input={'receptable_type': 'sidetable', 'receptable_id': 1}, log=\"\\nInvoking: `goto` with `{'receptable_type': 'sidetable', 'receptable_id': 1}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_zVKtRn26Dhvg6cESJx9Xfk54', 'function': {'arguments': '{\"receptable_type\":\"sidetable\",\"receptable_id\":1}', 'name': 'goto'}, 'type': 'function'}]})], tool_call_id='call_zVKtRn26Dhvg6cESJx9Xfk54'),\n AgentStep(action=OpenAIToolAgentAction(tool='goto', tool_input={'receptable_type': 'sidetable', 'receptable_id': 1}, log=\"\\nInvoking: `goto` with `{'receptable_type': 'sidetable', 'receptable_id': 1}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_zVKtRn26Dhvg6cESJx9Xfk54', 'function': {'arguments': '{\"receptable_type\":\"sidetable\",\"receptable_id\":1}', 'name': 'goto'}, 'type': 'function'}]})], tool_call_id='call_zVKtRn26Dhvg6cESJx9Xfk54'), observation=('You arrive at loc 15. On the sidetable 1, you see a cd 1, a cellphone 2, a cellphone 1, and a pencil 1.', 0, False, False, {'won': False, 'extra.expert_plan': ['take pencil 1 from sidetable 1'], 'extra.gamefile': None, 'admissible_commands': ['examine sidetable 1', 'go to bed 1', 'go to cabinet 1', 'go to cabinet 2', 'go to cabinet 3', 'go to cabinet 4', 'go to desk 1', 'go to drawer 1', 'go to drawer 2', 'go to garbagecan 1', 'go to shelf 1', 'inventory', 'look', 'take cd 1 from sidetable 1', 'take cellphone 1 from sidetable 1', 'take cellphone 2 from sidetable 1', 'take pencil 1 from sidetable 1'], 'facts': [Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('baseballbattype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('baseballbattype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('pillowtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('cabinettype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('cdtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('basketballtype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('sidetabletype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('alarmclocktype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('shelftype ', 'rtype'), Variable('boxtype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('tennisrackettype ', 'otype'))), Proposition('cancontain', (Variable('bedtype ', 'rtype'), Variable('laptoptype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('pentype ', 'otype'))), Proposition('cancontain', (Variable('garbagecantype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('keychaintype ', 'otype'))), Proposition('cancontain', (Variable('drawertype ', 'rtype'), Variable('cellphonetype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('booktype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('penciltype ', 'otype'))), Proposition('cancontain', (Variable('desktype ', 'rtype'), Variable('creditcardtype ', 'otype'))), Proposition('inreceptacle', (Variable('pencil 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('box 1', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('laptop 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 1', 'object'), Variable('drawer 2', 'receptacle'))), Proposition('inreceptacle', (Variable('alarmclock 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('laptop 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cellphone 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 2', 'object'), Variable('bed 1', 'receptacle'))), Proposition('inreceptacle', (Variable('keychain 1', 'object'), Variable('drawer 2', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 2', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('alarmclock 1', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cd 1', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('keychain 2', 'object'), Variable('shelf 1', 'receptacle'))), Proposition('inreceptacle', (Variable('creditcard 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('cellphone 2', 'object'), Variable('sidetable 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pencil 2', 'object'), Variable('drawer 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pen 1', 'object'), Variable('shelf 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pencil 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('book 3', 'object'), Variable('desk 1', 'receptacle'))), Proposition('inreceptacle', (Variable('pillow 1', 'object'), Variable('bed 1', 'receptacle'))), Proposition('atlocation', (Variable('agent1 ', 'agent'), Variable('loc 15', 'location'))), Proposition('objecttype', (Variable('book 1', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 1', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('book 3', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('laptop 1', 'object'), Variable('laptoptype ', 'otype'))), Proposition('objecttype', (Variable('pillow 1', 'object'), Variable('pillowtype ', 'otype'))), Proposition('objecttype', (Variable('window 1', 'object'), Variable('windowtype ', 'otype'))), Proposition('objecttype', (Variable('cellphone 2', 'object'), Variable('cellphonetype ', 'otype'))), Proposition('objecttype', (Variable('keychain 2', 'object'), Variable('keychaintype ', 'otype'))), Proposition('objecttype', (Variable('baseballbat 1', 'object'), Variable('baseballbattype ', 'otype'))), Proposition('objecttype', (Variable('basketball 1', 'object'), Variable('basketballtype ', 'otype'))), Proposition('objecttype', (Variable('keychain 1', 'object'), Variable('keychaintype ', 'otype'))), Proposition('objecttype', (Variable('pen 1', 'object'), Variable('pentype ', 'otype'))), Proposition('objecttype', (Variable('lightswitch 1', 'object'), Variable('lightswitchtype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 2', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('laptop 2', 'object'), Variable('laptoptype ', 'otype'))), Proposition('objecttype', (Variable('book 2', 'object'), Variable('booktype ', 'otype'))), Proposition('objecttype', (Variable('blinds 1', 'object'), Variable('blindstype ', 'otype'))), Proposition('objecttype', (Variable('alarmclock 2', 'object'), Variable('alarmclocktype ', 'otype'))), Proposition('objecttype', (Variable('cd 1', 'object'), Variable('cdtype ', 'otype'))), Proposition('objecttype', (Variable('creditcard 3', 'object'), Variable('creditcardtype ', 'otype'))), Proposition('objecttype', (Variable('tennisracket 1', 'object'), Variable('tennisrackettype ', 'otype'))), Proposition('objecttype', (Variable('chair 1', 'object'), Variable('chairtype ', 'otype'))), Proposition('objecttype', (Variable('box 1', 'object'), Variable('boxtype ', 'otype'))), Proposition('objecttype', (Variable('pencil 1', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('pencil 2', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('pencil 3', 'object'), Variable('penciltype ', 'otype'))), Proposition('objecttype', (Variable('mirror 1', 'object'), Variable('mirrortype ', 'otype'))), Proposition('objecttype', (Variable('cellphone 1', 'object'), Variable('cellphonetype ', 'otype'))), Proposition('objecttype', (Variable('alarmclock 1', 'object'), Variable('alarmclocktype ', 'otype'))), Proposition('objecttype', (Variable('poster 1', 'object'), Variable('postertype ', 'otype'))), Proposition('objectatlocation', (Variable('pillow 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('alarmclock 1', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('pencil 3', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('tennisracket 1', 'object'), Variable('loc 5', 'location'))), Proposition('objectatlocation', (Variable('cellphone 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('lightswitch 1', 'object'), Variable('loc 13', 'location'))), Proposition('objectatlocation', (Variable('keychain 2', 'object'), Variable('loc 9', 'location'))), Proposition('objectatlocation', (Variable('book 3', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('cd 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('creditcard 1', 'object'), Variable('loc 17', 'location'))), Proposition('objectatlocation', (Variable('pencil 2', 'object'), Variable('loc 16', 'location'))), Proposition('objectatlocation', (Variable('chair 1', 'object'), Variable('loc 7', 'location'))), Proposition('objectatlocation', (Variable('laptop 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('book 1', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('cellphone 2', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('poster 1', 'object'), Variable('loc 1', 'location'))), Proposition('objectatlocation', (Variable('mirror 1', 'object'), Variable('loc 14', 'location'))), Proposition('objectatlocation', (Variable('pen 1', 'object'), Variable('loc 9', 'location'))), Proposition('objectatlocation', (Variable('book 2', 'object'), Variable('loc 11', 'location'))), Proposition('objectatlocation', (Variable('box 1', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('blinds 1', 'object'), Variable('loc 3', 'location'))), Proposition('objectatlocation', (Variable('laptop 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('creditcard 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('keychain 1', 'object'), Variable('loc 17', 'location'))), Proposition('objectatlocation', (Variable('baseballbat 1', 'object'), Variable('loc 6', 'location'))), Proposition('objectatlocation', (Variable('pencil 1', 'object'), Variable('loc 15', 'location'))), Proposition('objectatlocation', (Variable('basketball 1', 'object'), Variable('loc 10', 'location'))), Proposition('objectatlocation', (Variable('alarmclock 2', 'object'), Variable('loc 8', 'location'))), Proposition('objectatlocation', (Variable('window 1', 'object'), Variable('loc 4', 'location'))), Proposition('objectatlocation', (Variable('creditcard 3', 'object'), Variable('loc 8', 'location'))), Proposition('pickupable', (Variable('creditcard 1', 'object'),)), Proposition('pickupable', (Variable('laptop 2', 'object'),)), Proposition('pickupable', (Variable('book 1', 'object'),)), Proposition('pickupable', (Variable('basketball 1', 'object'),)), Proposition('pickupable', (Variable('cellphone 1', 'object'),)), Proposition('pickupable', (Variable('pencil 1', 'object'),)), Proposition('pickupable', (Variable('pen 1', 'object'),)), Proposition('pickupable', (Variable('keychain 2', 'object'),)), Proposition('pickupable', (Variable('pencil 3', 'object'),)), Proposition('pickupable', (Variable('cd 1', 'object'),)), Proposition('pickupable', (Variable('cellphone 2', 'object'),)), Proposition('pickupable', (Variable('box 1', 'object'),)), Proposition('pickupable', (Variable('creditcard 2', 'object'),)), Proposition('pickupable', (Variable('alarmclock 2', 'object'),)), Proposition('pickupable', (Variable('book 3', 'object'),)), Proposition('pickupable', (Variable('tennisracket 1', 'object'),)), Proposition('pickupable', (Variable('keychain 1', 'object'),)), Proposition('pickupable', (Variable('pillow 1', 'object'),)), Proposition('pickupable', (Variable('pencil 2', 'object'),)), Proposition('pickupable', (Variable('book 2', 'object'),)), Proposition('pickupable', (Variable('creditcard 3', 'object'),)), Proposition('pickupable', (Variable('baseballbat 1', 'object'),)), Proposition('pickupable', (Variable('alarmclock 1', 'object'),)), Proposition('pickupable', (Variable('laptop 1', 'object'),)), Proposition('receptacletype', (Variable('cabinet 2', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('drawer 2', 'receptacle'), Variable('drawertype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 4', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('shelf 1', 'receptacle'), Variable('shelftype ', 'rtype'))), Proposition('receptacletype', (Variable('drawer 1', 'receptacle'), Variable('drawertype ', 'rtype'))), Proposition('receptacletype', (Variable('garbagecan 1', 'receptacle'), Variable('garbagecantype ', 'rtype'))), Proposition('receptacletype', (Variable('sidetable 1', 'receptacle'), Variable('sidetabletype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 3', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('cabinet 1', 'receptacle'), Variable('cabinettype ', 'rtype'))), Proposition('receptacletype', (Variable('bed 1', 'receptacle'), Variable('bedtype ', 'rtype'))), Proposition('receptacletype', (Variable('desk 1', 'receptacle'), Variable('desktype ', 'rtype'))), Proposition('receptacleatlocation', (Variable('drawer 1', 'receptacle'), Variable('loc 16', 'location'))), Proposition('receptacleatlocation', (Variable('garbagecan 1', 'receptacle'), Variable('loc 19', 'location'))), Proposition('receptacleatlocation', (Variable('bed 1', 'receptacle'), Variable('loc 11', 'location'))), Proposition('receptacleatlocation', (Variable('shelf 1', 'receptacle'), Variable('loc 9', 'location'))), Proposition('receptacleatlocation', (Variable('drawer 2', 'receptacle'), Variable('loc 17', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 4', 'receptacle'), Variable('loc 12', 'location'))), Proposition('receptacleatlocation', (Variable('sidetable 1', 'receptacle'), Variable('loc 15', 'location'))), Proposition('receptacleatlocation', (Variable('desk 1', 'receptacle'), Variable('loc 8', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 3', 'receptacle'), Variable('loc 9', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 2', 'receptacle'), Variable('loc 2', 'location'))), Proposition('receptacleatlocation', (Variable('cabinet 1', 'receptacle'), Variable('loc 2', 'location'))), Proposition('openable', (Variable('cabinet 4', 'receptacle'),)), Proposition('openable', (Variable('cabinet 2', 'receptacle'),)), Proposition('openable', (Variable('cabinet 3', 'receptacle'),)), Proposition('openable', (Variable('drawer 1', 'receptacle'),)), Proposition('openable', (Variable('cabinet 1', 'receptacle'),)), Proposition('openable', (Variable('drawer 2', 'receptacle'),)), Proposition('isreceptacleobject', (Variable('box 1', 'object'),)), Proposition('not_atlocation', (Variable('agent1 ', 'agent'), Variable('loc 18', 'location'))), Proposition('not_atlocation', (Variable('agent1 ', 'agent'), Variable('loc 9', 'location')))]})))]],\n 'finish_log': [\"Couldn't solve the task. Last log: Couldn't solve the task. Last log: Couldn't solve the task. Last log: Couldn't solve the task. Last log: Maximum decomposition depth reached.\"]}" }, - "execution_count": 9, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -524,12 +528,22 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-04-24T09:50:47.930079Z", - "start_time": "2024-04-24T09:50:32.898494Z" + "end_time": "2024-05-22T11:49:04.835275Z", + "start_time": "2024-05-22T11:48:45.217344Z" } }, "id": "bfcf7662db19240c", - "execution_count": 9 + "execution_count": 12 + }, + { + "cell_type": "code", + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "f5b8866ffd80b2ee", + "execution_count": null } ], "metadata": { diff --git a/environments/alfworld/common/environment.py b/environments/alfworld/common/environment.py index e47e8bd..3e241a6 100644 --- a/environments/alfworld/common/environment.py +++ b/environments/alfworld/common/environment.py @@ -1,17 +1,19 @@ from __future__ import annotations + +from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple + import gymnasium as gym +import yaml # type: ignore[import-untyped] from langchain_core.agents import AgentAction +from langchain_core.callbacks import CallbackManager from langchain_core.tools import BaseTool -import alfworld.agents.environment as environment # type: ignore[import-untyped] -import yaml # type: ignore[import-untyped] -from typing import Dict, Any, Tuple, Optional, Sequence -from gymnasium.core import SupportsFloat -from .tools import get_alfworld_tools -from planning_library.action_executors import LangchainActionExecutor from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped] -from langchain_core.callbacks import CallbackManager +import alfworld.agents.environment as environment # type: ignore[import-untyped] from alfworld.agents.environment.alfred_tw_env import AlfredTWEnv # type: ignore[import-untyped] +from planning_library.action_executors import LangchainActionExecutor + +from .tools import get_alfworld_tools class ALFWorldEnv(gym.Env[str, Tuple[AgentAction, Optional[CallbackManager]]]): @@ -21,13 +23,9 @@ def __init__( ): with open(config_path) as reader: config = yaml.safe_load(reader) - self._alfworld_env: AlfredTWEnv = getattr(environment, config["env"]["type"])( - config, train_eval="train" - ) + self._alfworld_env: AlfredTWEnv = getattr(environment, config["env"]["type"])(config, train_eval="train") self.env: TextworldBatchGymEnv = self._alfworld_env.init_env(batch_size=1) - self._action_executor = LangchainActionExecutor( - tools=get_alfworld_tools(env=self.env) - ) + self._action_executor = LangchainActionExecutor(tools=get_alfworld_tools(env=self.env)) @property def tools(self) -> Sequence[BaseTool]: @@ -37,10 +35,10 @@ def seed(self, seed: Optional[int] = None): self.env.seed(seed) def step( - self, inputs: Tuple[AgentAction, Optional[CallbackManager]] + self, action: Tuple[AgentAction, Optional[CallbackManager]] ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: - action, run_manager = inputs - result = self._action_executor.execute(action, run_manager=run_manager) + lc_action, run_manager = action + result = self._action_executor.execute(lc_action, run_manager=run_manager) try: observation, reward, terminated, truncated, info = result.observation except ValueError: @@ -59,9 +57,7 @@ def reset( ) -> Tuple[str, Dict[str, Any]]: if not options or "next_episode" not in options or not options["next_episode"]: self.env = self._alfworld_env.init_env(batch_size=1) - self._action_executor = LangchainActionExecutor( - tools=get_alfworld_tools(env=self.env) - ) + self._action_executor = LangchainActionExecutor(tools=get_alfworld_tools(env=self.env)) obs, infos = self.env.reset() observation = obs[0] diff --git a/environments/alfworld/common/evaluate_output_parser.py b/environments/alfworld/common/evaluate_output_parser.py index 1e40fc2..312c8d4 100644 --- a/environments/alfworld/common/evaluate_output_parser.py +++ b/environments/alfworld/common/evaluate_output_parser.py @@ -15,6 +15,4 @@ def parse(self, text: str) -> float: raise ValueError("The given number is out of (0.0, 1.0) range.") return result except ValueError: - raise OutputParserException( - f"Couldn't convert {text} to float between 0 and 1." - ) + raise OutputParserException(f"Couldn't convert {text} to float between 0 and 1.") diff --git a/environments/alfworld/common/tools.py b/environments/alfworld/common/tools.py index d50cf7e..ffd66b8 100644 --- a/environments/alfworld/common/tools.py +++ b/environments/alfworld/common/tools.py @@ -1,16 +1,16 @@ +from typing import Any, Dict, List, SupportsFloat, Tuple, Type + from langchain.pydantic_v1 import BaseModel from langchain.tools import BaseTool -from typing import Type, Any, Tuple, Dict, List +from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped] -from gymnasium.core import SupportsFloat from .tools_utils import ( BaseALFWorldTool, - ReceptableInput, - ObjectOrReceptableInput, - ObjectAndReceptableInput, EmptyInput, + ObjectAndReceptableInput, + ObjectOrReceptableInput, + ReceptableInput, ) -from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped] def get_alfworld_tools(env: TextworldBatchGymEnv) -> List[BaseTool]: @@ -33,7 +33,7 @@ def get_alfworld_tools(env: TextworldBatchGymEnv) -> List[BaseTool]: class GoToTool(BaseALFWorldTool, BaseTool): name = "goto" description = """Go to the specified receptable (static object).""" - args_schema: Type[BaseModel] = ReceptableInput + args_schema: Type[BaseModel] = ReceptableInput # type: ignore def _run( self, @@ -42,16 +42,14 @@ def _run( *args: Any, **kwargs: Any, ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: - obs, scores, dones, infos = self.env.step( - [f"go to {receptable_type} {receptable_id}"] - ) + obs, scores, dones, infos = self.env.step([f"go to {receptable_type} {receptable_id}"]) return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} class OpenTool(BaseALFWorldTool, BaseTool): name = "open" description = """Open a specified receptable (static object). Only works when you're near a receptable and when it is closed.""" - args_schema: Type[BaseModel] = ReceptableInput + args_schema: Type[BaseModel] = ReceptableInput # type: ignore def _run( self, @@ -60,16 +58,14 @@ def _run( *args: Any, **kwargs: Any, ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: - obs, scores, dones, infos = self.env.step( - [f"open {receptable_type} {receptable_id}"] - ) + obs, scores, dones, infos = self.env.step([f"open {receptable_type} {receptable_id}"]) return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} class CloseTool(BaseALFWorldTool, BaseTool): name = "close" description = """Close a specified receptable (static object). Only available when you're near a receptable and when it is closed.""" - args_schema: Type[BaseModel] = ReceptableInput + args_schema: Type[BaseModel] = ReceptableInput # type: ignore def _run( self, @@ -78,16 +74,14 @@ def _run( *args: Any, **kwargs: Any, ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: - obs, scores, dones, infos = self.env.step( - [f"close {receptable_type} {receptable_id}"] - ) + obs, scores, dones, infos = self.env.step([f"close {receptable_type} {receptable_id}"]) return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} class TakeTool(BaseALFWorldTool, BaseTool): name = "take" description = """Pick up the specified portable object from the specified receptable (static object). Only works when you're near the specified receptable and the specified object is present in/on the receptable.""" - args_schema: Type[BaseModel] = ObjectAndReceptableInput + args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore def _run( self, @@ -107,7 +101,7 @@ def _run( class PutTool(BaseALFWorldTool, BaseTool): name = "put" description = """Put the specified portable object in/on the specified receptable (static object). Only available when you're near the specified receptable and carry the specified portable object in your inventory.""" - args_schema: Type[BaseModel] = ObjectAndReceptableInput + args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore def _run( self, @@ -127,7 +121,7 @@ def _run( class ToggleTool(BaseALFWorldTool, BaseTool): name = "toggle" description = """Toggle the specified object on/off (can be either a portable object or a static receptable). Only available when you're near the specified receptable/portable object or carry the specified portable object.""" - args_schema: Type[BaseModel] = ObjectOrReceptableInput + args_schema: Type[BaseModel] = ObjectOrReceptableInput # type: ignore def _run( self, @@ -143,7 +137,7 @@ def _run( class HeatTool(BaseALFWorldTool, BaseTool): name = "heat" description = """Heat the portable object via the receptable (static object). Only available when you're already near the receptable and the portable object is in/on the receptable.""" - args_schema: Type[BaseModel] = ObjectAndReceptableInput + args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore def _run( self, @@ -163,7 +157,7 @@ def _run( class CoolTool(BaseALFWorldTool, BaseTool): name = "cool" description = """Cool the portable object via the receptable (static object). Only available when you're already near a receptable and the portable object is in/on the receptable.""" - args_schema: Type[BaseModel] = ObjectAndReceptableInput + args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore def _run( self, @@ -183,7 +177,7 @@ def _run( class CleanTool(BaseALFWorldTool, BaseTool): name = "clean" description = """Clean the portable object via the receptable (static object). Only available when you're already near a receptable and the portable object is in/on the receptable.""" - args_schema: Type[BaseModel] = ObjectAndReceptableInput + args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore def _run( self, @@ -203,7 +197,7 @@ def _run( class ExamineTool(BaseALFWorldTool, BaseTool): name = "examine" description = """Examine the specified object (can be either a portable object or a static receptable). Only available when you're near the receptable/portable object or carry the specified portable object.""" - args_schema: Type[BaseModel] = ObjectOrReceptableInput + args_schema: Type[BaseModel] = ObjectOrReceptableInput # type: ignore def _run( self, @@ -219,7 +213,7 @@ def _run( class InventoryTool(BaseALFWorldTool, BaseTool): name = "inventory" description = """Check if you are carrying any portable objects.""" - args_schema: Type[BaseModel] = EmptyInput + args_schema: Type[BaseModel] = EmptyInput # type: ignore def _run( self, @@ -233,7 +227,7 @@ def _run( class LookTool(BaseALFWorldTool, BaseTool): name = "look" description = """Check your surroundings.""" - args_schema: Type[BaseModel] = EmptyInput + args_schema: Type[BaseModel] = EmptyInput # type: ignore def _run( self, diff --git a/environments/alfworld/common/tools_utils.py b/environments/alfworld/common/tools_utils.py index 92e1c3a..93bcd5a 100644 --- a/environments/alfworld/common/tools_utils.py +++ b/environments/alfworld/common/tools_utils.py @@ -7,9 +7,7 @@ class EmptyInput(BaseModel): ... class ObjectInput(BaseModel): - object_type: str = Field( - description="A type of the portable object.", examples=["apple", "mug"] - ) + object_type: str = Field(description="A type of the portable object.", examples=["apple", "mug"]) object_id: int = Field( description="A specific number associated with the object (e.g., when there are " "several mugs in the room, those would be mug 1 and mug 2).", @@ -30,9 +28,7 @@ class ReceptableInput(BaseModel): class ObjectAndReceptableInput(BaseModel): - object_type: str = Field( - description="A type of the portable object.", examples=["apple", "mug"] - ) + object_type: str = Field(description="A type of the portable object.", examples=["apple", "mug"]) object_id: int = Field( description="A specific number associated with the object (e.g., when there are " "several mugs in the room, those would be mug 1 and mug 2).", diff --git a/environments/frozen_lake/common/__init__.py b/environments/frozen_lake/common/__init__.py index c51d3e3..ff6ebcf 100644 --- a/environments/frozen_lake/common/__init__.py +++ b/environments/frozen_lake/common/__init__.py @@ -1,6 +1,6 @@ -from .tools import MoveTool, CheckMapTool, CheckPositionTool from .environment import FrozenLakeEnvWrapper from .evaluate_output_parser import FrozenMapEvaluateOutputParser +from .tools import CheckMapTool, CheckPositionTool, MoveTool __all__ = [ "MoveTool", diff --git a/environments/frozen_lake/common/environment.py b/environments/frozen_lake/common/environment.py index adac441..ba05cb8 100644 --- a/environments/frozen_lake/common/environment.py +++ b/environments/frozen_lake/common/environment.py @@ -1,16 +1,17 @@ from __future__ import annotations -from typing import Any, Dict, Tuple, Sequence, Optional + +from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple import gymnasium as gym -from gymnasium.core import ObsType, SupportsFloat from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv from langchain_core.agents import AgentAction -from langchain_core.tools import BaseTool from langchain_core.callbacks import CallbackManager +from langchain_core.tools import BaseTool -from .tools import MoveTool from planning_library.action_executors import LangchainActionExecutor +from .tools import MoveTool + class FrozenLakeEnvWrapper(gym.Wrapper): def __init__(self, env: FrozenLakeEnv): @@ -22,10 +23,10 @@ def tools(self) -> Sequence[BaseTool]: return self._action_executor.tools def step( - self, inputs: Tuple[AgentAction, Optional[CallbackManager]] + self, action: Tuple[AgentAction, Optional[CallbackManager]] ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: - action, run_manager = inputs - result = self._action_executor.execute(action) + lc_action, run_manager = action + result = self._action_executor.execute(lc_action, run_manager=run_manager) return result.observation def reset( @@ -33,7 +34,7 @@ def reset( *, seed: int | None = None, options: Dict[str, Any] | None = None, - ) -> Tuple[ObsType, Dict[str, Any]]: + ) -> Tuple[str, Dict[str, Any]]: observation, info = self.env.reset(seed=seed, options=options) if options is not None and "trajectory" in options: diff --git a/environments/frozen_lake/common/evaluate_output_parser.py b/environments/frozen_lake/common/evaluate_output_parser.py index a232b0d..6b6b76e 100644 --- a/environments/frozen_lake/common/evaluate_output_parser.py +++ b/environments/frozen_lake/common/evaluate_output_parser.py @@ -15,6 +15,4 @@ def parse(self, text: str) -> float: raise ValueError("The given number is out of (0.0, 1.0) range.") return result except ValueError: - raise OutputParserException( - f"Couldn't convert {text} to float between 0 and 1." - ) + raise OutputParserException(f"Couldn't convert {text} to float between 0 and 1.") diff --git a/environments/frozen_lake/common/tools.py b/environments/frozen_lake/common/tools.py index cc93094..d60a94f 100644 --- a/environments/frozen_lake/common/tools.py +++ b/environments/frozen_lake/common/tools.py @@ -1,10 +1,9 @@ from textwrap import dedent -from typing import Any, Literal, Tuple, Type, Dict +from typing import Any, Dict, Literal, SupportsFloat, Tuple, Type import gymnasium as gym from langchain.pydantic_v1 import BaseModel, Field from langchain.tools import BaseTool -from gymnasium.core import SupportsFloat class BaseFrozenLakeTool(BaseModel): @@ -19,9 +18,7 @@ class Config(BaseTool.Config): class MoveInput(BaseModel): - direction: Literal["left", "right", "down", "up"] = Field( - description="Which direction to move." - ) + direction: Literal["left", "right", "down", "up"] = Field(description="Which direction to move.") class MoveTool(BaseFrozenLakeTool, BaseTool): @@ -35,12 +32,10 @@ class MoveTool(BaseFrozenLakeTool, BaseTool): * truncated: if True, the time limit has been exceeded; * info: probability of moving in the wrong direction for the current cell (ice is slippery!)""" ) - args_schema: Type[BaseModel] = MoveInput + args_schema: Type[BaseModel] = MoveInput # type: ignore @staticmethod - def _convert_frozenlake_observation_to_position( - observation: int, nrow: int - ) -> Tuple[int, int]: + def _convert_frozenlake_observation_to_position(observation: int, nrow: int) -> Tuple[int, int]: # FrozenLake: observation = current_row * nrow + current_col current_row, current_col = observation // nrow, observation % nrow return current_col, current_row @@ -68,16 +63,12 @@ def _run( MoveTool._convert_direction_to_frozenlake(direction) ) nrow = self.env.get_wrapper_attr("nrow") - observation = MoveTool._convert_frozenlake_observation_to_position( - observation=_observation, nrow=nrow - ) + observation = MoveTool._convert_frozenlake_observation_to_position(observation=_observation, nrow=nrow) return observation, reward, terminated, truncated, info class LookInput(BaseModel): - direction: Literal["left", "right", "down", "up"] = Field( - description="Which direction to look at." - ) + direction: Literal["left", "right", "down", "up"] = Field(description="Which direction to look at.") class LookTool(BaseFrozenLakeTool, BaseTool): @@ -90,7 +81,7 @@ class LookTool(BaseFrozenLakeTool, BaseTool): * F - frozen cell; * G - goal. """) - args_schema: Type[BaseModel] = LookInput + args_schema: Type[BaseModel] = LookInput # type: ignore def _run( self, @@ -113,9 +104,7 @@ def _run( elif direction == "up": observation = "out of bounds" if y == 0 else board[x][y - 1].decode() else: - raise ValueError( - "Wrong direction; expected one of: 'left', 'right', 'down', 'up'." - ) + raise ValueError("Wrong direction; expected one of: 'left', 'right', 'down', 'up'.") info: Dict[str, Any] reward, terminated, truncated, info = ( @@ -147,7 +136,7 @@ class CheckMapTool(BaseFrozenLakeTool, BaseTool): SH FG """) - args_schema: Type[BaseModel] = CheckMapInput + args_schema: Type[BaseModel] = CheckMapInput # type: ignore def _run( self, @@ -156,10 +145,7 @@ def _run( ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: info: Dict[str, Any] observation, reward, terminated, truncated, info = ( - "\n".join( - "".join(x.decode() for x in y) - for y in self.env.get_wrapper_attr("desc") - ), + "\n".join("".join(x.decode() for x in y) for y in self.env.get_wrapper_attr("desc")), 0, False, False, @@ -174,7 +160,7 @@ class CheckPositionInput(BaseModel): ... class CheckPositionTool(BaseFrozenLakeTool, BaseTool): name = "check_position" description = """Peeks at current position map without changing its state.""" - args_schema: Type[BaseModel] = CheckMapInput + args_schema: Type[BaseModel] = CheckMapInput # type: ignore def _run( self, diff --git a/environments/frozen_lake/reflexion/self_reflection_prompts.py b/environments/frozen_lake/reflexion/self_reflection_prompts.py index e6223e4..f0c75ae 100644 --- a/environments/frozen_lake/reflexion/self_reflection_prompts.py +++ b/environments/frozen_lake/reflexion/self_reflection_prompts.py @@ -1,9 +1,10 @@ +from textwrap import dedent + from langchain.prompts import ( ChatPromptTemplate, MessagesPlaceholder, ) from langchain_core.output_parsers import BaseOutputParser -from textwrap import dedent self_reflection_prompt = ChatPromptTemplate.from_messages( [ diff --git a/environments/frozen_lake/tot_dfs.ipynb b/environments/frozen_lake/tot_dfs.ipynb index 4aace6a..3c07893 100644 --- a/environments/frozen_lake/tot_dfs.ipynb +++ b/environments/frozen_lake/tot_dfs.ipynb @@ -810,7 +810,8 @@ "metadata": { "collapsed": false }, - "id": "c82623bf5d35dafd" + "id": "c82623bf5d35dafd", + "execution_count": null } ], "metadata": { diff --git a/environments/game_of_24/environment.py b/environments/game_of_24/environment.py index 8ac81a8..8d1f431 100644 --- a/environments/game_of_24/environment.py +++ b/environments/game_of_24/environment.py @@ -1,16 +1,17 @@ from __future__ import annotations + from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, Sequence +from typing import Any, Dict, List, Optional, Sequence, SupportsFloat, Tuple import gymnasium as gym -from gymnasium.core import SupportsFloat from langchain_core.agents import AgentAction -from langchain_core.tools import BaseTool from langchain_core.callbacks import CallbackManager +from langchain_core.tools import BaseTool -from .tools import AddTool, MultiplyTool, SubtractTool, DivideTool from planning_library.action_executors import LangchainActionExecutor +from .tools import AddTool, DivideTool, MultiplyTool, SubtractTool + class GameOf24Env(gym.Env[str, Tuple[AgentAction, Optional[CallbackManager]]]): def __init__(self, numbers: Optional[List[float | int]] = None): @@ -30,9 +31,7 @@ def __init__(self, numbers: Optional[List[float | int]] = None): @property def numbers(self) -> str: - return " ".join( - [str(key) for key, value in self._numbers.items() for _ in range(value)] - ) + return " ".join([str(key) for key, value in self._numbers.items() for _ in range(value)]) @numbers.setter def numbers(self, numbers: List[float | int]): @@ -74,10 +73,10 @@ def verify_arguments(self, number1: float, number2: float) -> bool: ) def step( - self, inputs: Tuple[AgentAction, Optional[CallbackManager]] + self, action: Tuple[AgentAction, Optional[CallbackManager]] ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: - action, run_manager = inputs - result = self._action_executor.execute(action, run_manager=run_manager) + lc_action, run_manager = action + result = self._action_executor.execute(lc_action, run_manager=run_manager) return result.observation def reset( @@ -94,9 +93,7 @@ def reset( if options is not None and "trajectory" in options: for action, step in options["trajectory"]: - assert isinstance( - action, AgentAction - ), f"Expected AgentAction, got {action}" + assert isinstance(action, AgentAction), f"Expected AgentAction, got {action}" observation, reward, terminated, truncated, info = self.step( ( action, diff --git a/environments/game_of_24/tools.py b/environments/game_of_24/tools.py index 952a108..2419b1f 100644 --- a/environments/game_of_24/tools.py +++ b/environments/game_of_24/tools.py @@ -1,11 +1,10 @@ +from abc import ABC, abstractmethod from textwrap import dedent -from typing import Any, Tuple, Type, Dict +from typing import Any, Dict, SupportsFloat, Tuple, Type +import gymnasium as gym from langchain.pydantic_v1 import BaseModel, Field from langchain.tools import BaseTool -from gymnasium.core import SupportsFloat -import gymnasium as gym -from abc import ABC, abstractmethod class BaseGameof24Tool(BaseModel, ABC): @@ -57,12 +56,8 @@ def _run( class CalculatorInput(BaseModel): - number1: float = Field( - description="The first argument in an arithmetical operation." - ) - number2: float = Field( - description="The second argument in an arithmetical operation." - ) + number1: float = Field(description="The first argument in an arithmetical operation.") + number2: float = Field(description="The second argument in an arithmetical operation.") class AddTool(BaseGameof24Tool, BaseTool): @@ -74,7 +69,7 @@ class AddTool(BaseGameof24Tool, BaseTool): * terminated: if True, the game has ended: there's no possible actions anymore; * truncated: if True, the time limit has been exceeded; * info: the remaining numbers""") - args_schema: Type[BaseModel] = CalculatorInput + args_schema: Type[BaseModel] = CalculatorInput # type: ignore def _operation(self, number1: float, number2: float) -> float: return number1 + number2 @@ -89,7 +84,7 @@ class SubtractTool(BaseGameof24Tool, BaseTool): * terminated: if True, the game has ended: there's no possible actions anymore; * truncated: if True, the time limit has been exceeded; * info: the remaining numbers""") - args_schema: Type[BaseModel] = CalculatorInput + args_schema: Type[BaseModel] = CalculatorInput # type: ignore def _operation(self, number1: float, number2: float) -> float: return number1 - number2 @@ -104,7 +99,7 @@ class MultiplyTool(BaseGameof24Tool, BaseTool): * terminated: if True, the game has ended: there's no possible actions anymore; * truncated: if True, the time limit has been exceeded; * info: the remaining numbers""") - args_schema: Type[BaseModel] = CalculatorInput + args_schema: Type[BaseModel] = CalculatorInput # type: ignore def _operation(self, number1: float, number2: float) -> float: return number1 * number2 @@ -119,7 +114,7 @@ class DivideTool(BaseGameof24Tool, BaseTool): * terminated: if True, the game has ended: there's no possible actions anymore; * truncated: if True, the time limit has been exceeded; * info: the remaining numbers""") - args_schema: Type[BaseModel] = CalculatorInput + args_schema: Type[BaseModel] = CalculatorInput # type: ignore def _operation(self, number1: float, number2: float) -> float: return number1 / number2 diff --git a/planning_library/action_executors/base_action_executor.py b/planning_library/action_executors/base_action_executor.py index 840c74f..fdb99ff 100644 --- a/planning_library/action_executors/base_action_executor.py +++ b/planning_library/action_executors/base_action_executor.py @@ -1,10 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, overload, Sequence, Optional +from typing import List, Optional, Sequence, overload + from langchain_core.agents import AgentAction, AgentStep +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager from langchain_core.tools import BaseTool -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager class BaseActionExecutor(ABC): @@ -33,6 +34,7 @@ async def areset( ... @overload + @abstractmethod def execute( self, actions: List[AgentAction], @@ -41,6 +43,7 @@ def execute( ) -> List[AgentStep]: ... @overload + @abstractmethod def execute( self, actions: AgentAction, @@ -67,6 +70,7 @@ def execute( ... @overload + @abstractmethod async def aexecute( self, actions: List[AgentAction], @@ -75,6 +79,7 @@ async def aexecute( ) -> List[AgentStep]: ... @overload + @abstractmethod async def aexecute( self, actions: AgentAction, diff --git a/planning_library/action_executors/default_action_executor.py b/planning_library/action_executors/default_action_executor.py index 33bd36f..4cadec5 100644 --- a/planning_library/action_executors/default_action_executor.py +++ b/planning_library/action_executors/default_action_executor.py @@ -1,25 +1,23 @@ from __future__ import annotations -from typing import List, overload, Sequence, Optional + +from typing import List, Optional, Sequence, overload from langchain_core.agents import AgentAction, AgentStep -from langchain_core.tools import BaseTool -from langgraph.prebuilt.tool_executor import ToolExecutor # type: ignore[import-untyped] -from .base_action_executor import BaseActionExecutor from langchain_core.callbacks import ( - CallbackManager, AsyncCallbackManager, + CallbackManager, ) +from langchain_core.tools import BaseTool +from langgraph.prebuilt.tool_executor import ToolExecutor # type: ignore[import-untyped] + +from .base_action_executor import BaseActionExecutor from .meta_tools import MetaTools class LangchainActionExecutor(BaseActionExecutor): - def __init__( - self, tools: Sequence[BaseTool], meta_tools: Optional[MetaTools] = None - ): + def __init__(self, tools: Sequence[BaseTool], meta_tools: Optional[MetaTools] = None): self._tool_executor = ToolExecutor(tools) - self._meta_tool_executor = ( - ToolExecutor(meta_tools.tools) if meta_tools else None - ) + self._meta_tool_executor = ToolExecutor(meta_tools.tools) if meta_tools else None self._meta_tool_names = meta_tools.tool_names_map if meta_tools else {} @property @@ -48,7 +46,7 @@ def reset( log="Invoking reset tool.", ) ], - tool_executor=self._meta_tool_executor, + tool_executor=self._meta_tool_executor, # type: ignore[reportArgumentType] run_manager=run_manager, ) if actions: @@ -70,6 +68,14 @@ def execute( **kwargs, ) -> AgentStep: ... + def execute( + self, + actions: List[AgentAction] | AgentAction, + run_manager: Optional[CallbackManager] = None, + **kwargs, + ) -> List[AgentStep] | AgentStep: + return self._execute(actions, self._tool_executor, run_manager) + def _execute( self, actions: List[AgentAction] | AgentAction, @@ -92,14 +98,6 @@ def _execute( ) return AgentStep(action=actions, observation=observation) - def execute( - self, - actions: List[AgentAction] | AgentAction, - run_manager: Optional[CallbackManager] = None, - **kwargs, - ) -> List[AgentStep] | AgentStep: - return self._execute(actions, self._tool_executor, run_manager) - async def areset( self, actions: Optional[List[AgentAction]] = None, @@ -116,7 +114,7 @@ async def areset( log="Invoking reset tool.", ) ], - tool_executor=self._meta_tool_executor, + tool_executor=self._meta_tool_executor, # type: ignore[reportArgumentType] run_manager=run_manager, ) if actions: @@ -138,6 +136,14 @@ async def aexecute( **kwargs, ) -> AgentStep: ... + async def aexecute( + self, + actions: List[AgentAction] | AgentAction, + run_manager: Optional[AsyncCallbackManager] = None, + **kwargs, + ) -> List[AgentStep] | AgentStep: + return await self._aexecute(actions, self._tool_executor, run_manager) + async def _aexecute( self, actions: List[AgentAction] | AgentAction, @@ -160,11 +166,3 @@ async def _aexecute( config={"callbacks": run_manager} if run_manager else {}, ) return AgentStep(action=actions, observation=observation) - - async def aexecute( - self, - actions: List[AgentAction] | AgentAction, - run_manager: Optional[AsyncCallbackManager] = None, - **kwargs, - ) -> List[AgentStep] | AgentStep: - return await self._aexecute(actions, self._tool_executor, run_manager) diff --git a/planning_library/action_executors/meta_tools.py b/planning_library/action_executors/meta_tools.py index 8b85a4c..2f990a9 100644 --- a/planning_library/action_executors/meta_tools.py +++ b/planning_library/action_executors/meta_tools.py @@ -1,6 +1,7 @@ -from langchain_core.tools import BaseTool -from typing import Optional, List, Dict from dataclasses import dataclass, fields +from typing import Dict, List, Optional + +from langchain_core.tools import BaseTool @dataclass diff --git a/planning_library/components/__init__.py b/planning_library/components/__init__.py index 8d5aaca..02196c4 100644 --- a/planning_library/components/__init__.py +++ b/planning_library/components/__init__.py @@ -1,6 +1,6 @@ +from .agent_component import AgentComponent from .base_component import BaseComponent from .runnable_component import RunnableComponent -from .agent_component import AgentComponent __all__ = [ "BaseComponent", diff --git a/planning_library/components/agent_component.py b/planning_library/components/agent_component.py index 0127603..5b362ee 100644 --- a/planning_library/components/agent_component.py +++ b/planning_library/components/agent_component.py @@ -1,22 +1,25 @@ from __future__ import annotations -from typing import Optional, List, Union, Sequence, Callable, Dict, Awaitable + +from typing import Awaitable, Callable, Dict, List, Optional, Sequence, Union + +from langchain.agents.agent import BaseMultiActionAgent, BaseSingleActionAgent, RunnableAgent, RunnableMultiActionAgent from langchain_core.agents import AgentAction, AgentFinish -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager -from .base_component import InputType, BaseComponent -from langchain.agents.agent import BaseMultiActionAgent, BaseSingleActionAgent +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager from langchain_core.language_models import BaseChatModel -from langchain_core.tools import BaseTool from langchain_core.prompts import ChatPromptTemplate -from planning_library.utils import ( - convert_runnable_to_agent, -) -from langchain.agents.agent import RunnableAgent, RunnableMultiActionAgent -from langchain_core.runnables import RunnableLambda, Runnable +from langchain_core.runnables import Runnable, RunnableLambda +from langchain_core.tools import BaseTool + from planning_library.function_calling_parsers import ( - BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser, + BaseFunctionCallingSingleActionParser, ParserRegistry, ) +from planning_library.utils import ( + convert_runnable_to_agent, +) + +from .base_component import BaseComponent, InputType class AgentFactory: @@ -35,26 +38,17 @@ def create_agent( ) -> Union[RunnableAgent, RunnableMultiActionAgent]: if parser is None: if parser_name is None: - raise ValueError( - "Either parser or parser_name should be provided to instantiate an agent." - ) + raise ValueError("Either parser or parser_name should be provided to instantiate an agent.") parser = ParserRegistry.get_parser(parser_name) llm_with_tools = parser.prepare_llm(llm=llm, tools=tools) - runnable: Runnable = ( - RunnableLambda(parser.format_inputs) - | prompt - | llm_with_tools - | parser.output_parser - ) + runnable: Runnable = RunnableLambda(parser.format_inputs) | prompt | llm_with_tools | parser.output_parser agent = convert_runnable_to_agent(runnable) return agent -class AgentComponent( - BaseComponent[InputType, Union[List[AgentAction], AgentAction, AgentFinish]] -): +class AgentComponent(BaseComponent[InputType, Union[List[AgentAction], AgentAction, AgentFinish]]): def __init__( self, agent: BaseSingleActionAgent | BaseMultiActionAgent, @@ -75,9 +69,7 @@ def add_input_preprocessing( apreprocess: Optional[Callable[[InputType], Awaitable[Dict]]] = None, ) -> None: if hasattr(self.agent, "runnable"): - self.agent.runnable = ( - RunnableLambda(preprocess, afunc=apreprocess) | self.agent.runnable - ) + self.agent.runnable = RunnableLambda(preprocess, afunc=apreprocess) | self.agent.runnable # type: ignore[reportAttributeAccessIssue] def add_output_preprocessing( self, @@ -93,15 +85,14 @@ def add_output_preprocessing( ] = None, ) -> None: if hasattr(self.agent, "runnable"): - self.agent.runnable = self.agent.runnable | RunnableLambda( - preprocess, afunc=apreprocess - ) + self.agent.runnable = self.agent.runnable | RunnableLambda(preprocess, afunc=apreprocess) # type: ignore[reportAttributeAccessIssue] def invoke( self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs ) -> Union[List[AgentAction], AgentAction, AgentFinish]: # TODO: no way to pass name to plan? - return self.agent.plan(**inputs, callbacks=run_manager) + # TODO: intermediate_steps? + return self.agent.plan(**inputs, callbacks=run_manager) # type: ignore[reportCallIssue] async def ainvoke( self, @@ -110,7 +101,7 @@ async def ainvoke( **kwargs, ) -> Union[List[AgentAction], AgentAction, AgentFinish]: # TODO: no way to pass name to plan? - outputs = await self.agent.aplan(**inputs, callbacks=run_manager) + outputs = await self.agent.aplan(**inputs, callbacks=run_manager) # type: ignore[reportCallIssue] return outputs @classmethod @@ -129,9 +120,7 @@ def create( ] = None, parser_name: Optional[str] = None, ) -> "AgentComponent[InputType]": - prompt = cls._process_prompt( - prompt=prompt, user_message=user_message, system_message=system_message - ) + prompt = cls._process_prompt(prompt=prompt, user_message=user_message, system_message=system_message) return cls( agent=AgentFactory.create_agent( diff --git a/planning_library/components/base_component.py b/planning_library/components/base_component.py index 0d44050..c4895df 100644 --- a/planning_library/components/base_component.py +++ b/planning_library/components/base_component.py @@ -1,6 +1,7 @@ -from typing import Generic, TypeVar, Optional, Mapping, Set, Dict, Callable, Awaitable from abc import ABC, abstractmethod -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager +from typing import Awaitable, Callable, Dict, Generic, Mapping, Optional, Set, TypeVar + +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager from langchain_core.prompts import ChatPromptTemplate InputType = TypeVar("InputType", bound=Mapping) @@ -12,9 +13,7 @@ class BaseComponent(Generic[InputType, OutputType], ABC): required_prompt_input_vars: Set[str] = set() @classmethod - def _create_default_prompt( - cls, system_message: Optional[str], user_message: str, **kwargs - ) -> ChatPromptTemplate: + def _create_default_prompt(cls, system_message: Optional[str], user_message: str, **kwargs) -> ChatPromptTemplate: raise NotImplementedError( f"Default prompt is not supported for {cls.__name__}. Please, provide `prompt` instead of `user_message`." ) @@ -29,18 +28,12 @@ def _process_prompt( ) -> ChatPromptTemplate: if prompt is None: if user_message is None: - raise ValueError( - "Either `prompt` or `user_message` are required to create an agent." - ) - prompt = cls._create_default_prompt( - system_message=system_message, user_message=user_message, **kwargs - ) + raise ValueError("Either `prompt` or `user_message` are required to create an agent.") + prompt = cls._create_default_prompt(system_message=system_message, user_message=user_message, **kwargs) missing_vars = cls.required_prompt_input_vars.difference(prompt.input_variables) if missing_vars: - raise ValueError( - f"Prompt for {cls.__name__} missing required variables: {missing_vars}" - ) + raise ValueError(f"Prompt for {cls.__name__} missing required variables: {missing_vars}") return prompt @@ -57,9 +50,7 @@ def add_output_preprocessing( ) -> None: ... @abstractmethod - def invoke( - self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs - ) -> OutputType: ... + def invoke(self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs) -> OutputType: ... @abstractmethod async def ainvoke( diff --git a/planning_library/components/evaluation/__init__.py b/planning_library/components/evaluation/__init__.py index 4e4b70d..e181818 100644 --- a/planning_library/components/evaluation/__init__.py +++ b/planning_library/components/evaluation/__init__.py @@ -1,4 +1,4 @@ from .evaluator_component import EvaluatorComponent -from .threshold_judge import LeqThresholdJudge, GeqThresholdJudge +from .threshold_judge import GeqThresholdJudge, LeqThresholdJudge __all__ = ["EvaluatorComponent", "LeqThresholdJudge", "GeqThresholdJudge"] diff --git a/planning_library/components/evaluation/evaluator_component.py b/planning_library/components/evaluation/evaluator_component.py index a586c35..c392654 100644 --- a/planning_library/components/evaluation/evaluator_component.py +++ b/planning_library/components/evaluation/evaluator_component.py @@ -1,19 +1,19 @@ -from typing import Optional, Dict, Generic, Type, Callable, Awaitable +from typing import Awaitable, Callable, Dict, Generic, Optional, Type -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager -from langchain_core.output_parsers import BaseOutputParser -from langchain_core.runnables import Runnable +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager from langchain_core.language_models import BaseChatModel +from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import ChatPromptTemplate -from ..base_component import BaseComponent, InputType, OutputType -from planning_library.primitives.output_parsers import SimpleEvaluateOutputParser -from .threshold_judge import LeqThresholdJudge, GeqThresholdJudge +from langchain_core.runnables import Runnable + from planning_library.components.runnable_component import RunnableComponent +from planning_library.primitives.output_parsers import SimpleEvaluateOutputParser + +from ..base_component import BaseComponent, InputType, OutputType +from .threshold_judge import GeqThresholdJudge, LeqThresholdJudge -class EvaluatorComponent( - Generic[InputType, OutputType], BaseComponent[InputType, bool] -): +class EvaluatorComponent(Generic[InputType, OutputType], BaseComponent[InputType, bool]): def __init__( self, backbone: BaseComponent[InputType, OutputType], @@ -36,16 +36,12 @@ def add_output_preprocessing( ) -> None: self.judge.add_output_preprocessing(preprocess, apreprocess) - def invoke( - self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs - ) -> bool: + def invoke(self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs) -> bool: if "run_name" not in kwargs and self.name: kwargs["run_name"] = self.name backbone_output = self.backbone.invoke(inputs, run_manager, **kwargs) - should_continue = self.judge.invoke( - {"backbone_output": backbone_output}, run_manager - ) + should_continue = self.judge.invoke({"backbone_output": backbone_output}, run_manager) return should_continue async def ainvoke( @@ -58,9 +54,7 @@ async def ainvoke( kwargs["run_name"] = self.name backbone_output = await self.backbone.ainvoke(inputs, run_manager, **kwargs) - should_continue = await self.judge.ainvoke( - {"backbone_output": backbone_output}, run_manager - ) + should_continue = await self.judge.ainvoke({"backbone_output": backbone_output}, run_manager) return should_continue @classmethod @@ -74,9 +68,7 @@ def create_threshold_evaluator( system_message: Optional[str] = None, output_parser: Optional[BaseOutputParser[float]] = None, ) -> "EvaluatorComponent[InputType, float]": - prompt = cls._process_prompt( - prompt=prompt, user_message=user_message, system_message=system_message - ) + prompt = cls._process_prompt(prompt=prompt, user_message=user_message, system_message=system_message) if output_parser is None: output_parser = SimpleEvaluateOutputParser() @@ -95,15 +87,11 @@ def create_threshold_evaluator_from_runnable( threshold_mode: str, ) -> "EvaluatorComponent[InputType, float]": if threshold_mode == "leq": - judge: BaseComponent[Dict[str, float], bool] = LeqThresholdJudge( - threshold=threshold - ) + judge: BaseComponent[Dict[str, float], bool] = LeqThresholdJudge(threshold=threshold) elif threshold_mode == "geq": judge = GeqThresholdJudge(threshold=threshold) else: - raise ValueError( - f"Unknown `threshold_mode` {threshold_mode} when initializing {cls.__name__}." - ) + raise ValueError(f"Unknown `threshold_mode` {threshold_mode} when initializing {cls.__name__}.") backbone = RunnableComponent(runnable) return cls(backbone=backbone, judge=judge) diff --git a/planning_library/components/evaluation/threshold_judge.py b/planning_library/components/evaluation/threshold_judge.py index 84bd780..2e26eb3 100644 --- a/planning_library/components/evaluation/threshold_judge.py +++ b/planning_library/components/evaluation/threshold_judge.py @@ -1,16 +1,17 @@ -from ..base_component import BaseComponent from typing import Optional -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager + +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager + from planning_library.components.base_component import InputType +from ..base_component import BaseComponent + class LeqThresholdJudge(BaseComponent[InputType, bool]): def __init__(self, threshold: float): self.threshold = threshold - def invoke( - self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs - ) -> bool: + def invoke(self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs) -> bool: return inputs["backbone_output"] <= self.threshold async def ainvoke( @@ -26,9 +27,7 @@ class GeqThresholdJudge(BaseComponent[InputType, bool]): def __init__(self, threshold: float): self.threshold = threshold - def invoke( - self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs - ) -> bool: + def invoke(self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs) -> bool: return inputs["backbone_output"] >= self.threshold async def ainvoke( diff --git a/planning_library/components/runnable_component.py b/planning_library/components/runnable_component.py index e8dd0ae..ee901c7 100644 --- a/planning_library/components/runnable_component.py +++ b/planning_library/components/runnable_component.py @@ -1,10 +1,12 @@ -from typing import Optional, Dict, Callable, Awaitable -from langchain_core.runnables import Runnable, RunnableLambda -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.output_parsers import BaseOutputParser +from typing import Awaitable, Callable, Dict, Optional + +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager from langchain_core.language_models import BaseChatModel -from .base_component import InputType, OutputType, BaseComponent +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import Runnable, RunnableLambda + +from .base_component import BaseComponent, InputType, OutputType class RunnableComponent(BaseComponent[InputType, OutputType]): @@ -20,9 +22,7 @@ def create_from_steps( user_message: Optional[str] = None, system_message: Optional[str] = None, ) -> "RunnableComponent": - prompt = cls._process_prompt( - prompt=prompt, user_message=user_message, system_message=system_message - ) + prompt = cls._process_prompt(prompt=prompt, user_message=user_message, system_message=system_message) runnable = prompt | llm if output_parser is not None: runnable = runnable | output_parser @@ -42,9 +42,7 @@ def add_output_preprocessing( ) -> None: self.runnable = self.runnable | RunnableLambda(preprocess, afunc=apreprocess) - def invoke( - self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs - ) -> OutputType: + def invoke(self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs) -> OutputType: config = kwargs if "callbacks" not in config and run_manager: config["callbacks"] = run_manager diff --git a/planning_library/function_calling_parsers/__init__.py b/planning_library/function_calling_parsers/__init__.py index 29d4f3d..1cc3570 100644 --- a/planning_library/function_calling_parsers/__init__.py +++ b/planning_library/function_calling_parsers/__init__.py @@ -1,6 +1,6 @@ from .base_parser import ( - BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser, + BaseFunctionCallingSingleActionParser, ) from .openai_functions_parser import OpenAIFunctionsParser from .openai_tools_parser import OpenAIToolsParser diff --git a/planning_library/function_calling_parsers/base_parser.py b/planning_library/function_calling_parsers/base_parser.py index 78a1cbe..2e82b8e 100644 --- a/planning_library/function_calling_parsers/base_parser.py +++ b/planning_library/function_calling_parsers/base_parser.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from typing import List, Tuple, Sequence, Dict, Any -from langchain_core.language_models import BaseChatModel +from typing import Any, Dict, List, Sequence, Tuple + from langchain.agents.agent import AgentOutputParser, MultiActionAgentOutputParser from langchain_core.agents import AgentAction -from langchain_core.tools import BaseTool -from langchain_core.runnables import Runnable +from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from typing_extensions import TypedDict @@ -23,9 +24,7 @@ class BaseFunctionCallingParser(ABC): name: str @abstractmethod - def prepare_llm( - self, llm: BaseChatModel, tools: Sequence[BaseTool] - ) -> Runnable: ... + def prepare_llm(self, llm: BaseChatModel, tools: Sequence[BaseTool]) -> Runnable: ... @abstractmethod def format_inputs(self, inputs: AgentInputs) -> ProcessedAgentInputs: ... diff --git a/planning_library/function_calling_parsers/openai_functions_parser.py b/planning_library/function_calling_parsers/openai_functions_parser.py index 5004dbd..56c9a9f 100644 --- a/planning_library/function_calling_parsers/openai_functions_parser.py +++ b/planning_library/function_calling_parsers/openai_functions_parser.py @@ -1,17 +1,18 @@ -from typing import List, Tuple, Sequence -from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser +from typing import List, Sequence, Tuple + from langchain.agents.format_scratchpad import format_to_openai_function_messages -from langchain_core.messages import BaseMessage +from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain_core.agents import AgentAction from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_function from planning_library.function_calling_parsers.base_parser import ( + AgentInputs, BaseFunctionCallingSingleActionParser, ProcessedAgentInputs, - AgentInputs, ) from planning_library.function_calling_parsers.parser_registry import ParserRegistry @@ -33,9 +34,7 @@ def format_inputs( ) -> ProcessedAgentInputs: return { **inputs, # type: ignore[typeddict-unknown-key] - "agent_scratchpad": self._format_intermediate_steps( - inputs["intermediate_steps"] - ), + "agent_scratchpad": self._format_intermediate_steps(inputs["intermediate_steps"]), } def prepare_llm(self, llm: BaseChatModel, tools: Sequence[BaseTool]) -> Runnable: diff --git a/planning_library/function_calling_parsers/openai_tools_parser.py b/planning_library/function_calling_parsers/openai_tools_parser.py index c62adcd..095096e 100644 --- a/planning_library/function_calling_parsers/openai_tools_parser.py +++ b/planning_library/function_calling_parsers/openai_tools_parser.py @@ -1,17 +1,19 @@ -from typing import List, Tuple, Sequence -from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser +from typing import List, Sequence, Tuple + from langchain.agents.format_scratchpad.openai_tools import ( format_to_openai_tool_messages, ) -from langchain_core.messages import BaseMessage +from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser from langchain_core.agents import AgentAction -from langchain_core.tools import BaseTool -from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool + from planning_library.function_calling_parsers.base_parser import ( - BaseFunctionCallingMultiActionParser, AgentInputs, + BaseFunctionCallingMultiActionParser, ProcessedAgentInputs, ) from planning_library.function_calling_parsers.parser_registry import ParserRegistry @@ -34,9 +36,7 @@ def format_inputs( ) -> ProcessedAgentInputs: return { **inputs, # type: ignore[typeddict-unknown-key] - "agent_scratchpad": self._format_intermediate_steps( - inputs["intermediate_steps"] - ), + "agent_scratchpad": self._format_intermediate_steps(inputs["intermediate_steps"]), } def prepare_llm(self, llm: BaseChatModel, tools: Sequence[BaseTool]) -> Runnable: diff --git a/planning_library/function_calling_parsers/parser_registry.py b/planning_library/function_calling_parsers/parser_registry.py index 32d5652..d340089 100644 --- a/planning_library/function_calling_parsers/parser_registry.py +++ b/planning_library/function_calling_parsers/parser_registry.py @@ -1,30 +1,25 @@ -from typing import List, Union, Dict, Type +from typing import Dict, List, Type, Union + from planning_library.function_calling_parsers import ( - BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser, + BaseFunctionCallingSingleActionParser, ) class ParserRegistry: registry: Dict[ str, - Union[ - BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser - ], + Union[BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser], ] = {} @classmethod def get_parser( cls, parser_name - ) -> Union[ - BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser - ]: + ) -> Union[BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser]: try: return cls.registry[parser_name] except KeyError: - raise ValueError( - f"Unknown parser {parser_name}. Currently available are: {cls.get_available_parsers()}" - ) + raise ValueError(f"Unknown parser {parser_name}. Currently available are: {cls.get_available_parsers()}") @classmethod def get_available_parsers(cls) -> List[str]: diff --git a/planning_library/primitives/output_parsers/evaluation_output_parser.py b/planning_library/primitives/output_parsers/evaluation_output_parser.py index 6701393..c433886 100644 --- a/planning_library/primitives/output_parsers/evaluation_output_parser.py +++ b/planning_library/primitives/output_parsers/evaluation_output_parser.py @@ -17,6 +17,4 @@ def parse(self, text: str) -> float: raise ValueError("The given number is out of (0.0, 1.0) range.") return result except ValueError: - raise OutputParserException( - f"Couldn't convert {text} to float between 0 and 1." - ) + raise OutputParserException(f"Couldn't convert {text} to float between 0 and 1.") diff --git a/planning_library/strategies/__init__.py b/planning_library/strategies/__init__.py index d449331..a1a95e0 100644 --- a/planning_library/strategies/__init__.py +++ b/planning_library/strategies/__init__.py @@ -1,8 +1,8 @@ +from .adapt import ADaPTStrategy from .base_strategy import BaseCustomStrategy, BaseLangGraphStrategy from .reflexion import ReflexionStrategy -from .tot_dfs import TreeOfThoughtsDFSStrategy from .simple import SimpleStrategy -from .adapt import ADaPTStrategy +from .tot_dfs import TreeOfThoughtsDFSStrategy __all__ = [ "BaseCustomStrategy", diff --git a/planning_library/strategies/adapt/adapt_strategy.py b/planning_library/strategies/adapt/adapt_strategy.py index eab9af4..41fe588 100644 --- a/planning_library/strategies/adapt/adapt_strategy.py +++ b/planning_library/strategies/adapt/adapt_strategy.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Dict, Iterator, List, Optional, Tuple, AsyncIterator, Any +from dataclasses import asdict +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import ( @@ -8,13 +9,12 @@ CallbackManagerForChainRun, ) - -from ..base_strategy import BaseCustomStrategy from planning_library.action_executors import LangchainActionExecutor, MetaTools from planning_library.strategies.adapt.components import ADaPTExecutor, ADaPTPlanner from planning_library.strategies.adapt.components.executor import ADaPTExecutorConfig from planning_library.strategies.adapt.components.planner import ADaPTPlannerConfig -from dataclasses import asdict + +from ..base_strategy import BaseCustomStrategy class ADaPTStrategy(BaseCustomStrategy): @@ -28,7 +28,7 @@ class ADaPTStrategy(BaseCustomStrategy): max_depth: int @staticmethod - def create( + def create( # type: ignore[reportIncompatibleMethodOverride] meta_tools: Optional[MetaTools] = None, return_intermediate_steps: bool = False, return_finish_log: bool = False, @@ -52,9 +52,7 @@ def create( verbose: True to print extra information during execution. """ # TODO: runnable component vs strategy component? - assert ( - executor_config is not None - ), "Default ADaPT executor is currently not supported." + assert executor_config is not None, "Default ADaPT executor is currently not supported." if executor_config.runnable is not None: executor = ADaPTExecutor( @@ -73,9 +71,7 @@ def create( verbose=verbose, ) - assert ( - planner_config is not None - ), "Default ADaPT planner is currently not supported." + assert planner_config is not None, "Default ADaPT planner is currently not supported." if planner_config.runnable is not None: planner = ADaPTPlanner(planner_config.runnable) @@ -139,18 +135,14 @@ def _adapt_step( if depth > self.max_depth: return ( False, - AgentFinish( - return_values={}, log="Maximum decomposition depth reached." - ), + AgentFinish(return_values={}, log="Maximum decomposition depth reached."), intermediate_steps, ) # 2: run task through executor executor_output = self.executor.invoke( - inputs, # type: ignore[arg-type] - run_manager=run_manager.get_child(tag=f"executor:depth_{depth}") - if run_manager - else None, + inputs, # type: ignore + run_manager=run_manager.get_child(tag=f"executor:depth_{depth}") if run_manager else None, ) is_completed, cur_agent_outcome, cur_intermediate_steps = ( @@ -167,34 +159,26 @@ def _adapt_step( # 3.2: otherwise: self.executor.reset( actions=[a[0] for a in intermediate_steps], - run_manager=run_manager.get_child(tag="clean_env") - if run_manager - else None, + run_manager=run_manager.get_child(tag="clean_env") if run_manager else None, ) # call a planner to further decompose a current task plan = self.planner.invoke( dict( - inputs=inputs["inputs"], + inputs=inputs, # type: ignore[reportArgumentType] executor_agent_outcome=cur_agent_outcome, executor_intermediate_steps=cur_intermediate_steps, ), - run_manager=run_manager.get_child(tag=f"executor:depth_{depth}") - if run_manager - else None, + run_manager=run_manager.get_child(tag=f"executor:depth_{depth}") if run_manager else None, ) # when AND logic is given, execute tasks sequentially if plan["aggregation_mode"] == "and": for task_inputs in plan["subtasks"]: - cur_is_completed, cur_agent_outcome, cur_intermediate_steps = ( - self._adapt_step( - inputs={ - "inputs": {"inputs": task_inputs} - }, # TODO: hard-coded inputs key is ugly - depth=depth + 1, - run_manager=run_manager, - intermediate_steps=intermediate_steps, - ) + cur_is_completed, cur_agent_outcome, cur_intermediate_steps = self._adapt_step( + inputs={"inputs": task_inputs}, + depth=depth + 1, + run_manager=run_manager, + intermediate_steps=intermediate_steps, ) if not cur_is_completed: @@ -206,9 +190,7 @@ def _adapt_step( else: intermediate_steps.extend(cur_intermediate_steps) - agent_outcome = AgentFinish( - return_values={}, log="Task solved successfully!" - ) + agent_outcome = AgentFinish(return_values={}, log="Task solved successfully!") return True, agent_outcome, intermediate_steps elif plan["aggregation_mode"] == "or": for task_inputs in plan["subtasks"]: @@ -217,18 +199,14 @@ def _adapt_step( cur_agent_outcome, cur_intermediate_steps, ) = self._adapt_step( - inputs={ - "inputs": {"inputs": task_inputs} - }, # TODO: hard-coded inputs key is ugly + inputs={"inputs": task_inputs}, # TODO: hard-coded inputs key is ugly depth=depth + 1, run_manager=run_manager, intermediate_steps=intermediate_steps, ) if cur_is_completed: - agent_outcome = AgentFinish( - return_values={}, log="Task solved successfully!" - ) + agent_outcome = AgentFinish(return_values={}, log="Task solved successfully!") return True, agent_outcome, intermediate_steps else: intermediate_steps.extend(cur_intermediate_steps) @@ -239,9 +217,7 @@ def _adapt_step( ) return False, agent_outcome, intermediate_steps - raise NotImplementedError( - "Currently, only `and` and `or` aggregation logic is supported." - ) + raise NotImplementedError("Currently, only `and` and `or` aggregation logic is supported.") def _run_strategy( self, @@ -275,20 +251,16 @@ async def _adapt_astep( if depth > self.max_depth: return ( False, - AgentFinish( - return_values={}, log="Maximum decomposition depth reached." - ), + AgentFinish(return_values={}, log="Maximum decomposition depth reached."), intermediate_steps, ) # 2: run task through executor executor_output = await self.executor.ainvoke( dict( - inputs=inputs, + inputs=inputs, # type: ignore ), - run_manager=run_manager.get_child(tag=f"executor:depth_{depth}") - if run_manager - else None, + run_manager=run_manager.get_child(tag=f"executor:depth_{depth}") if run_manager else None, ) is_completed, cur_agent_outcome, cur_intermediate_steps = ( executor_output["is_completed"], @@ -304,20 +276,16 @@ async def _adapt_astep( # 3.2: otherwise: await self.executor.areset( actions=[a[0] for a in intermediate_steps], - run_manager=run_manager.get_child(tag="clean_env") - if run_manager - else None, + run_manager=run_manager.get_child(tag="clean_env") if run_manager else None, ) plan = await self.planner.ainvoke( dict( - inputs=inputs, + inputs=inputs, # type: ignore[reportArgumentType] executor_agent_outcome=cur_agent_outcome, executor_intermediate_steps=cur_intermediate_steps, ), - run_manager=run_manager.get_child(tag=f"executor:depth_{depth}") - if run_manager - else None, + run_manager=run_manager.get_child(tag=f"executor:depth_{depth}") if run_manager else None, ) # when AND logic is given, execute tasks sequentially if plan["aggregation_mode"] == "and": @@ -327,7 +295,7 @@ async def _adapt_astep( cur_agent_outcome, cur_intermediate_steps, ) = await self._adapt_astep( - inputs={"inputs": {"inputs": task_inputs}}, + inputs={"inputs": task_inputs}, depth=depth + 1, run_manager=run_manager, intermediate_steps=intermediate_steps, @@ -342,9 +310,7 @@ async def _adapt_astep( else: intermediate_steps.extend(cur_intermediate_steps) - agent_outcome = AgentFinish( - return_values={}, log="Task solved successfully!" - ) + agent_outcome = AgentFinish(return_values={}, log="Task solved successfully!") return True, agent_outcome, intermediate_steps elif plan["aggregation_mode"] == "or": for task_inputs in plan["subtasks"]: @@ -353,16 +319,14 @@ async def _adapt_astep( cur_agent_outcome, cur_intermediate_steps, ) = await self._adapt_astep( - inputs={"inputs": {"inputs": task_inputs}}, + inputs={"inputs": task_inputs}, depth=depth + 1, run_manager=run_manager, intermediate_steps=intermediate_steps, ) if cur_is_completed: - agent_outcome = AgentFinish( - return_values={}, log="Task solved successfully!" - ) + agent_outcome = AgentFinish(return_values={}, log="Task solved successfully!") return True, agent_outcome, intermediate_steps else: intermediate_steps.extend(cur_intermediate_steps) @@ -373,9 +337,7 @@ async def _adapt_astep( ) return False, agent_outcome, intermediate_steps - raise NotImplementedError( - "Currently, only `and` and `or` aggregation logic is supported." - ) + raise NotImplementedError("Currently, only `and` and `or` aggregation logic is supported.") async def _arun_strategy( self, diff --git a/planning_library/strategies/adapt/components/executor.py b/planning_library/strategies/adapt/components/executor.py index 8edc52d..c08de1e 100644 --- a/planning_library/strategies/adapt/components/executor.py +++ b/planning_library/strategies/adapt/components/executor.py @@ -1,8 +1,16 @@ from __future__ import annotations -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from dataclasses import dataclass from textwrap import dedent -from typing import Optional +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager +from langchain_core.language_models import BaseChatModel +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import Runnable, RunnableLambda +from langchain_core.tools import BaseTool +from typing_extensions import TypedDict from planning_library.action_executors import ( BaseActionExecutor, @@ -11,20 +19,11 @@ ) from planning_library.components import RunnableComponent from planning_library.components.agent_component import AgentFactory -from planning_library.strategies import SimpleStrategy -from typing import Dict, Any, Tuple, List -from langchain_core.agents import AgentAction, AgentFinish -from langchain_core.runnables import Runnable, RunnableLambda -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager -from typing_extensions import TypedDict -from typing import Union, Sequence -from langchain_core.language_models import BaseChatModel -from langchain_core.tools import BaseTool from planning_library.function_calling_parsers import ( - BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser, + BaseFunctionCallingSingleActionParser, ) -from dataclasses import dataclass +from planning_library.strategies import SimpleStrategy class ADaPTExecutorInput(TypedDict): @@ -75,9 +74,7 @@ def __init__( self._action_executor = action_executor @classmethod - def _create_default_prompt( - cls, system_message: Optional[str], user_message: str, **kwargs - ) -> ChatPromptTemplate: + def _create_default_prompt(cls, system_message: Optional[str], user_message: str, **kwargs) -> ChatPromptTemplate: if system_message is None: system_message = "You are an advanced reasoning agent." @@ -107,8 +104,7 @@ def _process_outputs(outputs: Dict[str, Any]) -> ADaPTExecutorOutput: return_values={ key: value[0] for key, value in outputs.items() - if isinstance(key, list) - and key not in ["finish_log", "intermediate_steps"] + if isinstance(key, list) and key not in ["finish_log", "intermediate_steps"] }, log=outputs["finish_log"][0], ) @@ -151,13 +147,9 @@ def _preprocess_input( assert action_executor is not None, "Either pass tools or action executor." tools = action_executor.tools - prompt = cls._process_prompt( - prompt=prompt, user_message=user_message, system_message=system_message - ) + prompt = cls._process_prompt(prompt=prompt, user_message=user_message, system_message=system_message) - agent = AgentFactory.create_agent( - llm=llm, tools=tools, prompt=prompt, parser=parser, parser_name=parser_name - ) + agent = AgentFactory.create_agent(llm=llm, tools=tools, prompt=prompt, parser=parser, parser_name=parser_name) strategy = SimpleStrategy.create( tools=tools, @@ -189,6 +181,4 @@ async def areset( run_manager: Optional[AsyncCallbackManager] = None, **kwargs, ) -> None: - await self._action_executor.areset( - actions=actions, run_manager=run_manager, **kwargs - ) + await self._action_executor.areset(actions=actions, run_manager=run_manager, **kwargs) diff --git a/planning_library/strategies/adapt/components/planner.py b/planning_library/strategies/adapt/components/planner.py index e69b947..f7d4c90 100644 --- a/planning_library/strategies/adapt/components/planner.py +++ b/planning_library/strategies/adapt/components/planner.py @@ -1,38 +1,34 @@ from __future__ import annotations -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from dataclasses import dataclass from textwrap import dedent -from typing import Optional +from typing import Any, Dict, Optional, Sequence, Union -from planning_library.action_executors import LangchainActionExecutor -from planning_library.strategies.adapt.utils import get_adapt_planner_tools -from planning_library.strategies.adapt.utils.planner_tools import BaseADaPTPlannerTool +from langchain.agents.agent import RunnableAgent, RunnableMultiActionAgent +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager +from langchain_core.language_models import BaseChatModel from langchain_core.output_parsers import BaseOutputParser -from planning_library.components import BaseComponent, RunnableComponent -from planning_library.components.agent_component import AgentFactory -from planning_library.strategies import SimpleStrategy -from typing import Dict, Any +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable, RunnableLambda -from typing import Union, Sequence -from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from planning_library.utils import format_thought -from langchain.agents.agent import RunnableAgent, RunnableMultiActionAgent +from planning_library.action_executors import LangchainActionExecutor +from planning_library.components import BaseComponent, RunnableComponent +from planning_library.components.agent_component import AgentFactory from planning_library.function_calling_parsers import ( - BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser, + BaseFunctionCallingSingleActionParser, ParserRegistry, ) -from dataclasses import dataclass - - +from planning_library.strategies import SimpleStrategy from planning_library.strategies.adapt.utils import ( ADaPTPlannerInput, ADaPTPlannerOutput, SimplePlannerOutputParser, + get_adapt_planner_tools, ) +from planning_library.strategies.adapt.utils.planner_tools import BaseADaPTPlannerTool +from planning_library.utils import format_thought @dataclass @@ -104,12 +100,10 @@ def _create_default_prompt( ] if mode == "agent": - return ChatPromptTemplate.from_messages( - base_messages - + [ - ( - "human", - dedent(""" + base_messages += [ + ( + "human", + dedent(""" The trial above was unsuccessful. Your goal is to construct a step-by-step plan to successfully solve the original task. Plan is a list of subtasks with the logic of how the subtasks' results should be aggregated. @@ -124,18 +118,15 @@ def _create_default_prompt( You are given access to a set of tools to help with the plan construction. ALWAYS use tools, refrain from using tools only when you are done. """), - ), - MessagesPlaceholder("agent_scratchpad"), - ] - ) + ), + MessagesPlaceholder("agent_scratchpad"), + ] elif mode == "simple": - return ChatPromptTemplate.from_messages( - base_messages - + [ - ( - "human", - dedent(""" + base_messages += [ + ( + "human", + dedent(""" The trial above was unsuccessful. Your goal is to construct a step-by-step plan to successfully solve the original task. Plan is a list of subtasks with the logic of how the subtasks' results should be aggregated. @@ -158,10 +149,14 @@ def _create_default_prompt( "aggregation_mode": ""}} ``` """), - ), - ] - ) - raise NotImplementedError("Currently, only agentic planner is supported.") + ), + ] + else: + raise NotImplementedError(f"Unsupported mode {mode}.") + + return ChatPromptTemplate.from_messages( + base_messages # type: ignore[arg-type] + ) def invoke( self, @@ -170,22 +165,16 @@ def invoke( **kwargs, ) -> ADaPTPlannerOutput: if self.mode == "agent": - assert ( - self.tools is not None and len(self.tools) > 0 - ), "Tools have to be defined for agentic mode." + assert self.tools is not None and len(self.tools) > 0, "Tools have to be defined for agentic mode." plan = self.tools[0].plan plan.clear() - _ = self.runnable.invoke( - inputs, run_manager=run_manager, run_name=self.name - ) + _ = self.runnable.invoke(inputs, run_manager=run_manager, run_name=self.name) return { "subtasks": plan.subtasks, "aggregation_mode": plan.aggregation_mode, } if self.mode == "simple": - return self.runnable.invoke( - inputs, run_manager=run_manager, run_name=self.name - ) + return self.runnable.invoke(inputs, run_manager=run_manager, run_name=self.name) raise NotImplementedError( "Currently, only `agent` (with tools) and `simple` (a single call to llm) modes for the planner are supported." @@ -198,22 +187,16 @@ async def ainvoke( **kwargs, ) -> ADaPTPlannerOutput: if self.mode == "agent": - assert ( - self.tools is not None and len(self.tools) > 0 - ), "Tools have to be defined for agentic mode." + assert self.tools is not None and len(self.tools) > 0, "Tools have to be defined for agentic mode." plan = self.tools[0].plan plan.clear() - _ = await self.runnable.ainvoke( - inputs, run_manager=run_manager, run_name=self.name - ) + _ = await self.runnable.ainvoke(inputs, run_manager=run_manager, run_name=self.name) return { "subtasks": plan.subtasks, "aggregation_mode": plan.aggregation_mode, } if self.mode == "simple": - return await self.runnable.ainvoke( - inputs, run_manager=run_manager, run_name=self.name - ) + return await self.runnable.ainvoke(inputs, run_manager=run_manager, run_name=self.name) raise NotImplementedError("Currently, only agentic planner is supported.") @@ -233,13 +216,9 @@ def _create_agent( ] = None, parser_name: Optional[str] = None, ) -> Union[RunnableAgent, RunnableMultiActionAgent]: - prompt = cls._process_prompt( - prompt=prompt, user_message=user_message, system_message=system_message - ) + prompt = cls._process_prompt(prompt=prompt, user_message=user_message, system_message=system_message) - return AgentFactory.create_agent( - llm=llm, tools=tools, prompt=prompt, parser=parser, parser_name=parser_name - ) + return AgentFactory.create_agent(llm=llm, tools=tools, prompt=prompt, parser=parser, parser_name=parser_name) @classmethod def create_agent_planner( @@ -287,9 +266,7 @@ def _preprocess_input( return { **inputs["inputs"], - "executor_agent_outcome": format_thought( - inputs["executor_agent_outcome"] - ), + "executor_agent_outcome": format_thought(inputs["executor_agent_outcome"]), "executor_intermediate_steps": executor_intermediate_steps, } @@ -353,9 +330,7 @@ def _preprocess_input( return { **inputs["inputs"], - "executor_agent_outcome": format_thought( - inputs["executor_agent_outcome"] - ), + "executor_agent_outcome": format_thought(inputs["executor_agent_outcome"]), "executor_intermediate_steps": executor_intermediate_steps, } @@ -369,8 +344,6 @@ def _preprocess_input( if output_parser is None: output_parser = SimplePlannerOutputParser() - runnable = RunnableComponent.create_from_steps( - prompt=prompt, llm=llm, output_parser=output_parser - ) + runnable = RunnableComponent.create_from_steps(prompt=prompt, llm=llm, output_parser=output_parser) runnable.add_input_preprocessing(_preprocess_input) return cls(runnable=runnable, mode="simple") diff --git a/planning_library/strategies/adapt/utils/__init__.py b/planning_library/strategies/adapt/utils/__init__.py index 32135e9..107c59b 100644 --- a/planning_library/strategies/adapt/utils/__init__.py +++ b/planning_library/strategies/adapt/utils/__init__.py @@ -1,6 +1,6 @@ -from .typing_utils import ADaPTPlan, ADaPTPlannerOutput, ADaPTPlannerInput -from .planner_tools import get_adapt_planner_tools from .planner_output_parser import SimplePlannerOutputParser +from .planner_tools import get_adapt_planner_tools +from .typing_utils import ADaPTPlan, ADaPTPlannerInput, ADaPTPlannerOutput __all__ = [ "ADaPTPlan", diff --git a/planning_library/strategies/adapt/utils/planner_output_parser.py b/planning_library/strategies/adapt/utils/planner_output_parser.py index 700df01..3bba253 100644 --- a/planning_library/strategies/adapt/utils/planner_output_parser.py +++ b/planning_library/strategies/adapt/utils/planner_output_parser.py @@ -1,13 +1,12 @@ from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers.json import parse_and_check_json_markdown + from .typing_utils import ADaPTPlannerOutput class SimplePlannerOutputParser(BaseOutputParser[ADaPTPlannerOutput]): def parse(self, text: str) -> ADaPTPlannerOutput: - output = parse_and_check_json_markdown( - text, expected_keys=["subtasks", "aggregation_mode"] - ) + output = parse_and_check_json_markdown(text, expected_keys=["subtasks", "aggregation_mode"]) return { "subtasks": output["subtasks"], "aggregation_mode": output["aggregation_mode"], diff --git a/planning_library/strategies/adapt/utils/planner_tools.py b/planning_library/strategies/adapt/utils/planner_tools.py index af54171..edb5121 100644 --- a/planning_library/strategies/adapt/utils/planner_tools.py +++ b/planning_library/strategies/adapt/utils/planner_tools.py @@ -1,9 +1,11 @@ +from abc import ABC from textwrap import dedent -from typing import Any, Type, Optional, List, Literal -from planning_library.strategies.adapt.utils import ADaPTPlan +from typing import Any, List, Literal, Optional, Type + from langchain.pydantic_v1 import BaseModel, Field from langchain.tools import BaseTool -from abc import ABC + +from planning_library.strategies.adapt.utils import ADaPTPlan class BaseADaPTPlannerTool(BaseTool, BaseModel, ABC): @@ -25,27 +27,21 @@ class CheckPlanTool(BaseADaPTPlannerTool): class CheckPlanInput(BaseModel): ... - args_schema: Type[BaseModel] = CheckPlanInput + args_schema: Type[BaseModel] = CheckPlanInput # type: ignore def _run(self, *args: Any, **kwargs: Any) -> str: if len(self.plan.subtasks) == 0: observation = ["Currently, the plan doesn't contain any subtasks."] else: - observation = [ - f"Currently, the plan contains {len(self.plan.subtasks)} subtasks." - ] + observation = [f"Currently, the plan contains {len(self.plan.subtasks)} subtasks."] for i, subtask in enumerate(self.plan.subtasks): observation.append(f"{i + 1}. {subtask}") if self.plan.aggregation_mode is None: - observation.append( - "The current subtasks results aggregation mode is not defined yet." - ) + observation.append("The current subtasks results aggregation mode is not defined yet.") else: - observation.append( - f"The current subtasks results aggregation mode is set to {self.plan.aggregation_mode}." - ) + observation.append(f"The current subtasks results aggregation mode is set to {self.plan.aggregation_mode}.") return "\n".join(observation) @@ -65,7 +61,7 @@ class AddTaskInput(BaseModel): default=None, ) - args_schema: Type[BaseModel] = AddTaskInput + args_schema: Type[BaseModel] = AddTaskInput # type: ignore def _run( self, @@ -87,9 +83,7 @@ def _run( class EditTaskTool(BaseADaPTPlannerTool): name = "edit_task" - description = dedent( - """Changes the formulation of the existing subtask in the current plan.""" - ) + description = dedent("""Changes the formulation of the existing subtask in the current plan.""") class EditTaskInput(BaseModel): task_inputs: str = Field( @@ -101,11 +95,9 @@ class EditTaskInput(BaseModel): Note that the indexing starts with 0."""), ) - args_schema: Type[BaseModel] = EditTaskInput + args_schema: Type[BaseModel] = EditTaskInput # type: ignore - def _run( - self, task_inputs: str, task_position: int, *args: Any, **kwargs: Any - ) -> str: + def _run(self, task_inputs: str, task_position: int, *args: Any, **kwargs: Any) -> str: try: self.plan.subtasks[task_position] = task_inputs return f"Successfully edited the subtask at position {task_position} in the current plan." @@ -122,7 +114,7 @@ class RemoveTaskInput(BaseModel): description="The position in the plan to delete the subtask from. Note that the indexing starts with 0." ) - args_schema: Type[BaseModel] = RemoveTaskInput + args_schema: Type[BaseModel] = RemoveTaskInput # type: ignore def _run(self, task_position: int, *args: Any, **kwargs: Any) -> str: try: @@ -146,11 +138,9 @@ class DefineAggregationModeInput(BaseModel): """) ) - args_schema: Type[BaseModel] = DefineAggregationModeInput + args_schema: Type[BaseModel] = DefineAggregationModeInput # type: ignore - def _run( - self, aggregation_mode: Literal["and", "or"], *args: Any, **kwargs: Any - ) -> str: + def _run(self, aggregation_mode: Literal["and", "or"], *args: Any, **kwargs: Any) -> str: self.plan.aggregation_mode = aggregation_mode return f"Successfully set aggregation mode to {aggregation_mode}" diff --git a/planning_library/strategies/adapt/utils/typing_utils.py b/planning_library/strategies/adapt/utils/typing_utils.py index d79500a..bc6a8a8 100644 --- a/planning_library/strategies/adapt/utils/typing_utils.py +++ b/planning_library/strategies/adapt/utils/typing_utils.py @@ -1,7 +1,8 @@ -from typing import List, Literal, Dict, Any, Tuple +from typing import Any, Dict, List, Literal, Tuple + from langchain_core.agents import AgentAction, AgentFinish -from typing_extensions import TypedDict from langchain_core.pydantic_v1 import BaseModel, Field +from typing_extensions import TypedDict class ADaPTPlannerInput(TypedDict): diff --git a/planning_library/strategies/base_strategy.py b/planning_library/strategies/base_strategy.py index 2e4a3b7..4844825 100644 --- a/planning_library/strategies/base_strategy.py +++ b/planning_library/strategies/base_strategy.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from typing import ( Any, @@ -56,12 +57,17 @@ def _run_strategy( run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Iterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]: ... - @abstractmethod - def _arun_strategy( + async def _arun_strategy( self, inputs: Dict[str, str], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> AsyncIterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]: ... + ) -> AsyncIterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]: + loop = asyncio.get_event_loop() + + sync_run_manager = run_manager.get_sync() if run_manager is not None else None + result = await loop.run_in_executor(None, self._run_strategy, inputs, sync_run_manager) + for item in result: + yield item def _return( self, @@ -85,9 +91,7 @@ async def _areturn( run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: if run_manager: - await run_manager.on_agent_finish( - output, color="green", verbose=self.verbose - ) + await run_manager.on_agent_finish(output, color="green", verbose=self.verbose) final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps @@ -126,9 +130,7 @@ async def _acall( outputs = [] async for _output, _intermediate_steps in _outputs: - output = await self._areturn( - _output, _intermediate_steps, run_manager=run_manager - ) + output = await self._areturn(_output, _intermediate_steps, run_manager=run_manager) outputs.append(output) return {key: [output[key] for output in outputs] for key in outputs[0]} diff --git a/planning_library/strategies/reflexion/components/actor.py b/planning_library/strategies/reflexion/components/actor.py index 0addfe1..8f4be1e 100644 --- a/planning_library/strategies/reflexion/components/actor.py +++ b/planning_library/strategies/reflexion/components/actor.py @@ -1,16 +1,15 @@ from __future__ import annotations -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from textwrap import dedent -from typing import Optional - +from typing import Any, Dict, List, Optional, Sequence, Tuple -from planning_library.components import AgentComponent -from typing import Tuple, Dict, Any, List, Sequence from langchain_core.agents import AgentAction from langchain_core.messages import BaseMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from typing_extensions import TypedDict +from planning_library.components import AgentComponent + class ReflexionActorInput(TypedDict): inputs: Dict[str, Any] @@ -35,9 +34,7 @@ class ReflexionActor(AgentComponent[ReflexionActorInput]): } | {"agent_scratchpad"} @classmethod - def _create_default_prompt( - cls, system_message: Optional[str], user_message: str, **kwargs - ) -> ChatPromptTemplate: + def _create_default_prompt(cls, system_message: Optional[str], user_message: str, **kwargs) -> ChatPromptTemplate: if system_message is None: system_message = "You are an advanced reasoning agent that can improve based on self-reflection." diff --git a/planning_library/strategies/reflexion/components/evaluator.py b/planning_library/strategies/reflexion/components/evaluator.py index 560e009..e6da7d3 100644 --- a/planning_library/strategies/reflexion/components/evaluator.py +++ b/planning_library/strategies/reflexion/components/evaluator.py @@ -1,19 +1,22 @@ from __future__ import annotations -from planning_library.components.evaluation import EvaluatorComponent -from typing import Tuple, Dict, Any, List, Optional, Generic, Type -from planning_library.components.base_component import OutputType + +from textwrap import dedent +from typing import Any, Dict, Generic, List, Optional, Tuple, Type + from langchain_core.agents import AgentAction, AgentFinish from langchain_core.language_models import BaseChatModel -from langchain_core.output_parsers import BaseOutputParser from langchain_core.messages import BaseMessage +from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from textwrap import dedent +from typing_extensions import TypedDict + +from planning_library.components.base_component import OutputType +from planning_library.components.evaluation import EvaluatorComponent from planning_library.function_calling_parsers import ( - ParserRegistry, BaseFunctionCallingMultiActionParser, BaseFunctionCallingSingleActionParser, + ParserRegistry, ) -from typing_extensions import TypedDict class ReflexionEvaluatorInput(TypedDict): @@ -27,21 +30,17 @@ class PreprocessedReflexionEvaluatorInput(TypedDict): agent_outcome: str -class ReflexionEvaluator( - Generic[OutputType], EvaluatorComponent[ReflexionEvaluatorInput, OutputType] -): +class ReflexionEvaluator(Generic[OutputType], EvaluatorComponent[ReflexionEvaluatorInput, OutputType]): name = "Evaluator" - required_prompt_input_vars = set(ReflexionEvaluatorInput.__annotations__) - { - "inputs" - } + required_prompt_input_vars = set(ReflexionEvaluatorInput.__annotations__) - {"inputs"} @classmethod - def _create_default_prompt( - cls, system_message: Optional[str], user_message: str, **kwargs - ) -> ChatPromptTemplate: + def _create_default_prompt(cls, system_message: Optional[str], user_message: str, **kwargs) -> ChatPromptTemplate: if system_message is None: - system_message = "You are an advanced reasoning assistant that judges whether the episodes result in success or failure." + system_message = ( + "You are an advanced reasoning assistant that judges whether the episodes result in success or failure." + ) return ChatPromptTemplate.from_messages( [ @@ -73,9 +72,7 @@ def create( user_message: Optional[str] = None, system_message: Optional[str] = None, output_parser: Optional[BaseOutputParser[float]] = None, - parser: Optional[ - BaseFunctionCallingSingleActionParser | BaseFunctionCallingMultiActionParser - ] = None, + parser: Optional[BaseFunctionCallingSingleActionParser | BaseFunctionCallingMultiActionParser] = None, parser_name: Optional[str] = None, ) -> "ReflexionEvaluator[float]": def _preprocess_input( @@ -90,9 +87,7 @@ def _preprocess_input( preprocessed_inputs = parser.format_inputs(inputs) return { **preprocessed_inputs["inputs"], - "agent_outcome": preprocessed_inputs["agent_outcome"].return_values[ # type: ignore[typeddict-item] - "output" - ], + "agent_outcome": preprocessed_inputs["agent_outcome"].return_values["output"], # type: ignore[typeddict-item] "intermediate_steps": preprocessed_inputs["agent_scratchpad"], } diff --git a/planning_library/strategies/reflexion/components/self_reflection.py b/planning_library/strategies/reflexion/components/self_reflection.py index 5e17a93..838e0ba 100644 --- a/planning_library/strategies/reflexion/components/self_reflection.py +++ b/planning_library/strategies/reflexion/components/self_reflection.py @@ -1,16 +1,19 @@ from __future__ import annotations -from planning_library.components import RunnableComponent + +from textwrap import dedent +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type + from langchain_core.agents import AgentAction, AgentFinish -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.language_models import BaseChatModel -from textwrap import dedent -from typing import Optional, Sequence, Tuple, Dict, Any, List, Type from langchain_core.messages import BaseMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from typing_extensions import TypedDict + +from planning_library.components import RunnableComponent from planning_library.function_calling_parsers import ( - ParserRegistry, BaseFunctionCallingMultiActionParser, BaseFunctionCallingSingleActionParser, + ParserRegistry, ) @@ -25,19 +28,13 @@ class PreprocessedReflexionSelfReflectionInput(TypedDict): agent_outcome: str -class ReflexionSelfReflection( - RunnableComponent[ReflexionSelfReflectionInput, Sequence[BaseMessage]] -): +class ReflexionSelfReflection(RunnableComponent[ReflexionSelfReflectionInput, Sequence[BaseMessage]]): name = "Self-Reflection" - required_prompt_input_vars = set(ReflexionSelfReflectionInput.__annotations__) - { - "inputs" - } + required_prompt_input_vars = set(ReflexionSelfReflectionInput.__annotations__) - {"inputs"} @classmethod - def _create_default_prompt( - cls, system_message: Optional[str], user_message: str, **kwRGS - ) -> ChatPromptTemplate: + def _create_default_prompt(cls, system_message: Optional[str], user_message: str, **kwRGS) -> ChatPromptTemplate: if system_message is None: system_message = "You are an advanced reasoning agent that can self-reflect on their shortcomings when solving reasoning tasks." @@ -68,9 +65,7 @@ def create( prompt: Optional[ChatPromptTemplate] = None, user_message: Optional[str] = None, system_message: Optional[str] = None, - parser: Optional[ - BaseFunctionCallingSingleActionParser | BaseFunctionCallingMultiActionParser - ] = None, + parser: Optional[BaseFunctionCallingSingleActionParser | BaseFunctionCallingMultiActionParser] = None, parser_name: Optional[str] = None, ) -> "ReflexionSelfReflection": def _preprocess_input( @@ -84,15 +79,11 @@ def _preprocess_input( preprocessed_inputs = parser.format_inputs(inputs) return { **preprocessed_inputs["inputs"], - "agent_outcome": preprocessed_inputs["agent_outcome"].return_values[ # type: ignore[typeddict-item] - "output" - ], + "agent_outcome": preprocessed_inputs["agent_outcome"].return_values["output"], # type: ignore[typeddict-item] "intermediate_steps": preprocessed_inputs["agent_scratchpad"], } - prompt = cls._process_prompt( - prompt=prompt, user_message=user_message, system_message=system_message - ) + prompt = cls._process_prompt(prompt=prompt, user_message=user_message, system_message=system_message) # TODO: figure out typing here self_reflection: ReflexionSelfReflection = cls.create_from_steps( # type: ignore[assignment] diff --git a/planning_library/strategies/reflexion/reflexion_graph.py b/planning_library/strategies/reflexion/reflexion_graph.py index 661535f..1f95e8c 100644 --- a/planning_library/strategies/reflexion/reflexion_graph.py +++ b/planning_library/strategies/reflexion/reflexion_graph.py @@ -6,16 +6,16 @@ List, Literal, Optional, + Sequence, Tuple, TypedDict, - Sequence, Union, ) -from langchain.memory import ChatMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import BaseMessage, AIMessage +from langchain.memory import ChatMessageHistory from langchain_core.agents import AgentAction, AgentFinish, AgentStep +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import AIMessage, BaseMessage from langchain_core.runnables import RunnableLambda from langgraph.graph import END, StateGraph # type: ignore[import] from langgraph.pregel import Pregel # type: ignore[import-untyped] @@ -23,10 +23,10 @@ from ...action_executors import BaseActionExecutor from .components import ( ReflexionActor, - ReflexionSelfReflection, + ReflexionActorInput, ReflexionEvaluator, ReflexionEvaluatorInput, - ReflexionActorInput, + ReflexionSelfReflection, ReflexionSelfReflectionInput, ) @@ -58,15 +58,11 @@ def _format_self_reflections( return result @staticmethod - def init( - state: ReflexionState, memory: Optional[BaseChatMessageHistory] = None - ) -> ReflexionState: + def init(state: ReflexionState, memory: Optional[BaseChatMessageHistory] = None) -> ReflexionState: """The entry node in the graph. Initializes the state correctly.""" state["agent_outcome"] = None state["evaluator_should_continue"] = None - state["self_reflection_memory"] = ( - ChatMessageHistory() if memory is None else memory - ) + state["self_reflection_memory"] = ChatMessageHistory() if memory is None else memory state["self_reflections"] = [] state["intermediate_steps"] = [] state["iteration"] = 1 @@ -125,9 +121,7 @@ def execute_actions( action_executor: BaseActionExecutor, ) -> ReflexionState: """Synchronous version of executing actions as previously requested by an agent.""" - assert ( - state["agent_outcome"] is not None - ), "Agent outcome should be defined on the tool execution step." + assert state["agent_outcome"] is not None, "Agent outcome should be defined on the tool execution step." assert not isinstance( state["agent_outcome"], AgentFinish ), "Agent outcome should not be AgentFinish on the tool execution step." @@ -137,13 +131,9 @@ def execute_actions( ) if isinstance(observation, AgentStep): - state["intermediate_steps"].append( - (observation.action, observation.observation) - ) + state["intermediate_steps"].append((observation.action, observation.observation)) else: - state["intermediate_steps"].extend( - [(obs.action, obs.observation) for obs in observation] - ) + state["intermediate_steps"].extend([(obs.action, obs.observation) for obs in observation]) return state @staticmethod @@ -152,9 +142,7 @@ async def aexecute_actions( action_executor: BaseActionExecutor, ) -> ReflexionState: """Asynchronous version of executing tools as previously requested by an agent.""" - assert ( - state["agent_outcome"] is not None - ), "Agent outcome should be defined on the tool execution step." + assert state["agent_outcome"] is not None, "Agent outcome should be defined on the tool execution step." assert not isinstance( state["agent_outcome"], AgentFinish ), "Agent outcome should not be AgentFinish on the tool execution step." @@ -164,19 +152,13 @@ async def aexecute_actions( ) if isinstance(observation, AgentStep): - state["intermediate_steps"].append( - (observation.action, observation.observation) - ) + state["intermediate_steps"].append((observation.action, observation.observation)) else: - state["intermediate_steps"].extend( - [(obs.action, obs.observation) for obs in observation] - ) + state["intermediate_steps"].extend([(obs.action, obs.observation) for obs in observation]) return state @staticmethod - def evaluate( - state: ReflexionState, evaluator: ReflexionEvaluator - ) -> ReflexionState: + def evaluate(state: ReflexionState, evaluator: ReflexionEvaluator) -> ReflexionState: """Synchronous version of evaluating the outcome of the current trial.""" assert isinstance( state["agent_outcome"], AgentFinish @@ -192,9 +174,7 @@ def evaluate( return state @staticmethod - async def aevaluate( - state: ReflexionState, evaluator: ReflexionEvaluator - ) -> ReflexionState: + async def aevaluate(state: ReflexionState, evaluator: ReflexionEvaluator) -> ReflexionState: """Asynchronous version of evaluating the outcome of the current trial.""" assert isinstance( state["agent_outcome"], AgentFinish @@ -211,9 +191,7 @@ async def aevaluate( return state @staticmethod - def self_reflect( - state: ReflexionState, self_reflection: ReflexionSelfReflection - ) -> ReflexionState: + def self_reflect(state: ReflexionState, self_reflection: ReflexionSelfReflection) -> ReflexionState: """Synchronous version of self-reflecting on the current trial.""" assert isinstance( state["agent_outcome"], AgentFinish @@ -230,9 +208,7 @@ def self_reflect( return state @staticmethod - async def aself_reflect( - state: ReflexionState, self_reflection: ReflexionSelfReflection - ) -> ReflexionState: + async def aself_reflect(state: ReflexionState, self_reflection: ReflexionSelfReflection) -> ReflexionState: """Asynchronous version of self-reflecting on the current trial.""" assert isinstance( state["agent_outcome"], AgentFinish @@ -331,9 +307,7 @@ def create_reflexion_graph( "self_reflect", RunnableLambda( partial(ReflexionNodes.self_reflect, self_reflection=self_reflection), - afunc=partial( - ReflexionNodes.aself_reflect, self_reflection=self_reflection - ), + afunc=partial(ReflexionNodes.aself_reflect, self_reflection=self_reflection), ), ) @@ -356,9 +330,7 @@ def create_reflexion_graph( ) builder.add_conditional_edges( "self_reflect", - partial( - ReflexionEdges.should_continue_num_iterations, max_iterations=max_iterations - ), + partial(ReflexionEdges.should_continue_num_iterations, max_iterations=max_iterations), {"yes": "re_init", "no": END}, ) builder.add_edge("re_init", "act") diff --git a/planning_library/strategies/reflexion/reflexion_strategy.py b/planning_library/strategies/reflexion/reflexion_strategy.py index c17be26..3e05baf 100644 --- a/planning_library/strategies/reflexion/reflexion_strategy.py +++ b/planning_library/strategies/reflexion/reflexion_strategy.py @@ -1,10 +1,9 @@ from typing import Any, Callable, Dict, Optional - from ...action_executors import BaseActionExecutor from ..base_strategy import BaseLangGraphStrategy +from .components import ReflexionActor, ReflexionEvaluator, ReflexionSelfReflection from .reflexion_graph import create_reflexion_graph -from .components import ReflexionActor, ReflexionSelfReflection, ReflexionEvaluator class ReflexionStrategy(BaseLangGraphStrategy): diff --git a/planning_library/strategies/simple/simple_strategy.py b/planning_library/strategies/simple/simple_strategy.py index 259564a..8631519 100644 --- a/planning_library/strategies/simple/simple_strategy.py +++ b/planning_library/strategies/simple/simple_strategy.py @@ -1,21 +1,21 @@ from __future__ import annotations -from langchain_core.agents import AgentFinish, AgentAction, AgentStep +from typing import AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union + +from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent +from langchain_core.agents import AgentAction, AgentFinish, AgentStep from langchain_core.callbacks import ( - CallbackManagerForChainRun, AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, ) +from langchain_core.tools import BaseTool -from planning_library.strategies import BaseCustomStrategy from planning_library.action_executors import ( BaseActionExecutor, LangchainActionExecutor, MetaTools, ) - -from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent -from langchain_core.tools import BaseTool -from typing import Union, Sequence, Optional, Dict, Iterator, Tuple, List, AsyncIterator +from planning_library.strategies import BaseCustomStrategy class SimpleStrategy(BaseCustomStrategy): @@ -93,20 +93,15 @@ def _run_strategy( ) if isinstance(action_results, AgentStep): - intermediate_steps.append( - (action_results.action, action_results.observation) - ) + intermediate_steps.append((action_results.action, action_results.observation)) else: intermediate_steps.extend( - (_action_results.action, _action_results.observation) - for _action_results in action_results + (_action_results.action, _action_results.observation) for _action_results in action_results ) cur_iteration += 1 - stopped_outcome = AgentFinish( - {"output": "Agent stopped due to iteration limit."}, "" - ) + stopped_outcome = AgentFinish({"output": "Agent stopped due to iteration limit."}, "") yield stopped_outcome, intermediate_steps return @@ -133,20 +128,15 @@ async def _arun_strategy( action_results = await self.action_executor.aexecute(agent_outcome) if isinstance(action_results, AgentStep): - intermediate_steps.append( - (action_results.action, action_results.observation) - ) + intermediate_steps.append((action_results.action, action_results.observation)) else: intermediate_steps.extend( - (_action_results.action, _action_results.observation) - for _action_results in action_results + (_action_results.action, _action_results.observation) for _action_results in action_results ) cur_iteration += 1 - stopped_outcome = AgentFinish( - {"output": "Agent stopped due to iteration limit."}, "" - ) + stopped_outcome = AgentFinish({"output": "Agent stopped due to iteration limit."}, "") yield stopped_outcome, intermediate_steps return diff --git a/planning_library/strategies/tot_dfs/components/__init__.py b/planning_library/strategies/tot_dfs/components/__init__.py index 12faf5b..c61de6e 100644 --- a/planning_library/strategies/tot_dfs/components/__init__.py +++ b/planning_library/strategies/tot_dfs/components/__init__.py @@ -1,14 +1,14 @@ -from .thought_generator import ( - ThoughtGeneratorInput, - ThoughtGenerator, - ThoughtGeneratorConfig, -) -from .thought_sorter import ThoughtSorterInput, ThoughtSorter, ThoughtSorterConfig from .thought_evaluator import ( - ThoughtEvaluatorInput, ThoughtEvaluator, ThoughtEvaluatorConfig, + ThoughtEvaluatorInput, +) +from .thought_generator import ( + ThoughtGenerator, + ThoughtGeneratorConfig, + ThoughtGeneratorInput, ) +from .thought_sorter import ThoughtSorter, ThoughtSorterConfig, ThoughtSorterInput __all__ = [ "ThoughtGeneratorInput", diff --git a/planning_library/strategies/tot_dfs/components/thought_evaluator.py b/planning_library/strategies/tot_dfs/components/thought_evaluator.py index 7c474fe..035c8a6 100644 --- a/planning_library/strategies/tot_dfs/components/thought_evaluator.py +++ b/planning_library/strategies/tot_dfs/components/thought_evaluator.py @@ -1,23 +1,26 @@ from __future__ import annotations -from planning_library.components.evaluation import EvaluatorComponent -from typing import Tuple, Dict, Any, List, Optional, Generic, Type, Union -from planning_library.components.base_component import OutputType + +from dataclasses import dataclass +from textwrap import dedent +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union + from langchain_core.agents import AgentAction, AgentFinish from langchain_core.language_models import BaseChatModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable -from planning_library.utils import ( - format_thought, -) -from textwrap import dedent +from typing_extensions import TypedDict + +from planning_library.components.base_component import OutputType +from planning_library.components.evaluation import EvaluatorComponent from planning_library.function_calling_parsers import ( - ParserRegistry, BaseFunctionCallingMultiActionParser, BaseFunctionCallingSingleActionParser, + ParserRegistry, +) +from planning_library.utils import ( + format_thought, ) -from typing_extensions import TypedDict -from dataclasses import dataclass @dataclass @@ -49,17 +52,13 @@ class ThoughtEvaluatorInput(TypedDict): next_thought: List[AgentAction] | AgentAction | AgentFinish -class ThoughtEvaluator( - Generic[OutputType], EvaluatorComponent[ThoughtEvaluatorInput, OutputType] -): +class ThoughtEvaluator(Generic[OutputType], EvaluatorComponent[ThoughtEvaluatorInput, OutputType]): name = "Evaluate Thoughts" required_prompt_input_vars = set(ThoughtEvaluatorInput.__annotations__) - {"inputs"} @classmethod - def _create_default_prompt( - cls, system_message: Optional[str], user_message: str, **kwargs - ) -> ChatPromptTemplate: + def _create_default_prompt(cls, system_message: Optional[str], user_message: str, **kwargs) -> ChatPromptTemplate: if system_message is None: system_message = ( "You are an advanced reasoning assistant that judges the plausability of " @@ -104,9 +103,7 @@ def create( user_message: Optional[str] = None, system_message: Optional[str] = None, output_parser: Optional[BaseOutputParser[float]] = None, - parser: Optional[ - BaseFunctionCallingSingleActionParser | BaseFunctionCallingMultiActionParser - ] = None, + parser: Optional[BaseFunctionCallingSingleActionParser | BaseFunctionCallingMultiActionParser] = None, parser_name: Optional[str] = None, ) -> "ThoughtEvaluator[float]": def _preprocess_input( diff --git a/planning_library/strategies/tot_dfs/components/thought_generator.py b/planning_library/strategies/tot_dfs/components/thought_generator.py index 08f9374..dca56a0 100644 --- a/planning_library/strategies/tot_dfs/components/thought_generator.py +++ b/planning_library/strategies/tot_dfs/components/thought_generator.py @@ -1,22 +1,24 @@ from __future__ import annotations -from typing import Tuple, Dict, Any, List, Union, Optional, Sequence -from langchain_core.callbacks import CallbackManager, AsyncCallbackManager -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager from langchain_core.language_models import BaseChatModel +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.tools import BaseTool -from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent from typing_extensions import TypedDict -from planning_library.components import BaseComponent, AgentComponent + +from planning_library.components import AgentComponent, BaseComponent from planning_library.function_calling_parsers import ( - BaseFunctionCallingSingleActionParser, BaseFunctionCallingMultiActionParser, + BaseFunctionCallingSingleActionParser, ) from planning_library.utils import ( format_thoughts, ) -from dataclasses import dataclass @dataclass @@ -50,11 +52,7 @@ class ThoughtGeneratorAgentInput(ThoughtGeneratorInput): previous_thoughts: List[List[AgentAction] | AgentAction | AgentFinish] -class ThoughtGenerator( - BaseComponent[ - ThoughtGeneratorInput, List[Union[List[AgentAction], AgentAction, AgentFinish]] - ] -): +class ThoughtGenerator(BaseComponent[ThoughtGeneratorInput, List[Union[List[AgentAction], AgentAction, AgentFinish]]]): name = "Generate Thoughts" required_prompt_input_vars = set(ThoughtGeneratorInput.__annotations__) - { @@ -73,9 +71,7 @@ def __init__( self.max_num_thoughts = max_num_thoughts @classmethod - def _create_default_prompt( - cls, system_message: Optional[str], user_message: str, **kwargs - ) -> ChatPromptTemplate: + def _create_default_prompt(cls, system_message: Optional[str], user_message: str, **kwargs) -> ChatPromptTemplate: if system_message is None: system_message = "You are an advanced reasoning agent that can improve based on self-reflection." @@ -138,9 +134,7 @@ async def ainvoke( @classmethod def create_from_config(cls, config: ThoughtGeneratorConfig) -> ThoughtGenerator: if config.agent is not None: - return ThoughtGenerator( - agent=config.agent, max_num_thoughts=config.max_num_thoughts - ) + return ThoughtGenerator(agent=config.agent, max_num_thoughts=config.max_num_thoughts) if config.llm is None: raise ValueError("`llm` must be provided when `agent` is None.") @@ -176,9 +170,7 @@ def create( ] = None, parser_name: Optional[str] = None, ) -> ThoughtGenerator: - prompt = cls._process_prompt( - prompt=prompt, user_message=user_message, system_message=system_message - ) + prompt = cls._process_prompt(prompt=prompt, user_message=user_message, system_message=system_message) agent: AgentComponent = AgentComponent.create( llm=llm, @@ -190,11 +182,7 @@ def create( agent.add_input_preprocessing( preprocess=lambda inputs: { - **{ - key: value - for key, value in inputs.items() - if key != "previous_thoughts" - }, + **{key: value for key, value in inputs.items() if key != "previous_thoughts"}, "previous_thoughts": format_thoughts(inputs["previous_thoughts"]), } ) diff --git a/planning_library/strategies/tot_dfs/components/thought_sorter.py b/planning_library/strategies/tot_dfs/components/thought_sorter.py index ac2baa4..3b5bd36 100644 --- a/planning_library/strategies/tot_dfs/components/thought_sorter.py +++ b/planning_library/strategies/tot_dfs/components/thought_sorter.py @@ -1,24 +1,26 @@ from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from itertools import combinations from textwrap import dedent -from typing import Dict, List, Optional, Tuple, Union, Any +from typing import Any, Dict, List, Optional, Tuple, Union from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager from langchain_core.language_models import BaseChatModel -from langchain_core.output_parsers import BaseOutputParser from langchain_core.messages import BaseMessage +from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable -from itertools import combinations -from langchain_core.callbacks import AsyncCallbackManager, CallbackManager +from typing_extensions import TypedDict + from planning_library.components import BaseComponent, RunnableComponent from planning_library.function_calling_parsers import ( BaseFunctionCallingMultiActionParser, BaseFunctionCallingSingleActionParser, ParserRegistry, ) -from typing_extensions import TypedDict -from dataclasses import dataclass -from collections import defaultdict from planning_library.utils import ( format_thought, ) @@ -57,11 +59,7 @@ class ThoughtSorterRunnableInput(TypedDict): thought2: List[BaseMessage] -class ThoughtSorter( - BaseComponent[ - ThoughtSorterInput, List[Union[List[AgentAction], AgentAction, AgentFinish]] - ] -): +class ThoughtSorter(BaseComponent[ThoughtSorterInput, List[Union[List[AgentAction], AgentAction, AgentFinish]]]): """ ToT+DFS component responsible for sorting the candidate thought on each DFS step. @@ -77,17 +75,14 @@ class ThoughtSorter( def __init__( self, - runnable: Runnable[ThoughtSorterRunnableInput, str] - | RunnableComponent[ThoughtSorterRunnableInput, str], + runnable: Runnable[ThoughtSorterRunnableInput, str] | RunnableComponent[ThoughtSorterRunnableInput, str], ): if not isinstance(runnable, RunnableComponent): runnable = RunnableComponent(runnable) self.runnable = runnable @classmethod - def _create_default_prompt( - cls, system_message: Optional[str], user_message: str, **kwargs - ) -> ChatPromptTemplate: + def _create_default_prompt(cls, system_message: Optional[str], user_message: str, **kwargs) -> ChatPromptTemplate: if system_message is None: system_message = ( "You are an advanced reasoning assistant that compares the " @@ -168,12 +163,8 @@ def invoke( run_manager: Optional[CallbackManager] = None, **kwargs, ) -> List[Union[List[AgentAction], AgentAction, AgentFinish]]: - scores: Dict[Union[List[AgentAction], AgentAction, AgentFinish], float] = ( - defaultdict(float) - ) - for thought1, thought2 in [ - pair for pair in combinations(inputs["thoughts"], 2) - ]: + scores: Dict[Union[List[AgentAction], AgentAction, AgentFinish], float] = defaultdict(float) + for thought1, thought2 in [pair for pair in combinations(inputs["thoughts"], 2)]: result = self._compare_pairwise( inputs=inputs["inputs"], intermediate_steps=inputs["intermediate_steps"], @@ -200,12 +191,8 @@ async def ainvoke( run_manager: Optional[AsyncCallbackManager] = None, **kwargs, ) -> List[Union[List[AgentAction], AgentAction, AgentFinish]]: - scores: Dict[Union[List[AgentAction], AgentAction, AgentFinish], float] = ( - defaultdict(float) - ) - for thought1, thought2 in [ - pair for pair in combinations(inputs["thoughts"], 2) - ]: + scores: Dict[Union[List[AgentAction], AgentAction, AgentFinish], float] = defaultdict(float) + for thought1, thought2 in [pair for pair in combinations(inputs["thoughts"], 2)]: result = await self._acompare_pairwise( inputs=inputs["inputs"], intermediate_steps=inputs["intermediate_steps"], diff --git a/planning_library/strategies/tot_dfs/tot_strategy.py b/planning_library/strategies/tot_dfs/tot_strategy.py index c0bd161..6c38701 100644 --- a/planning_library/strategies/tot_dfs/tot_strategy.py +++ b/planning_library/strategies/tot_dfs/tot_strategy.py @@ -1,4 +1,5 @@ from __future__ import annotations + from collections import deque from typing import ( AsyncIterator, @@ -16,22 +17,19 @@ CallbackManagerForChainRun, ) - from ...action_executors import BaseActionExecutor, LangchainActionExecutor, MetaTools from ..base_strategy import BaseCustomStrategy - from .components import ( - ThoughtSorter, + ThoughtEvaluator, + ThoughtEvaluatorConfig, + ThoughtEvaluatorInput, ThoughtGenerator, + ThoughtGeneratorConfig, ThoughtGeneratorInput, + ThoughtSorter, ThoughtSorterConfig, - ThoughtEvaluatorConfig, - ThoughtGeneratorConfig, - ThoughtEvaluator, ThoughtSorterInput, - ThoughtEvaluatorInput, ) - from .utils import ToTNode @@ -47,9 +45,7 @@ class TreeOfThoughtsDFSStrategy(BaseCustomStrategy): thought_generator: ThoughtGenerator thought_evaluator: ThoughtEvaluator thought_sorter: Optional[ThoughtSorter] = None - do_sorting: bool = ( - False # True for DFS (Tree of Thoughts), False for DFSDT (ToolLLM) - ) + do_sorting: bool = False # True for DFS (Tree of Thoughts), False for DFSDT (ToolLLM) root: Optional[ToTNode] = None terminals: List[ToTNode] = [] @@ -97,28 +93,20 @@ def create( max_iterations: Maximum number of iterations. """ if generator_config is None: - raise ValueError( - "Default thought generator config is currently not supported." - ) + raise ValueError("Default thought generator config is currently not supported.") if evaluator_config is None: - raise ValueError( - "Default thought evaluator config is currently not supported." - ) + raise ValueError("Default thought evaluator config is currently not supported.") if do_sorting and sorter_config is None: - raise ValueError( - "Default thought sorter config is currently not supported." - ) + raise ValueError("Default thought sorter config is currently not supported.") generator = ThoughtGenerator.create_from_config(generator_config) evaluator = ThoughtEvaluator.create_from_config(evaluator_config) sorter = ThoughtSorter.create_from_config(sorter_config) if do_sorting else None # type: ignore[arg-type] if action_executor is None: - action_executor = LangchainActionExecutor( - tools=generator_config.tools, meta_tools=meta_tools - ) + action_executor = LangchainActionExecutor(tools=generator_config.tools, meta_tools=meta_tools) return cls( thought_generator=generator, @@ -155,23 +143,15 @@ def _dfs_step( # 1: generate k possible next steps thoughts = self.thought_generator.invoke( ThoughtGeneratorInput(inputs=inputs, intermediate_steps=trajectory), - run_manager=run_manager.get_child(tag="generate_thoughts") - if run_manager - else None, + run_manager=run_manager.get_child(tag="generate_thoughts") if run_manager else None, ) # 2: (optional) sort them if self.do_sorting: - assert ( - self.thought_sorter is not None - ), "Sorting enabled, but thought sorter was not passed." + assert self.thought_sorter is not None, "Sorting enabled, but thought sorter was not passed." thoughts = self.thought_sorter.invoke( - ThoughtSorterInput( - thoughts=thoughts, inputs=inputs, intermediate_steps=trajectory - ), - run_manager=run_manager.get_child(tag="sort_thoughts") - if run_manager - else None, + ThoughtSorterInput(thoughts=thoughts, inputs=inputs, intermediate_steps=trajectory), + run_manager=run_manager.get_child(tag="sort_thoughts") if run_manager else None, ) for cur_thought in thoughts: @@ -182,9 +162,7 @@ def _dfs_step( intermediate_steps=trajectory, next_thought=cur_thought, ), - run_manager=run_manager.get_child(tag="evaluate_thought") - if run_manager - else None, + run_manager=run_manager.get_child(tag="evaluate_thought") if run_manager else None, ) # 4: proceed only with thoughts with value above a certain threshold @@ -241,9 +219,7 @@ def _run_strategy( run_manager=run_manager.get_child() if run_manager else None, ) - new_node = ToTNode( - parent=cur_node, thought=new_thought, observation=observation - ) + new_node = ToTNode(parent=cur_node, thought=new_thought, observation=observation) cur_node.children.append(new_node) if isinstance(new_thought, AgentFinish): @@ -254,9 +230,7 @@ def _run_strategy( cur_step += 1 for node in self.terminals: - assert isinstance( - node.thought, AgentFinish - ), "Terminal nodes are expected to contain AgentFinish." + assert isinstance(node.thought, AgentFinish), "Terminal nodes are expected to contain AgentFinish." yield node.thought, node.trajectory async def _adfs_step( @@ -284,23 +258,15 @@ async def _adfs_step( # 1: generate k possible next steps thoughts = await self.thought_generator.ainvoke( ThoughtGeneratorInput(inputs=inputs, intermediate_steps=trajectory), - run_manager=run_manager.get_child(tag="generate_thoughts") - if run_manager - else None, + run_manager=run_manager.get_child(tag="generate_thoughts") if run_manager else None, ) # 2: (optional) sort them if self.do_sorting: - assert ( - self.thought_sorter is not None - ), "Sorting enabled, but thought sorter was not passed." + assert self.thought_sorter is not None, "Sorting enabled, but thought sorter was not passed." thoughts = await self.thought_sorter.ainvoke( - ThoughtSorterInput( - thoughts=thoughts, inputs=inputs, intermediate_steps=trajectory - ), - run_manager=run_manager.get_child(tag="sort_thoughts") - if run_manager - else None, + ThoughtSorterInput(thoughts=thoughts, inputs=inputs, intermediate_steps=trajectory), + run_manager=run_manager.get_child(tag="sort_thoughts") if run_manager else None, ) for cur_thought in thoughts: @@ -311,9 +277,7 @@ async def _adfs_step( intermediate_steps=trajectory, next_thought=cur_thought, ), - run_manager=run_manager.get_child(tag="evaluate_thought") - if run_manager - else None, + run_manager=run_manager.get_child(tag="evaluate_thought") if run_manager else None, ) # 4: proceed only with thoughts with value above a certain threshold @@ -370,9 +334,7 @@ async def _arun_strategy( run_manager=run_manager.get_child() if run_manager else None, ) - new_node = ToTNode( - parent=cur_node, thought=new_thought, observation=observation - ) + new_node = ToTNode(parent=cur_node, thought=new_thought, observation=observation) cur_node.children.append(new_node) if isinstance(new_thought, AgentFinish): self.terminals.append(new_node) @@ -382,7 +344,5 @@ async def _arun_strategy( cur_step += 1 for node in self.terminals: - assert isinstance( - node.thought, AgentFinish - ), "Terminal nodes are expected to contain AgentFinish." + assert isinstance(node.thought, AgentFinish), "Terminal nodes are expected to contain AgentFinish." yield node.thought, node.trajectory diff --git a/planning_library/strategies/tot_dfs/utils/tot_node.py b/planning_library/strategies/tot_dfs/utils/tot_node.py index 0821437..af12121 100644 --- a/planning_library/strategies/tot_dfs/utils/tot_node.py +++ b/planning_library/strategies/tot_dfs/utils/tot_node.py @@ -26,18 +26,13 @@ def trajectory(self) -> List[Tuple[AgentAction, str]]: trajectory_actions: List[Tuple[AgentAction, str]] = [] while node is not None: if isinstance(node.thought, list): - assert isinstance(node.observation, list) and len(node.thought) == len( - node.observation - ) + assert isinstance(node.observation, list) and len(node.thought) == len(node.observation) trajectory_actions.extend( - (observation.action, observation.observation) - for observation in node.observation + (observation.action, observation.observation) for observation in node.observation ) elif isinstance(node.thought, AgentAction): assert isinstance(node.observation, AgentStep) - trajectory_actions.append( - (node.observation.action, node.observation.observation) - ) + trajectory_actions.append((node.observation.action, node.observation.observation)) elif isinstance(node.thought, AgentFinish) and node is not self: raise ValueError("AgentFinish detected as non-terminal node.") diff --git a/planning_library/utils/actions_utils.py b/planning_library/utils/actions_utils.py index 7cd2b69..375f597 100644 --- a/planning_library/utils/actions_utils.py +++ b/planning_library/utils/actions_utils.py @@ -14,9 +14,7 @@ def get_tools_maps( tools: Sequence[BaseTool], ) -> Tuple[Dict[str, BaseTool], Dict[str, str]]: name_to_tool_map = {tool.name: tool for tool in tools} - color_mapping = get_color_mapping( - [tool.name for tool in tools], excluded_colors=["green"] - ) + color_mapping = get_color_mapping([tool.name for tool in tools], excluded_colors=["green"]) return name_to_tool_map, color_mapping diff --git a/planning_library/utils/format_agent_outputs.py b/planning_library/utils/format_agent_outputs.py index 9a8bdb6..93fc1d8 100644 --- a/planning_library/utils/format_agent_outputs.py +++ b/planning_library/utils/format_agent_outputs.py @@ -1,8 +1,9 @@ from __future__ import annotations from typing import List + from langchain_core.agents import AgentAction, AgentFinish -from langchain_core.messages import BaseMessage, AIMessage +from langchain_core.messages import AIMessage, BaseMessage def format_thought( @@ -14,17 +15,9 @@ def format_thought( messages.extend(format_thought(action)) return messages elif isinstance(thought, AgentAction): - return [ - AIMessage( - content=f"Call tool `{thought.tool}` with arguments `{thought.tool_input}`" - ) - ] + return [AIMessage(content=f"Call tool `{thought.tool}` with arguments `{thought.tool_input}`")] elif isinstance(thought, AgentFinish): - return [ - AIMessage( - content=f"Finish execution with return values `{thought.return_values}`" - ) - ] + return [AIMessage(content=f"Finish execution with return values `{thought.return_values}`")] raise ValueError(f"Unexpected type for `thought`: {type(thought)}") diff --git a/planning_library/utils/gym_env_reset_tool.py b/planning_library/utils/gym_env_reset_tool.py index e758a0f..5d560eb 100644 --- a/planning_library/utils/gym_env_reset_tool.py +++ b/planning_library/utils/gym_env_reset_tool.py @@ -1,15 +1,17 @@ from __future__ import annotations +from typing import Any, Dict, Generic, Optional, Tuple, TypeVar + +import gymnasium as gym from langchain.pydantic_v1 import BaseModel, Field from langchain.tools import BaseTool -import gymnasium as gym -from typing import Tuple, Optional, Any, Dict from langchain_core.agents import AgentAction from langchain_core.callbacks import CallbackManager -from gymnasium.core import ObsType + +ObsType = TypeVar("ObsType") -class GymEnvResetTool(BaseTool, BaseModel): +class GymEnvResetTool(BaseTool, BaseModel, Generic[ObsType]): env: gym.Env[ObsType, Tuple[AgentAction, Optional[CallbackManager]]] = Field( # type: ignore[valid-type] exclude=True ) @@ -26,4 +28,4 @@ def _run( seed: int | None = None, options: Dict[str, Any] | None = None, ) -> Tuple[ObsType, Dict[str, Any]]: - return self.env.reset(seed=seed, options=options) + return self.env.reset(seed=seed, options=options) # type: ignore[reportReturnType] diff --git a/poetry.lock b/poetry.lock index ebb2581..09f9c0c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3039,6 +3039,17 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9. extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "notebook" version = "7.1.1" @@ -4063,6 +4074,24 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pyright" +version = "1.1.368" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.368-py3-none-any.whl", hash = "sha256:4a86e34b61c755b43b367af7fbf927fc6466fff6b81a9dcea07d42416c640af3"}, + {file = "pyright-1.1.368.tar.gz", hash = "sha256:9b2aa48142d9d9fc9a6aedff743c76873cc4e615f3297cdbf893d5793f75b306"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" + +[package.extras] +all = ["twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] + [[package]] name = "pytest" version = "7.4.4" @@ -6431,4 +6460,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "c40c924f76e30cdcf182270d2d3b9bad453b709805a925759678051b9296acbf" +content-hash = "b9d87557dfc9f09fc25847e5fdb715058ab9d68d01e444717e23e5eb7b5ea218" diff --git a/pyproject.toml b/pyproject.toml index 0caf836..c223570 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,26 +25,29 @@ moviepy = "^1.0.3" alfworld = {extras = ["full"], version = "^0.3.3"} [tool.poetry.group.dev.dependencies] -black = {extras = ["jupyter"], version = "^23.7.0"} isort = "^5.12.0" mypy = "^1.5.0" pytest = "^7.4.0" wandb = "^0.16.3" grandalf = "^0.8" ruff = "^0.3.2" +pyright = "^1.1.368" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" -[tool.black] +[tool.ruff] line-length = 120 -target-version = ["py310"] +target-version = "py310" + +[tool.ruff.lint] +extend-select = ["I"] [tool.isort] -line_length = 120 -py_version = 310 profile = "black" +force_sort_within_sections = true +order_by_type = true [tool.mypy] python_version = "3.9"