diff --git a/README.md b/README.md index 3535240..066bc03 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ print(dumped_memory['memory']) Clear the memory of this session(`session_id=0` by default): ```python -agent.memory.reset() +agent.reset() ``` ### Custom Message Aggregation diff --git a/lagent/agents/__init__.py b/lagent/agents/__init__.py index f06972c..0a995d2 100644 --- a/lagent/agents/__init__.py +++ b/lagent/agents/__init__.py @@ -1,9 +1,33 @@ -from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential +from .agent import ( + Agent, + AgentDict, + AgentList, + AsyncAgent, + AsyncSequential, + AsyncStreamingAgent, + AsyncStreamingSequential, + Sequential, + StreamingAgent, + StreamingSequential, +) from .react import AsyncReAct, ReAct from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder __all__ = [ - 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM', - 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct', - 'AsyncReAct', 'Sequential', 'AsyncSequential' + 'Agent', + 'AgentDict', + 'AgentList', + 'AsyncAgent', + 'AgentForInternLM', + 'AsyncAgentForInternLM', + 'MathCoder', + 'AsyncMathCoder', + 'ReAct', + 'AsyncReAct', + 'Sequential', + 'AsyncSequential', + 'StreamingAgent', + 'StreamingSequential', + 'AsyncStreamingAgent', + 'AsyncStreamingSequential', ] diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py index b1e941b..9707d7b 100644 --- a/lagent/agents/agent.py +++ b/lagent/agents/agent.py @@ -3,7 +3,7 @@ from collections import OrderedDict, UserDict, UserList, abc from functools import wraps from itertools import chain, repeat -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union +from typing import Any, AsyncGenerator, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union from lagent.agents.aggregator import DefaultAggregator from lagent.hooks import Hook, RemovableHandle @@ -11,7 +11,7 @@ from lagent.memory import Memory, MemoryManager from lagent.prompts.parsers import StrParser from lagent.prompts.prompt_template import PromptTemplate -from lagent.schema import AgentMessage +from lagent.schema import AgentMessage, ModelStatusCode from lagent.utils import create_object @@ -63,29 +63,17 @@ def update_memory(self, message, session_id=0): if self.memory: self.memory.add(message, session_id=session_id) - def __call__( - self, - *message: Union[str, AgentMessage, List[AgentMessage]], - session_id=0, - **kwargs, - ) -> AgentMessage: + def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: # message.receiver = self.name - message = [ - AgentMessage(sender='user', content=m) - if isinstance(m, str) else copy.deepcopy(m) for m in message - ] + message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] for hook in self._hooks.values(): result = hook.before_agent(self, message, session_id) if result: message = result self.update_memory(message, session_id=session_id) - response_message = self.forward( - *message, session_id=session_id, **kwargs) + response_message = self.forward(*message, session_id=session_id, **kwargs) if not isinstance(response_message, AgentMessage): - response_message = AgentMessage( - sender=self.name, - content=response_message, - ) + response_message = AgentMessage(sender=self.name, content=response_message) self.update_memory(response_message, session_id=session_id) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): @@ -94,25 +82,14 @@ def __call__( response_message = result return response_message - def forward(self, - *message: AgentMessage, - session_id=0, - **kwargs) -> Union[AgentMessage, str]: + def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), - self.name, - self.output_format, - self.template, + self.memory.get(session_id), self.name, self.output_format, self.template ) llm_response = self.llm.chat(formatted_messages, **kwargs) if self.output_format: - formatted_messages = self.output_format.parse_response( - llm_response) - return AgentMessage( - sender=self.name, - content=llm_response, - formatted=formatted_messages, - ) + formatted_messages = self.output_format.parse_response(llm_response) + return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages) return llm_response def __setattr__(self, __name: str, __value: Any) -> None: @@ -165,12 +142,8 @@ def register_hook(self, hook: Callable): self._hooks[handle.id] = hook return handle - def reset(self, - session_id=0, - keypath: Optional[str] = None, - recursive: bool = False): - assert not (keypath and - recursive), 'keypath and recursive can\'t be used together' + def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = False): + assert not (keypath and recursive), 'keypath and recursive can\'t be used together' if keypath: keys, agent = keypath.split('.'), self for key in keys: @@ -189,15 +162,13 @@ def reset(self, def __repr__(self): def _rcsv_repr(agent, n_indent=1): - res = agent.__class__.__name__ + (f"(name='{agent.name}')" - if agent.name else '') + res = agent.__class__.__name__ + (f"(name='{agent.name}')" if agent.name else '') modules = [ f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}" for name, agent in getattr(agent, '_agents', {}).items() ] if modules: - res += '(\n' + '\n'.join( - modules) + f'\n{(n_indent - 1) * " "})' + res += '(\n' + '\n'.join(modules) + f'\n{(n_indent - 1) * " "})' elif not res.endswith(')'): res += '()' return res @@ -205,28 +176,18 @@ def _rcsv_repr(agent, n_indent=1): return _rcsv_repr(self) -class AsyncAgent(Agent): +class AsyncAgentMixin: - async def __call__(self, - *message: AgentMessage | List[AgentMessage], - session_id=0, - **kwargs) -> AgentMessage: - message = [ - AgentMessage(sender='user', content=m) - if isinstance(m, str) else copy.deepcopy(m) for m in message - ] + async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: + message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] for hook in self._hooks.values(): result = hook.before_agent(self, message, session_id) if result: message = result self.update_memory(message, session_id=session_id) - response_message = await self.forward( - *message, session_id=session_id, **kwargs) + response_message = await self.forward(*message, session_id=session_id, **kwargs) if not isinstance(response_message, AgentMessage): - response_message = AgentMessage( - sender=self.name, - content=response_message, - ) + response_message = AgentMessage(sender=self.name, content=response_message) self.update_memory(response_message, session_id=session_id) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): @@ -235,40 +196,133 @@ async def __call__(self, response_message = result return response_message - async def forward(self, - *message: AgentMessage, - session_id=0, - **kwargs) -> Union[AgentMessage, str]: + async def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), - self.name, - self.output_format, - self.template, + self.memory.get(session_id), self.name, self.output_format, self.template ) - llm_response = await self.llm.chat(formatted_messages, session_id, - **kwargs) + llm_response = await self.llm.chat(formatted_messages, session_id, **kwargs) if self.output_format: - formatted_messages = self.output_format.parse_response( - llm_response) - return AgentMessage( - sender=self.name, - content=llm_response, - formatted=formatted_messages, - ) + formatted_messages = self.output_format.parse_response(llm_response) + return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages) return llm_response +class AsyncAgent(AsyncAgentMixin, Agent): + """Asynchronous variant of the Agent class""" + + pass + + +class StreamingAgentMixin: + """Component that makes agent calling output a streaming response.""" + + def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Generator[AgentMessage, None, None]: + message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] + for hook in self._hooks.values(): + result = hook.before_agent(self, message, session_id) + if result: + message = result + self.update_memory(message, session_id=session_id) + response_message = AgentMessage(sender=self.name, content="") + for response_message in self.forward(*message, session_id=session_id, **kwargs): + if not isinstance(response_message, AgentMessage): + model_state, response = response_message + response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state) + yield response_message.model_copy() + self.update_memory(response_message, session_id=session_id) + response_message = copy.deepcopy(response_message) + for hook in self._hooks.values(): + result = hook.after_agent(self, response_message, session_id) + if result: + response_message = result + yield response_message + + def forward( + self, *message: AgentMessage, session_id=0, **kwargs + ) -> Generator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None, None]: + formatted_messages = self.aggregator.aggregate( + self.memory.get(session_id), self.name, self.output_format, self.template + ) + for model_state, response, *_ in self.llm.stream_chat(formatted_messages, session_id=session_id, **kwargs): + yield ( + AgentMessage( + sender=self.name, + content=response, + formatted=self.output_format.parse_response(response), + stream_state=model_state, + ) + if self.output_format + else (model_state, response) + ) + + +class AsyncStreamingAgentMixin: + """Component that makes asynchronous agent calling output a streaming response.""" + + async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AsyncGenerator[AgentMessage, None]: + message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] + for hook in self._hooks.values(): + result = hook.before_agent(self, message, session_id) + if result: + message = result + self.update_memory(message, session_id=session_id) + response_message = AgentMessage(sender=self.name, content="") + async for response_message in self.forward(*message, session_id=session_id, **kwargs): + if not isinstance(response_message, AgentMessage): + model_state, response = response_message + response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state) + yield response_message.model_copy() + self.update_memory(response_message, session_id=session_id) + response_message = copy.deepcopy(response_message) + for hook in self._hooks.values(): + result = hook.after_agent(self, response_message, session_id) + if result: + response_message = result + yield response_message + + async def forward( + self, *message: AgentMessage, session_id=0, **kwargs + ) -> AsyncGenerator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None]: + formatted_messages = self.aggregator.aggregate( + self.memory.get(session_id), self.name, self.output_format, self.template + ) + async for model_state, response, *_ in self.llm.stream_chat( + formatted_messages, session_id=session_id, **kwargs + ): + yield ( + AgentMessage( + sender=self.name, + content=response, + formatted=self.output_format.parse_response(response), + stream_state=model_state, + ) + if self.output_format + else (model_state, response) + ) + + +class StreamingAgent(StreamingAgentMixin, Agent): + """Streaming variant of the Agent class""" + + pass + + +class AsyncStreamingAgent(AsyncStreamingAgentMixin, Agent): + """Streaming variant of the AsyncAgent class""" + + pass + + class Sequential(Agent): - """Sequential is an agent container that forwards messages to each agent + """Sequential is an agent container that forwards messages to each agent in the order they are added.""" - def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs): + def __init__(self, *agents: Union[Agent, Iterable], **kwargs): super().__init__(**kwargs) self._agents = OrderedDict() if not agents: raise ValueError('At least one agent should be provided') - if isinstance(agents[0], - Iterable) and not isinstance(agents[0], Agent): + if isinstance(agents[0], Iterable) and not isinstance(agents[0], Agent): if not agents[0]: raise ValueError('At least one agent should be provided') agents = agents[0] @@ -279,17 +333,11 @@ def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs): key, agent = agent self.add_agent(key, agent) - def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]): - assert isinstance( - agent, (Agent, AsyncAgent - )), f'{type(agent)} is not an Agent or AsyncAgent subclass' + def add_agent(self, name: str, agent: Agent): + assert isinstance(agent, Agent), f'{type(agent)} is not an Agent subclass' self._agents[str(name)] = agent - def forward(self, - *message: AgentMessage, - session_id=0, - exit_at: Optional[int] = None, - **kwargs) -> AgentMessage: + def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs) -> AgentMessage: assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' if exit_at is None: exit_at = len(self) - 1 @@ -297,7 +345,7 @@ def forward(self, for _ in range(exit_at + 1): agent = next(iterator) if isinstance(message, AgentMessage): - message = (message, ) + message = (message,) message = agent(*message, session_id=session_id, **kwargs) return message @@ -311,13 +359,11 @@ def __len__(self): return len(self._agents) -class AsyncSequential(Sequential, AsyncAgent): +class AsyncSequential(AsyncAgentMixin, Sequential): - async def forward(self, - *message: AgentMessage, - session_id=0, - exit_at: Optional[int] = None, - **kwargs) -> AgentMessage: + async def forward( + self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs + ) -> AgentMessage: assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' if exit_at is None: exit_at = len(self) - 1 @@ -325,11 +371,43 @@ async def forward(self, for _ in range(exit_at + 1): agent = next(iterator) if isinstance(message, AgentMessage): - message = (message, ) + message = (message,) message = await agent(*message, session_id=session_id, **kwargs) return message +class StreamingSequential(StreamingAgentMixin, Sequential): + """Streaming variant of the Sequential class""" + + def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs): + assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' + if exit_at is None: + exit_at = len(self) - 1 + iterator = chain.from_iterable(repeat(self._agents.values())) + for _ in range(exit_at + 1): + agent = next(iterator) + if isinstance(message, AgentMessage): + message = (message,) + for message in agent(*message, session_id=session_id, **kwargs): + yield message + + +class AsyncStreamingSequential(AsyncStreamingAgentMixin, Sequential): + """Streaming variant of the AsyncSequential class""" + + async def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs): + assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' + if exit_at is None: + exit_at = len(self) - 1 + iterator = chain.from_iterable(repeat(self._agents.values())) + for _ in range(exit_at + 1): + agent = next(iterator) + if isinstance(message, AgentMessage): + message = (message,) + async for message in agent(*message, session_id=session_id, **kwargs): + yield message + + class AgentContainerMixin: def __init_subclass__(cls): @@ -349,33 +427,28 @@ def _backup(d): ret = func(self, *args, **kwargs) agents = OrderedDict() - for k, item in (self.data.items() if isinstance( - self.data, abc.Mapping) else enumerate(self.data)): - if isinstance(self.data, - abc.Mapping) and not isinstance(k, str): + for k, item in self.data.items() if isinstance(self.data, abc.Mapping) else enumerate(self.data): + if isinstance(self.data, abc.Mapping) and not isinstance(k, str): _backup(data) - raise KeyError( - f'agent name should be a string, got {type(k)}') + raise KeyError(f'agent name should be a string, got {type(k)}') if isinstance(k, str) and '.' in k: _backup(data) - raise KeyError( - f'agent name can\'t contain ".", got {k}') - if not isinstance(item, (Agent, AsyncAgent)): + raise KeyError(f'agent name can\'t contain ".", got {k}') + if not isinstance(item, Agent): _backup(data) - raise TypeError( - f'{type(item)} is not an Agent or AsyncAgent subclass' - ) + raise TypeError(f'{type(item)} is not an Agent subclass') agents[str(k)] = item self._agents = agents return ret return wrapped_func + # fmt: off for method in [ - 'append', 'sort', 'reverse', 'pop', 'clear', 'update', - 'insert', 'extend', 'remove', '__init__', '__setitem__', - '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', - '__imul__', '__rmul__' + 'append', 'sort', 'reverse', 'pop', 'clear', 'update', + 'insert', 'extend', 'remove', '__init__', '__setitem__', + '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', + '__imul__', '__rmul__' ]: if hasattr(cls, method): setattr(cls, method, wrap_api(getattr(cls, method))) @@ -383,8 +456,7 @@ def _backup(d): class AgentList(Agent, UserList, AgentContainerMixin): - def __init__(self, - agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None): + def __init__(self, agents: Optional[Iterable[Agent]] = None): Agent.__init__(self, memory=None) UserList.__init__(self, agents) self.name = None @@ -392,9 +464,7 @@ def __init__(self, class AgentDict(Agent, UserDict, AgentContainerMixin): - def __init__(self, - agents: Optional[Mapping[str, Union[Agent, - AsyncAgent]]] = None): + def __init__(self, agents: Optional[Mapping[str, Agent]] = None): Agent.__init__(self, memory=None) UserDict.__init__(self, agents) self.name = None diff --git a/lagent/agents/react.py b/lagent/agents/react.py index 41d2414..4a942a0 100644 --- a/lagent/agents/react.py +++ b/lagent/agents/react.py @@ -12,7 +12,6 @@ from lagent.prompts.parsers.json_parser import JSONParser from lagent.prompts.prompt_template import PromptTemplate from lagent.schema import AgentMessage -from lagent.utils import create_object select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括: {action_info} @@ -28,96 +27,88 @@ class ReAct(Agent): - def __init__(self, - llm: Union[BaseLLM, Dict], - actions: Union[BaseAction, List[BaseAction]], - template: Union[PromptTemplate, str] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict(type=JSONParser), - aggregator: Dict = dict(type=DefaultAggregator), - hooks: List = [dict(type=ActionPreprocessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: - 'conclusion' in m.content or 'conclusion' in m.formatted, - max_turn: int = 5, - **kwargs): + def __init__( + self, + llm: Union[BaseLLM, Dict], + actions: Union[BaseAction, List[BaseAction]], + template: Union[PromptTemplate, str] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict(type=JSONParser), + aggregator: Dict = dict(type=DefaultAggregator), + hooks: List = [dict(type=ActionPreprocessor)], + finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content + or 'conclusion' in m.formatted, + max_turn: int = 5, + **kwargs + ): self.max_turn = max_turn self.finish_condition = finish_condition - actions = dict( - type=ActionExecutor, - actions=actions, - hooks=hooks, - ) - self.actions: ActionExecutor = create_object(actions) - select_agent = dict( - type=Agent, + self.actions = ActionExecutor(actions=actions, hooks=hooks) + self.select_agent = Agent( llm=llm, template=template.format( - action_info=json.dumps(self.actions.description()), - output_format=output_format.format_instruction()), + action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction() + ), output_format=output_format, memory=memory, aggregator=aggregator, hooks=hooks, ) - self.select_agent = create_object(select_agent) super().__init__(**kwargs) - def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: + def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: for _ in range(self.max_turn): - message = self.select_agent(message) + message = self.select_agent(message, session_id=session_id, **kwargs) if self.finish_condition(message): return message - message = self.actions(message) + message = self.actions(message, session_id=session_id) return message class AsyncReAct(AsyncAgent): - def __init__(self, - llm: Union[BaseLLM, Dict], - actions: Union[BaseAction, List[BaseAction]], - template: Union[PromptTemplate, str] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict(type=JSONParser), - aggregator: Dict = dict(type=DefaultAggregator), - hooks: List = [dict(type=ActionPreprocessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: - 'conclusion' in m.content or 'conclusion' in m.formatted, - max_turn: int = 5, - **kwargs): + def __init__( + self, + llm: Union[BaseLLM, Dict], + actions: Union[BaseAction, List[BaseAction]], + template: Union[PromptTemplate, str] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict(type=JSONParser), + aggregator: Dict = dict(type=DefaultAggregator), + hooks: List = [dict(type=ActionPreprocessor)], + finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content + or 'conclusion' in m.formatted, + max_turn: int = 5, + **kwargs + ): self.max_turn = max_turn self.finish_condition = finish_condition - actions = dict( - type=AsyncActionExecutor, - actions=actions, - hooks=hooks, - ) - self.actions: AsyncActionExecutor = create_object(actions) - select_agent = dict( - type=AsyncAgent, + self.actions = AsyncActionExecutor(actions=actions, hooks=hooks) + self.select_agent = AsyncAgent( llm=llm, template=template.format( - action_info=json.dumps(self.actions.description()), - output_format=output_format.format_instruction()), + action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction() + ), output_format=output_format, memory=memory, aggregator=aggregator, hooks=hooks, ) - self.select_agent = create_object(select_agent) super().__init__(**kwargs) - async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: + async def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: for _ in range(self.max_turn): - message = await self.select_agent(message) + message = await self.select_agent(message, session_id=session_id, **kwargs) if self.finish_condition(message): return message - message = await self.actions(message) + message = await self.actions(message, session_id=session_id) return message if __name__ == '__main__': - from lagent.llms import GPTAPI + import asyncio + + from lagent.llms import GPTAPI, AsyncGPTAPI class ActionCall(BaseModel): name: str = Field(description='调用的函数名称') @@ -125,37 +116,49 @@ class ActionCall(BaseModel): class ActionFormat(BaseModel): thought_process: str = Field( - description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') + description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。' + ) action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。') class FinishFormat(BaseModel): thought_process: str = Field( - description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') + description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。' + ) conclusion: str = Field(description='总结当前的搜索结果,回答问题。') prompt_template = PromptTemplate(select_action_template) - output_format = JSONParser( - output_format_template, - function_format=ActionFormat, - finish_format=FinishFormat) - - llm = dict( - type=GPTAPI, - model_type='gpt-4o-2024-05-13', - key=None, - max_new_tokens=4096, - proxies=dict(), - retry=1000) + output_format = JSONParser(output_format_template, function_format=ActionFormat, finish_format=FinishFormat) agent = ReAct( - llm=llm, + llm=dict( + type=GPTAPI, + model_type='gpt-4o-2024-05-13', + max_new_tokens=4096, + proxies=dict(), + retry=1000, + ), template=prompt_template, output_format=output_format, - aggregator=dict(type='DefaultAggregator'), - actions=[dict(type='PythonInterpreter')], + aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'), + actions=[dict(type='lagent.actions.PythonInterpreter')], ) - response = agent( - AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) + response = agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) print(response) response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢')) print(response) + + async_agent = AsyncReAct( + llm=dict( + type=AsyncGPTAPI, + model_type='gpt-4o-2024-05-13', + max_new_tokens=4096, + proxies=dict(), + retry=1000, + ), + template=prompt_template, + output_format=output_format, + aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'), + actions=[dict(type='lagent.actions.AsyncPythonInterpreter')], + ) + response = asyncio.run(async_agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))) + print(async_agent.state_dict()) diff --git a/lagent/agents/stream.py b/lagent/agents/stream.py index 512250f..ba79fdc 100644 --- a/lagent/agents/stream.py +++ b/lagent/agents/stream.py @@ -15,22 +15,27 @@ API_PREFIX = ( "This is the subfunction for tool '{tool_name}', you can use this tool. " - 'The description of this function is: \n{description}') + 'The description of this function is: \n{description}' +) -META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用') +META_CN = '当开启工具以及代码时,根据需求选择合适的工具进行调用' -INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' - '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' - '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' - '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' - '文本处理和分析(比如文本解析和自然语言处理),' - '机器学习和数据科学(用于展示模型训练和数据可视化),' - '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。') +INTERPRETER_CN = ( + '你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' + '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' + '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' + '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' + '文本处理和分析(比如文本解析和自然语言处理),' + '机器学习和数据科学(用于展示模型训练和数据可视化),' + '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。' +) -PLUGIN_CN = ('你可以使用如下工具:' - '\n{prompt}\n' - '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' - '同时注意你可以使用的工具,不要随意捏造!') +PLUGIN_CN = ( + '你可以使用如下工具:' + '\n{prompt}\n' + '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' + '同时注意你可以使用的工具,不要随意捏造!' +) def get_plugin_prompt(actions, api_desc_template=API_PREFIX): @@ -41,19 +46,15 @@ def get_plugin_prompt(actions, api_desc_template=API_PREFIX): if action.is_toolkit: for api in action_desc['api_list']: api['name'] = f"{action.name}.{api['name']}" - api['description'] = api_desc_template.format( - tool_name=action.name, description=api['description']) - api['parameters'] = [ - param for param in api['parameters'] - if param['name'] in api['required'] - ] + api['description'] = api_desc_template.format(tool_name=action.name, description=api['description']) + api['parameters'] = [param for param in api['parameters'] if param['name'] in api['required']] plugin_descriptions.append(api) else: action_desc['description'] = api_desc_template.format( - tool_name=action.name, description=action_desc['description']) + tool_name=action.name, description=action_desc['description'] + ) action_desc['parameters'] = [ - param for param in action_desc['parameters'] - if param['name'] in action_desc['required'] + param for param in action_desc['parameters'] if param['name'] in action_desc['required'] ] plugin_descriptions.append(action_desc) return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4) @@ -76,17 +77,15 @@ def __init__( parsers=[ dict(type=PluginParser, template=PLUGIN_CN), dict(type=InterpreterParser, template=INTERPRETER_CN), - ]), + ], + ), aggregator: Dict = dict(type=InternLMToolAggregator), action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[ - [AgentMessage], - bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 4, **kwargs, ): - agent = dict( - type=self._INTERNAL_AGENT_CLS, + self.agent = self._INTERNAL_AGENT_CLS( llm=llm, template=template, output_format=output_format, @@ -94,22 +93,18 @@ def __init__( aggregator=aggregator, hooks=kwargs.pop('hooks', None), ) - self.agent = create_object(agent) - self.plugin_executor = plugins and ActionExecutor( - plugins, hooks=action_hooks) - self.interpreter_executor = interpreter and ActionExecutor( - interpreter, hooks=action_hooks) + self.plugin_executor = plugins and ActionExecutor(plugins, hooks=action_hooks) + self.interpreter_executor = interpreter and ActionExecutor(interpreter, hooks=action_hooks) if not (self.plugin_executor or self.interpreter_executor): warnings.warn( 'Neither plugin nor interpreter executor is initialized. ' - 'An exception will be thrown when the agent call a tool.') + 'An exception will be thrown when the agent call a tool.' + ) self.finish_condition = finish_condition self.max_turn = max_turn super().__init__(**kwargs) def forward(self, message: AgentMessage, session_id=0, **kwargs): - if isinstance(message, str): - message = AgentMessage(sender='user', content=message) for _ in range(self.max_turn): message = self.agent(message, session_id=session_id, **kwargs) assert isinstance(message.formatted, dict) @@ -127,15 +122,10 @@ def get_steps(self, session_id=0): steps, tool_type = [], None for msg in self.agent.memory.get_memory(session_id): if msg.sender == self.agent.name: - steps.append( - dict(role='thought', content=msg.formatted['thought'])) + steps.append(dict(role='thought', content=msg.formatted['thought'])) if msg.formatted['tool_type']: tool_type = msg.formatted['tool_type'] - steps.append( - dict( - role='tool', - content=msg.formatted['action'], - name=tool_type)) + steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type)) elif msg.sender != 'user': feedback = dict(role='environment', content=msg.content) if tool_type: @@ -149,23 +139,22 @@ class MathCoder(AgentForInternLM): def __init__( self, llm: Union[BaseLLM, Dict], - interpreter: dict = dict( - type=IPythonInteractive, timeout=20, max_out_len=8192), + interpreter: dict = dict(type=IPythonInteractive, timeout=20, max_out_len=8192), template: Union[str, dict, List[dict]] = None, memory: Dict = dict(type=Memory), output_format: Dict = dict( type=InterpreterParser, - template= - ('Integrate step-by-step reasoning and Python code to solve math problems ' - 'using the following guidelines:\n' - '- Analyze the question and write jupyter code to solve the problem;\n' - r"- Present the final result in LaTeX using a '\boxed{{}}' without any " - 'units. \n')), + template=( + 'Integrate step-by-step reasoning and Python code to solve math problems ' + 'using the following guidelines:\n' + '- Analyze the question and write jupyter code to solve the problem;\n' + r"- Present the final result in LaTeX using a '\boxed{{}}' without any " + 'units. \n' + ), + ), aggregator: Dict = dict(type=InternLMToolAggregator), action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[ - [AgentMessage], - bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 6, **kwargs, ): @@ -180,7 +169,8 @@ def __init__( action_hooks=action_hooks, finish_condition=finish_condition, max_turn=max_turn, - **kwargs) + **kwargs, + ) class AsyncAgentForInternLM(AsyncAgent): @@ -200,17 +190,15 @@ def __init__( parsers=[ dict(type=PluginParser, template=PLUGIN_CN), dict(type=InterpreterParser, template=INTERPRETER_CN), - ]), + ], + ), aggregator: Dict = dict(type=InternLMToolAggregator), action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[ - [AgentMessage], - bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 4, **kwargs, ): - agent = dict( - type=self._INTERNAL_AGENT_CLS, + self.agent = self._INTERNAL_AGENT_CLS( llm=llm, template=template, output_format=output_format, @@ -218,25 +206,20 @@ def __init__( aggregator=aggregator, hooks=kwargs.pop('hooks', None), ) - self.agent = create_object(agent) - self.plugin_executor = plugins and AsyncActionExecutor( - plugins, hooks=action_hooks) - self.interpreter_executor = interpreter and AsyncActionExecutor( - interpreter, hooks=action_hooks) + self.plugin_executor = plugins and AsyncActionExecutor(plugins, hooks=action_hooks) + self.interpreter_executor = interpreter and AsyncActionExecutor(interpreter, hooks=action_hooks) if not (self.plugin_executor or self.interpreter_executor): warnings.warn( 'Neither plugin nor interpreter executor is initialized. ' - 'An exception will be thrown when the agent call a tool.') + 'An exception will be thrown when the agent call a tool.' + ) self.finish_condition = finish_condition self.max_turn = max_turn super().__init__(**kwargs) async def forward(self, message: AgentMessage, session_id=0, **kwargs): - if isinstance(message, str): - message = AgentMessage(sender='user', content=message) for _ in range(self.max_turn): - message = await self.agent( - message, session_id=session_id, **kwargs) + message = await self.agent(message, session_id=session_id, **kwargs) assert isinstance(message.formatted, dict) if self.finish_condition(message): return message @@ -252,15 +235,10 @@ def get_steps(self, session_id=0): steps, tool_type = [], None for msg in self.agent.memory.get_memory(session_id): if msg.sender == self.agent.name: - steps.append( - dict(role='thought', content=msg.formatted['thought'])) + steps.append(dict(role='thought', content=msg.formatted['thought'])) if msg.formatted['tool_type']: tool_type = msg.formatted['tool_type'] - steps.append( - dict( - role='tool', - content=msg.formatted['action'], - name=tool_type)) + steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type)) elif msg.sender != 'user': feedback = dict(role='environment', content=msg.content) if tool_type: @@ -279,17 +257,17 @@ def __init__( memory: Dict = dict(type=Memory), output_format: Dict = dict( type=InterpreterParser, - template= - ('Integrate step-by-step reasoning and Python code to solve math problems ' - 'using the following guidelines:\n' - '- Analyze the question and write jupyter code to solve the problem;\n' - r"- Present the final result in LaTeX using a '\boxed{{}}' without any " - 'units. \n')), + template=( + 'Integrate step-by-step reasoning and Python code to solve math problems ' + 'using the following guidelines:\n' + '- Analyze the question and write jupyter code to solve the problem;\n' + r"- Present the final result in LaTeX using a '\boxed{{}}' without any " + 'units. \n' + ), + ), aggregator: Dict = dict(type=InternLMToolAggregator), action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[ - [AgentMessage], - bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, max_turn: int = 6, **kwargs, ): @@ -304,13 +282,13 @@ def __init__( action_hooks=action_hooks, finish_condition=finish_condition, max_turn=max_turn, - **kwargs) + **kwargs, + ) async def forward(self, message: AgentMessage, session_id=0, **kwargs): try: return await super().forward(message, session_id, **kwargs) finally: - interpreter = next( - iter(self.interpreter_executor.actions.values())) + interpreter = next(iter(self.interpreter_executor.actions.values())) if interpreter.name == 'AsyncIPythonInterpreter': await interpreter.close_session(session_id) diff --git a/lagent/prompts/parsers/str_parser.py b/lagent/prompts/parsers/str_parser.py index 6af7aa6..be997bc 100644 --- a/lagent/prompts/parsers/str_parser.py +++ b/lagent/prompts/parsers/str_parser.py @@ -1,3 +1,4 @@ +import string from typing import Any @@ -8,14 +9,17 @@ def __init__( template: str = '', **format_field, ): + fields = {item[1] for item in string.Formatter().parse(template) if item[1] is not None} + if not fields.issubset(format_field.keys()): + raise ValueError( + 'not all required fields of "template" are provided, missing ' + f'{fields - format_field.keys()}. Please pass them as keyword arguments.' + ) self.template = template self.format_field = format_field def format_instruction(self) -> Any: - format_data = { - key: self.format_to_string(value) - for key, value in self.format_field.items() - } + format_data = {key: self.format_to_string(value) for key, value in self.format_field.items()} return self.template.format(**format_data) def format_to_string(self, format_model: Any) -> str: diff --git a/lagent/prompts/parsers/tool_parser.py b/lagent/prompts/parsers/tool_parser.py index 5343312..a8ffea3 100644 --- a/lagent/prompts/parsers/tool_parser.py +++ b/lagent/prompts/parsers/tool_parser.py @@ -23,29 +23,24 @@ class ToolStatusCode(IntEnum): class ToolParser(StrParser): - def __init__(self, - tool_type: str, - template: str = '', - begin: str = '\n', - end: str = '\n', - validate: Callable[[str], Any] = None, - **kwargs): + def __init__( + self, + tool_type: str, + template: str = '', + begin: str = '\n', + end: str = '\n', + validate: Callable[[str], Any] = None, + **kwargs + ): super().__init__(template, begin=begin, end=end, **kwargs) self.template = template self.tool_type = tool_type - # self.pattern = re.compile( - # '(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)), - # re.DOTALL) - self.validate = load_class_from_string(validate) if isinstance( - validate, str) else validate + # self.pattern = re.compile('(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)), re.DOTALL) + self.validate = load_class_from_string(validate) if isinstance(validate, str) else validate def parse_response(self, data: str) -> dict: if self.format_field['begin'] not in data: - return dict( - tool_type=None, - thought=data, - action=None, - status=ToolStatusCode.NO_TOOL) + return dict(tool_type=None, thought=data, action=None, status=ToolStatusCode.NO_TOOL) thought, action, *_ = data.split(self.format_field["begin"]) action = action.split(self.format_field['end'])[0] status = ToolStatusCode.VALID_TOOL @@ -54,11 +49,7 @@ def parse_response(self, data: str) -> dict: action = self.validate(action) except Exception: status = ToolStatusCode.PARSING_ERROR - return dict( - tool_type=self.tool_type, - thought=thought, - action=action, - status=status) + return dict(tool_type=self.tool_type, thought=thought, action=action, status=status) def format_response(self, parsed: dict) -> str: if parsed['action'] is None: @@ -68,41 +59,40 @@ def format_response(self, parsed: dict) -> str: action = json.dumps(parsed['action'], ensure_ascii=False) else: action = str(parsed['action']) - return parsed['thought'] + self.format_field[ - 'begin'] + action + self.format_field['end'] + return parsed['thought'] + self.format_field['begin'] + action + self.format_field['end'] class InterpreterParser(ToolParser): - def __init__(self, - tool_type: str = 'interpreter', - template: str = '', - begin: str = '<|action_start|><|interpreter|>\n', - end: str = '<|action_end|>\n', - validate: Callable[[str], Any] = None, - **kwargs): + def __init__( + self, + tool_type: str = 'interpreter', + template: str = '', + begin: str = '<|action_start|><|interpreter|>\n', + end: str = '<|action_end|>\n', + validate: Callable[[str], Any] = None, + **kwargs + ): super().__init__(tool_type, template, begin, end, validate, **kwargs) class PluginParser(ToolParser): - def __init__(self, - tool_type: str = 'plugin', - template: str = '', - begin: str = '<|action_start|><|plugin|>\n', - end: str = '<|action_end|>\n', - validate: Callable[[str], Any] = default_plugin_validate, - **kwargs): + def __init__( + self, + tool_type: str = 'plugin', + template: str = '', + begin: str = '<|action_start|><|plugin|>\n', + end: str = '<|action_end|>\n', + validate: Callable[[str], Any] = default_plugin_validate, + **kwargs + ): super().__init__(tool_type, template, begin, end, validate, **kwargs) class MixedToolParser(StrParser): - def __init__(self, - tool_type: Optional[str] = None, - template='', - parsers: List[ToolParser] = None, - **format_field): + def __init__(self, tool_type: Optional[str] = None, template='', parsers: List[ToolParser] = None, **format_field): self.parsers = {} self.tool_type = tool_type for parser in parsers or []: @@ -125,11 +115,7 @@ def format_instruction(self) -> List[dict]: return inst def parse_response(self, data: str) -> dict: - res = dict( - tool_type=None, - thought=data, - action=None, - status=ToolStatusCode.NO_TOOL) + res = dict(tool_type=None, thought=data, action=None, status=ToolStatusCode.NO_TOOL) for name, parser in self.parsers.items(): res = parser.parse_response(data) if res['tool_type'] == name: