Skip to content

Commit

Permalink
Merge pull request #2 from JetBrains-Research/tigina/default-action-e…
Browse files Browse the repository at this point in the history
…xecutor-fixes

Fix actions invocation for async executor calls
  • Loading branch information
saridormi authored May 6, 2024
2 parents 629c766 + e92dd6c commit 260100c
Showing 1 changed file with 47 additions and 28 deletions.
75 changes: 47 additions & 28 deletions planning_library/action_executors/default_action_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def reset(
) -> None:
"""Resets the current state. If actions are passed, will also execute them."""
if self.reset_tool_name is not None:
self.execute(
self._execute(
actions=[
AgentAction(
tool=self.reset_tool_name,
tool_input={},
log="Invoking reset tool.",
)
],
tool_executor=self._meta_tool_executor,
run_manager=run_manager,
)
if actions:
Expand All @@ -69,30 +70,35 @@ def execute(
**kwargs,
) -> AgentStep: ...

def execute(
def _execute(
self,
actions: List[AgentAction] | AgentAction,
tool_executor: ToolExecutor,
run_manager: Optional[CallbackManager] = None,
**kwargs,
) -> List[AgentStep] | AgentStep:
if isinstance(actions, list):
observations = [
self.execute(action, run_manager=run_manager) for action in actions
]
else:
observations = self._tool_executor.invoke(
actions,
config={"callbacks": run_manager} if run_manager else {},
)
if isinstance(observations, list):
assert isinstance(actions, list)
return [
AgentStep(action=action, observation=observation)
for action, observation in zip(actions, observations)
]
steps = []
for action in actions:
assert isinstance(action, AgentAction)
observation = self.execute(action, run_manager=run_manager)
steps.append(AgentStep(action=action, observation=observation))
return steps

assert isinstance(actions, AgentAction)
return AgentStep(action=actions, observation=observations)
observation = tool_executor.invoke(
actions,
config={"callbacks": run_manager} if run_manager else {},
)
return AgentStep(action=actions, observation=observation)

def execute(
self,
actions: List[AgentAction] | AgentAction,
run_manager: Optional[CallbackManager] = None,
**kwargs,
) -> List[AgentStep] | AgentStep:
return self._execute(actions, self._tool_executor, run_manager)

async def areset(
self,
Expand All @@ -102,14 +108,15 @@ async def areset(
) -> None:
"""Resets the current state. If actions are passed, will also execute them."""
if self.reset_tool_name is not None:
await self.aexecute(
await self._aexecute(
actions=[
AgentAction(
tool=self.reset_tool_name,
tool_input={},
log="Invoking reset tool.",
)
],
tool_executor=self._meta_tool_executor,
run_manager=run_manager,
)
if actions:
Expand All @@ -131,21 +138,33 @@ async def aexecute(
**kwargs,
) -> AgentStep: ...

async def aexecute(
async def _aexecute(
self,
actions: List[AgentAction] | AgentAction,
tool_executor: ToolExecutor,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> List[AgentStep] | AgentStep:
observations = await self._tool_executor.ainvoke(
if isinstance(actions, list):
steps = []
for action in actions:
observation = await tool_executor.ainvoke(
action,
config={"callbacks": run_manager} if run_manager else {},
)
steps.append(AgentStep(action=action, observation=observation))
return steps
assert isinstance(actions, AgentAction)
observation = await tool_executor.ainvoke(
actions,
config={"callbacks": run_manager} if run_manager else {},
)
if isinstance(observations, list):
assert isinstance(actions, list)
return [
AgentStep(action=action, observation=observation)
for action, observation in zip(actions, observations)
]
assert isinstance(actions, AgentAction)
return AgentStep(action=actions, observation=observations)
return AgentStep(action=actions, observation=observation)

async def aexecute(
self,
actions: List[AgentAction] | AgentAction,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> List[AgentStep] | AgentStep:
return await self._aexecute(actions, self._tool_executor, run_manager)

0 comments on commit 260100c

Please sign in to comment.