Skip to content

Commit

Permalink
Fix input formatting in ToT+DFS and add a way to prettify components …
Browse files Browse the repository at this point in the history
…in LangSmith (albeit not working for agents?? 🥹)
  • Loading branch information
saridormi committed Apr 17, 2024
1 parent 85bc609 commit d16358d
Show file tree
Hide file tree
Showing 20 changed files with 244 additions and 176 deletions.
6 changes: 4 additions & 2 deletions environments/game_of_24/common/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ def reset(
observation, info = "", {"numbers": self.numbers}

if options is not None and "trajectory" in options:
for action in options["trajectory"]:
assert isinstance(action, AgentAction)
for action, step in options["trajectory"]:
assert isinstance(
action, AgentAction
), f"Expected AgentAction, got {action}"
observation, reward, terminated, truncated, info = self.step(
(
action,
Expand Down
119 changes: 55 additions & 64 deletions environments/game_of_24/reflexion.ipynb

Large diffs are not rendered by default.

82 changes: 28 additions & 54 deletions environments/game_of_24/tot_dfs.ipynb

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion planning_library/action_executors/base_action_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ class BaseActionExecutor(ABC):
def tools(self) -> Sequence[BaseTool]: ...

@abstractmethod
def reset(self, actions: Optional[List[AgentAction]] = None, **kwargs) -> None:
def reset(
self,
actions: Optional[List[AgentAction]] = None,
run_manager: Optional[CallbackManager] = None,
**kwargs,
) -> None:
"""Resets the current state. If actions are passed, will also execute them."""
...

Expand Down
7 changes: 6 additions & 1 deletion planning_library/action_executors/default_action_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ class DefaultActionExecutor(BaseActionExecutor):
def __init__(self, tools: Sequence[BaseTool]):
self._tool_executor = ToolExecutor(tools)

def reset(self, actions: Optional[List[AgentAction]] = None, **kwargs) -> None:
def reset(
self,
actions: Optional[List[AgentAction]] = None,
run_manager: Optional[CallbackManager] = None,
**kwargs,
) -> None:
"""Resets the current state. If actions are passed, will also execute them.
This action executor doesn't have a state by default, so this method doesn't do anything.
Expand Down
10 changes: 9 additions & 1 deletion planning_library/action_executors/gymnasium_action_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,21 @@ def __init__(
def tools(self) -> Sequence[BaseTool]:
return self._env.get_wrapper_attr("tools")

def reset(self, actions: Optional[List[AgentAction]] = None, **kwargs) -> None:
def reset(
self,
actions: Optional[List[AgentAction]] = None,
run_manager: Optional[CallbackManager] = None,
**kwargs,
) -> None:
"""Resets the environment. If actions are passed, will also execute them."""

options = kwargs
if actions:
options["trajectory"] = actions

if run_manager:
options["run_manager"] = run_manager

self._env.reset(seed=self._seed, options=options)

@overload
Expand Down
7 changes: 4 additions & 3 deletions planning_library/components/agent_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,18 @@ def add_output_preprocessing(
)

def invoke(
self,
inputs: InputType,
run_manager: Optional[CallbackManager] = None,
self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs
) -> Union[List[AgentAction], AgentAction, AgentFinish]:
# TODO: no way to pass name to plan?
return self.agent.plan(**inputs, callbacks=run_manager)

async def ainvoke(
self,
inputs: InputType,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> Union[List[AgentAction], AgentAction, AgentFinish]:
# TODO: no way to pass name to plan?
outputs = await self.agent.aplan(**inputs, callbacks=run_manager)
return outputs

Expand Down
6 changes: 3 additions & 3 deletions planning_library/components/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


class BaseComponent(Generic[InputType, OutputType], ABC):
name: Optional[str] = None
required_prompt_input_vars: Set[str] = set()

@classmethod
Expand Down Expand Up @@ -56,14 +57,13 @@ def add_output_preprocessing(

@abstractmethod
def invoke(
self,
inputs: InputType,
run_manager: Optional[CallbackManager] = None,
self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs
) -> OutputType: ...

@abstractmethod
async def ainvoke(
self,
inputs: InputType,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> OutputType: ...
15 changes: 10 additions & 5 deletions planning_library/components/evaluation/evaluator_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def add_output_preprocessing(
self.judge.add_output_preprocessing(preprocess, apreprocess)

def invoke(
self,
inputs: InputType,
run_manager: Optional[CallbackManager] = None,
self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs
) -> bool:
backbone_output = self.backbone.invoke(inputs, run_manager)
if "run_name" not in kwargs and self.name:
kwargs["run_name"] = self.name

backbone_output = self.backbone.invoke(inputs, run_manager, **kwargs)
should_continue = self.judge.invoke(
{"backbone_output": backbone_output}, run_manager
)
Expand All @@ -51,8 +52,12 @@ async def ainvoke(
self,
inputs: InputType,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> bool:
backbone_output = await self.backbone.ainvoke(inputs, run_manager)
if "run_name" not in kwargs and self.name:
kwargs["run_name"] = self.name

backbone_output = await self.backbone.ainvoke(inputs, run_manager, **kwargs)
should_continue = await self.judge.ainvoke(
{"backbone_output": backbone_output}, run_manager
)
Expand Down
10 changes: 4 additions & 6 deletions planning_library/components/evaluation/threshold_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@ def __init__(self, threshold: float):
self.threshold = threshold

def invoke(
self,
inputs: InputType,
run_manager: Optional[CallbackManager] = None,
self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs
) -> bool:
return inputs["backbone_output"] <= self.threshold

async def ainvoke(
self,
inputs: InputType,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> bool:
return inputs["backbone_output"] <= self.threshold

Expand All @@ -28,15 +27,14 @@ def __init__(self, threshold: float):
self.threshold = threshold

def invoke(
self,
inputs: InputType,
run_manager: Optional[CallbackManager] = None,
self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs
) -> bool:
return inputs["backbone_output"] >= self.threshold

async def ainvoke(
self,
inputs: InputType,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> bool:
return inputs["backbone_output"] >= self.threshold
23 changes: 18 additions & 5 deletions planning_library/components/runnable_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,36 @@ def add_output_preprocessing(
self.runnable = self.runnable | RunnableLambda(preprocess, afunc=apreprocess)

def invoke(
self,
inputs: InputType,
run_manager: Optional[CallbackManager] = None,
self, inputs: InputType, run_manager: Optional[CallbackManager] = None, **kwargs
) -> OutputType:
config = kwargs
if "callbacks" not in config and run_manager:
config["callbacks"] = run_manager

if "run_name" not in config and self.name:
config["run_name"] = self.name

outputs = self.runnable.invoke(
inputs,
config={"callbacks": run_manager} if run_manager else {},
config=config, # type: ignore[arg-type]
)
return outputs

async def ainvoke(
self,
inputs: InputType,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> OutputType:
config = kwargs
if "callbacks" not in config and run_manager:
config["callbacks"] = run_manager

if "run_name" not in config and self.name:
config["run_name"] = self.name

outputs = await self.runnable.ainvoke(
inputs,
config={"callbacks": run_manager} if run_manager else {},
config=config, # type: ignore[arg-type]
)
return outputs
2 changes: 2 additions & 0 deletions planning_library/strategies/reflexion/components/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class ReflexionActor(AgentComponent[ReflexionActorInput]):
be initialized with only a single user message.
"""

name = "Actor"

required_prompt_input_vars = set(ReflexionActorInput.__annotations__) - {
"inputs",
"intermediate_steps",
Expand Down
2 changes: 2 additions & 0 deletions planning_library/strategies/reflexion/components/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class PreprocessedReflexionEvaluatorInput(TypedDict):
class ReflexionEvaluator(
Generic[OutputType], EvaluatorComponent[ReflexionEvaluatorInput, OutputType]
):
name = "Evaluator"

required_prompt_input_vars = set(ReflexionEvaluatorInput.__annotations__) - {
"inputs"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class PreprocessedReflexionSelfReflectionInput(TypedDict):
class ReflexionSelfReflection(
RunnableComponent[ReflexionSelfReflectionInput, Sequence[BaseMessage]]
):
name = "Self-Reflection"

required_prompt_input_vars = set(ReflexionSelfReflectionInput.__annotations__) - {
"inputs"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable

from planning_library.strategies.tot_dfs.utils.format_agent_outputs import (
format_thought,
)
from textwrap import dedent
from planning_library.function_calling_parsers import (
ParserRegistry,
Expand Down Expand Up @@ -45,6 +47,8 @@ class ThoughtEvaluatorInput(TypedDict):
class ThoughtEvaluator(
Generic[OutputType], EvaluatorComponent[ThoughtEvaluatorInput, OutputType]
):
name = "Evaluate Thoughts"

required_prompt_input_vars = set(ThoughtEvaluatorInput.__annotations__) - {"inputs"}

@classmethod
Expand All @@ -65,12 +69,11 @@ def _create_default_prompt(
user_message,
),
MessagesPlaceholder("intermediate_steps"),
("human", "Here is the proposed next step:" ""),
MessagesPlaceholder("next_thought"),
(
"human",
dedent("""
Here is the proposed next step:
{next_thought}
Your goal is to judge whether this proposal should be followed or discarded,
how likely it is to lead to the success.
Expand Down Expand Up @@ -114,7 +117,7 @@ def _preprocess_input(

return {
**inputs["inputs"],
"next_thought": inputs["next_thought"],
"next_thought": format_thought(inputs["next_thought"]),
"intermediate_steps": intermediate_steps,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
BaseFunctionCallingSingleActionParser,
BaseFunctionCallingMultiActionParser,
)
from planning_library.strategies.tot_dfs.utils.format_agent_outputs import (
format_thoughts,
)
from dataclasses import dataclass


Expand Down Expand Up @@ -48,6 +51,8 @@ class ThoughtGenerator(
ThoughtGeneratorInput, List[Union[List[AgentAction], AgentAction, AgentFinish]]
]
):
name = "Generate Thoughts"

required_prompt_input_vars = set(ThoughtGeneratorInput.__annotations__) - {
"inputs",
"intermediate_steps",
Expand Down Expand Up @@ -94,12 +99,14 @@ def invoke(
self,
inputs: ThoughtGeneratorInput,
run_manager: Optional[CallbackManager] = None,
**kwargs,
) -> List[List[AgentAction] | AgentAction | AgentFinish]:
results: List[List[AgentAction] | AgentAction | AgentFinish] = []
for _ in range(self.max_num_thoughts):
cur_result = self.agent.invoke(
{**inputs, "previous_thoughts": results},
run_manager=run_manager,
**kwargs,
)
# TODO: how to fix mypy warning properly here?
results.append(cur_result) # type: ignore[arg-type]
Expand All @@ -110,12 +117,14 @@ async def ainvoke(
self,
inputs: ThoughtGeneratorInput,
run_manager: Optional[AsyncCallbackManager] = None,
**kwargs,
) -> List[List[AgentAction] | AgentAction | AgentFinish]:
results: List[List[AgentAction] | AgentAction | AgentFinish] = []
for _ in range(self.max_num_thoughts):
cur_result = await self.agent.ainvoke(
{**inputs, "previous_thoughts": results},
run_manager=run_manager,
**kwargs,
)
# TODO: how to fix mypy warning properly here?
results.append(cur_result) # type: ignore[arg-type]
Expand Down Expand Up @@ -174,4 +183,16 @@ def create(
parser=parser,
parser_name=parser_name,
)

agent.add_input_preprocessing(
preprocess=lambda inputs: {
**{
key: value
for key, value in inputs.items()
if key != "previous_thoughts"
},
"previous_thoughts": format_thoughts(inputs["previous_thoughts"]),
}
)

return ThoughtGenerator(agent=agent, max_num_thoughts=max_num_thoughts)
Loading

0 comments on commit d16358d

Please sign in to comment.