diff --git a/README.md b/README.md
index 3535240..066bc03 100644
--- a/README.md
+++ b/README.md
@@ -96,7 +96,7 @@ print(dumped_memory['memory'])
Clear the memory of this session(`session_id=0` by default):
```python
-agent.memory.reset()
+agent.reset()
```
### Custom Message Aggregation
diff --git a/lagent/agents/__init__.py b/lagent/agents/__init__.py
index f06972c..0a995d2 100644
--- a/lagent/agents/__init__.py
+++ b/lagent/agents/__init__.py
@@ -1,9 +1,33 @@
-from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential
+from .agent import (
+ Agent,
+ AgentDict,
+ AgentList,
+ AsyncAgent,
+ AsyncSequential,
+ AsyncStreamingAgent,
+ AsyncStreamingSequential,
+ Sequential,
+ StreamingAgent,
+ StreamingSequential,
+)
from .react import AsyncReAct, ReAct
from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder
__all__ = [
- 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
- 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
- 'AsyncReAct', 'Sequential', 'AsyncSequential'
+ 'Agent',
+ 'AgentDict',
+ 'AgentList',
+ 'AsyncAgent',
+ 'AgentForInternLM',
+ 'AsyncAgentForInternLM',
+ 'MathCoder',
+ 'AsyncMathCoder',
+ 'ReAct',
+ 'AsyncReAct',
+ 'Sequential',
+ 'AsyncSequential',
+ 'StreamingAgent',
+ 'StreamingSequential',
+ 'AsyncStreamingAgent',
+ 'AsyncStreamingSequential',
]
diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py
index b1e941b..9707d7b 100644
--- a/lagent/agents/agent.py
+++ b/lagent/agents/agent.py
@@ -3,7 +3,7 @@
from collections import OrderedDict, UserDict, UserList, abc
from functools import wraps
from itertools import chain, repeat
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
+from typing import Any, AsyncGenerator, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union
from lagent.agents.aggregator import DefaultAggregator
from lagent.hooks import Hook, RemovableHandle
@@ -11,7 +11,7 @@
from lagent.memory import Memory, MemoryManager
from lagent.prompts.parsers import StrParser
from lagent.prompts.prompt_template import PromptTemplate
-from lagent.schema import AgentMessage
+from lagent.schema import AgentMessage, ModelStatusCode
from lagent.utils import create_object
@@ -63,29 +63,17 @@ def update_memory(self, message, session_id=0):
if self.memory:
self.memory.add(message, session_id=session_id)
- def __call__(
- self,
- *message: Union[str, AgentMessage, List[AgentMessage]],
- session_id=0,
- **kwargs,
- ) -> AgentMessage:
+ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
# message.receiver = self.name
- message = [
- AgentMessage(sender='user', content=m)
- if isinstance(m, str) else copy.deepcopy(m) for m in message
- ]
+ message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
for hook in self._hooks.values():
result = hook.before_agent(self, message, session_id)
if result:
message = result
self.update_memory(message, session_id=session_id)
- response_message = self.forward(
- *message, session_id=session_id, **kwargs)
+ response_message = self.forward(*message, session_id=session_id, **kwargs)
if not isinstance(response_message, AgentMessage):
- response_message = AgentMessage(
- sender=self.name,
- content=response_message,
- )
+ response_message = AgentMessage(sender=self.name, content=response_message)
self.update_memory(response_message, session_id=session_id)
response_message = copy.deepcopy(response_message)
for hook in self._hooks.values():
@@ -94,25 +82,14 @@ def __call__(
response_message = result
return response_message
- def forward(self,
- *message: AgentMessage,
- session_id=0,
- **kwargs) -> Union[AgentMessage, str]:
+ def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]:
formatted_messages = self.aggregator.aggregate(
- self.memory.get(session_id),
- self.name,
- self.output_format,
- self.template,
+ self.memory.get(session_id), self.name, self.output_format, self.template
)
llm_response = self.llm.chat(formatted_messages, **kwargs)
if self.output_format:
- formatted_messages = self.output_format.parse_response(
- llm_response)
- return AgentMessage(
- sender=self.name,
- content=llm_response,
- formatted=formatted_messages,
- )
+ formatted_messages = self.output_format.parse_response(llm_response)
+ return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages)
return llm_response
def __setattr__(self, __name: str, __value: Any) -> None:
@@ -165,12 +142,8 @@ def register_hook(self, hook: Callable):
self._hooks[handle.id] = hook
return handle
- def reset(self,
- session_id=0,
- keypath: Optional[str] = None,
- recursive: bool = False):
- assert not (keypath and
- recursive), 'keypath and recursive can\'t be used together'
+ def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = False):
+ assert not (keypath and recursive), 'keypath and recursive can\'t be used together'
if keypath:
keys, agent = keypath.split('.'), self
for key in keys:
@@ -189,15 +162,13 @@ def reset(self,
def __repr__(self):
def _rcsv_repr(agent, n_indent=1):
- res = agent.__class__.__name__ + (f"(name='{agent.name}')"
- if agent.name else '')
+ res = agent.__class__.__name__ + (f"(name='{agent.name}')" if agent.name else '')
modules = [
f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
for name, agent in getattr(agent, '_agents', {}).items()
]
if modules:
- res += '(\n' + '\n'.join(
- modules) + f'\n{(n_indent - 1) * " "})'
+ res += '(\n' + '\n'.join(modules) + f'\n{(n_indent - 1) * " "})'
elif not res.endswith(')'):
res += '()'
return res
@@ -205,28 +176,18 @@ def _rcsv_repr(agent, n_indent=1):
return _rcsv_repr(self)
-class AsyncAgent(Agent):
+class AsyncAgentMixin:
- async def __call__(self,
- *message: AgentMessage | List[AgentMessage],
- session_id=0,
- **kwargs) -> AgentMessage:
- message = [
- AgentMessage(sender='user', content=m)
- if isinstance(m, str) else copy.deepcopy(m) for m in message
- ]
+ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
+ message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
for hook in self._hooks.values():
result = hook.before_agent(self, message, session_id)
if result:
message = result
self.update_memory(message, session_id=session_id)
- response_message = await self.forward(
- *message, session_id=session_id, **kwargs)
+ response_message = await self.forward(*message, session_id=session_id, **kwargs)
if not isinstance(response_message, AgentMessage):
- response_message = AgentMessage(
- sender=self.name,
- content=response_message,
- )
+ response_message = AgentMessage(sender=self.name, content=response_message)
self.update_memory(response_message, session_id=session_id)
response_message = copy.deepcopy(response_message)
for hook in self._hooks.values():
@@ -235,40 +196,133 @@ async def __call__(self,
response_message = result
return response_message
- async def forward(self,
- *message: AgentMessage,
- session_id=0,
- **kwargs) -> Union[AgentMessage, str]:
+ async def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]:
formatted_messages = self.aggregator.aggregate(
- self.memory.get(session_id),
- self.name,
- self.output_format,
- self.template,
+ self.memory.get(session_id), self.name, self.output_format, self.template
)
- llm_response = await self.llm.chat(formatted_messages, session_id,
- **kwargs)
+ llm_response = await self.llm.chat(formatted_messages, session_id, **kwargs)
if self.output_format:
- formatted_messages = self.output_format.parse_response(
- llm_response)
- return AgentMessage(
- sender=self.name,
- content=llm_response,
- formatted=formatted_messages,
- )
+ formatted_messages = self.output_format.parse_response(llm_response)
+ return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages)
return llm_response
+class AsyncAgent(AsyncAgentMixin, Agent):
+ """Asynchronous variant of the Agent class"""
+
+ pass
+
+
+class StreamingAgentMixin:
+ """Component that makes agent calling output a streaming response."""
+
+ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Generator[AgentMessage, None, None]:
+ message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
+ for hook in self._hooks.values():
+ result = hook.before_agent(self, message, session_id)
+ if result:
+ message = result
+ self.update_memory(message, session_id=session_id)
+ response_message = AgentMessage(sender=self.name, content="")
+ for response_message in self.forward(*message, session_id=session_id, **kwargs):
+ if not isinstance(response_message, AgentMessage):
+ model_state, response = response_message
+ response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state)
+ yield response_message.model_copy()
+ self.update_memory(response_message, session_id=session_id)
+ response_message = copy.deepcopy(response_message)
+ for hook in self._hooks.values():
+ result = hook.after_agent(self, response_message, session_id)
+ if result:
+ response_message = result
+ yield response_message
+
+ def forward(
+ self, *message: AgentMessage, session_id=0, **kwargs
+ ) -> Generator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None, None]:
+ formatted_messages = self.aggregator.aggregate(
+ self.memory.get(session_id), self.name, self.output_format, self.template
+ )
+ for model_state, response, *_ in self.llm.stream_chat(formatted_messages, session_id=session_id, **kwargs):
+ yield (
+ AgentMessage(
+ sender=self.name,
+ content=response,
+ formatted=self.output_format.parse_response(response),
+ stream_state=model_state,
+ )
+ if self.output_format
+ else (model_state, response)
+ )
+
+
+class AsyncStreamingAgentMixin:
+ """Component that makes asynchronous agent calling output a streaming response."""
+
+ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AsyncGenerator[AgentMessage, None]:
+ message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message]
+ for hook in self._hooks.values():
+ result = hook.before_agent(self, message, session_id)
+ if result:
+ message = result
+ self.update_memory(message, session_id=session_id)
+ response_message = AgentMessage(sender=self.name, content="")
+ async for response_message in self.forward(*message, session_id=session_id, **kwargs):
+ if not isinstance(response_message, AgentMessage):
+ model_state, response = response_message
+ response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state)
+ yield response_message.model_copy()
+ self.update_memory(response_message, session_id=session_id)
+ response_message = copy.deepcopy(response_message)
+ for hook in self._hooks.values():
+ result = hook.after_agent(self, response_message, session_id)
+ if result:
+ response_message = result
+ yield response_message
+
+ async def forward(
+ self, *message: AgentMessage, session_id=0, **kwargs
+ ) -> AsyncGenerator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None]:
+ formatted_messages = self.aggregator.aggregate(
+ self.memory.get(session_id), self.name, self.output_format, self.template
+ )
+ async for model_state, response, *_ in self.llm.stream_chat(
+ formatted_messages, session_id=session_id, **kwargs
+ ):
+ yield (
+ AgentMessage(
+ sender=self.name,
+ content=response,
+ formatted=self.output_format.parse_response(response),
+ stream_state=model_state,
+ )
+ if self.output_format
+ else (model_state, response)
+ )
+
+
+class StreamingAgent(StreamingAgentMixin, Agent):
+ """Streaming variant of the Agent class"""
+
+ pass
+
+
+class AsyncStreamingAgent(AsyncStreamingAgentMixin, Agent):
+ """Streaming variant of the AsyncAgent class"""
+
+ pass
+
+
class Sequential(Agent):
- """Sequential is an agent container that forwards messages to each agent
+ """Sequential is an agent container that forwards messages to each agent
in the order they are added."""
- def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
+ def __init__(self, *agents: Union[Agent, Iterable], **kwargs):
super().__init__(**kwargs)
self._agents = OrderedDict()
if not agents:
raise ValueError('At least one agent should be provided')
- if isinstance(agents[0],
- Iterable) and not isinstance(agents[0], Agent):
+ if isinstance(agents[0], Iterable) and not isinstance(agents[0], Agent):
if not agents[0]:
raise ValueError('At least one agent should be provided')
agents = agents[0]
@@ -279,17 +333,11 @@ def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
key, agent = agent
self.add_agent(key, agent)
- def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]):
- assert isinstance(
- agent, (Agent, AsyncAgent
- )), f'{type(agent)} is not an Agent or AsyncAgent subclass'
+ def add_agent(self, name: str, agent: Agent):
+ assert isinstance(agent, Agent), f'{type(agent)} is not an Agent subclass'
self._agents[str(name)] = agent
- def forward(self,
- *message: AgentMessage,
- session_id=0,
- exit_at: Optional[int] = None,
- **kwargs) -> AgentMessage:
+ def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs) -> AgentMessage:
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
if exit_at is None:
exit_at = len(self) - 1
@@ -297,7 +345,7 @@ def forward(self,
for _ in range(exit_at + 1):
agent = next(iterator)
if isinstance(message, AgentMessage):
- message = (message, )
+ message = (message,)
message = agent(*message, session_id=session_id, **kwargs)
return message
@@ -311,13 +359,11 @@ def __len__(self):
return len(self._agents)
-class AsyncSequential(Sequential, AsyncAgent):
+class AsyncSequential(AsyncAgentMixin, Sequential):
- async def forward(self,
- *message: AgentMessage,
- session_id=0,
- exit_at: Optional[int] = None,
- **kwargs) -> AgentMessage:
+ async def forward(
+ self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs
+ ) -> AgentMessage:
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
if exit_at is None:
exit_at = len(self) - 1
@@ -325,11 +371,43 @@ async def forward(self,
for _ in range(exit_at + 1):
agent = next(iterator)
if isinstance(message, AgentMessage):
- message = (message, )
+ message = (message,)
message = await agent(*message, session_id=session_id, **kwargs)
return message
+class StreamingSequential(StreamingAgentMixin, Sequential):
+ """Streaming variant of the Sequential class"""
+
+ def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs):
+ assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
+ if exit_at is None:
+ exit_at = len(self) - 1
+ iterator = chain.from_iterable(repeat(self._agents.values()))
+ for _ in range(exit_at + 1):
+ agent = next(iterator)
+ if isinstance(message, AgentMessage):
+ message = (message,)
+ for message in agent(*message, session_id=session_id, **kwargs):
+ yield message
+
+
+class AsyncStreamingSequential(AsyncStreamingAgentMixin, Sequential):
+ """Streaming variant of the AsyncSequential class"""
+
+ async def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs):
+ assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
+ if exit_at is None:
+ exit_at = len(self) - 1
+ iterator = chain.from_iterable(repeat(self._agents.values()))
+ for _ in range(exit_at + 1):
+ agent = next(iterator)
+ if isinstance(message, AgentMessage):
+ message = (message,)
+ async for message in agent(*message, session_id=session_id, **kwargs):
+ yield message
+
+
class AgentContainerMixin:
def __init_subclass__(cls):
@@ -349,33 +427,28 @@ def _backup(d):
ret = func(self, *args, **kwargs)
agents = OrderedDict()
- for k, item in (self.data.items() if isinstance(
- self.data, abc.Mapping) else enumerate(self.data)):
- if isinstance(self.data,
- abc.Mapping) and not isinstance(k, str):
+ for k, item in self.data.items() if isinstance(self.data, abc.Mapping) else enumerate(self.data):
+ if isinstance(self.data, abc.Mapping) and not isinstance(k, str):
_backup(data)
- raise KeyError(
- f'agent name should be a string, got {type(k)}')
+ raise KeyError(f'agent name should be a string, got {type(k)}')
if isinstance(k, str) and '.' in k:
_backup(data)
- raise KeyError(
- f'agent name can\'t contain ".", got {k}')
- if not isinstance(item, (Agent, AsyncAgent)):
+ raise KeyError(f'agent name can\'t contain ".", got {k}')
+ if not isinstance(item, Agent):
_backup(data)
- raise TypeError(
- f'{type(item)} is not an Agent or AsyncAgent subclass'
- )
+ raise TypeError(f'{type(item)} is not an Agent subclass')
agents[str(k)] = item
self._agents = agents
return ret
return wrapped_func
+ # fmt: off
for method in [
- 'append', 'sort', 'reverse', 'pop', 'clear', 'update',
- 'insert', 'extend', 'remove', '__init__', '__setitem__',
- '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
- '__imul__', '__rmul__'
+ 'append', 'sort', 'reverse', 'pop', 'clear', 'update',
+ 'insert', 'extend', 'remove', '__init__', '__setitem__',
+ '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
+ '__imul__', '__rmul__'
]:
if hasattr(cls, method):
setattr(cls, method, wrap_api(getattr(cls, method)))
@@ -383,8 +456,7 @@ def _backup(d):
class AgentList(Agent, UserList, AgentContainerMixin):
- def __init__(self,
- agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
+ def __init__(self, agents: Optional[Iterable[Agent]] = None):
Agent.__init__(self, memory=None)
UserList.__init__(self, agents)
self.name = None
@@ -392,9 +464,7 @@ def __init__(self,
class AgentDict(Agent, UserDict, AgentContainerMixin):
- def __init__(self,
- agents: Optional[Mapping[str, Union[Agent,
- AsyncAgent]]] = None):
+ def __init__(self, agents: Optional[Mapping[str, Agent]] = None):
Agent.__init__(self, memory=None)
UserDict.__init__(self, agents)
self.name = None
diff --git a/lagent/agents/react.py b/lagent/agents/react.py
index 41d2414..4a942a0 100644
--- a/lagent/agents/react.py
+++ b/lagent/agents/react.py
@@ -12,7 +12,6 @@
from lagent.prompts.parsers.json_parser import JSONParser
from lagent.prompts.prompt_template import PromptTemplate
from lagent.schema import AgentMessage
-from lagent.utils import create_object
select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
{action_info}
@@ -28,96 +27,88 @@
class ReAct(Agent):
- def __init__(self,
- llm: Union[BaseLLM, Dict],
- actions: Union[BaseAction, List[BaseAction]],
- template: Union[PromptTemplate, str] = None,
- memory: Dict = dict(type=Memory),
- output_format: Dict = dict(type=JSONParser),
- aggregator: Dict = dict(type=DefaultAggregator),
- hooks: List = [dict(type=ActionPreprocessor)],
- finish_condition: Callable[[AgentMessage], bool] = lambda m:
- 'conclusion' in m.content or 'conclusion' in m.formatted,
- max_turn: int = 5,
- **kwargs):
+ def __init__(
+ self,
+ llm: Union[BaseLLM, Dict],
+ actions: Union[BaseAction, List[BaseAction]],
+ template: Union[PromptTemplate, str] = None,
+ memory: Dict = dict(type=Memory),
+ output_format: Dict = dict(type=JSONParser),
+ aggregator: Dict = dict(type=DefaultAggregator),
+ hooks: List = [dict(type=ActionPreprocessor)],
+ finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content
+ or 'conclusion' in m.formatted,
+ max_turn: int = 5,
+ **kwargs
+ ):
self.max_turn = max_turn
self.finish_condition = finish_condition
- actions = dict(
- type=ActionExecutor,
- actions=actions,
- hooks=hooks,
- )
- self.actions: ActionExecutor = create_object(actions)
- select_agent = dict(
- type=Agent,
+ self.actions = ActionExecutor(actions=actions, hooks=hooks)
+ self.select_agent = Agent(
llm=llm,
template=template.format(
- action_info=json.dumps(self.actions.description()),
- output_format=output_format.format_instruction()),
+ action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction()
+ ),
output_format=output_format,
memory=memory,
aggregator=aggregator,
hooks=hooks,
)
- self.select_agent = create_object(select_agent)
super().__init__(**kwargs)
- def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
+ def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
for _ in range(self.max_turn):
- message = self.select_agent(message)
+ message = self.select_agent(message, session_id=session_id, **kwargs)
if self.finish_condition(message):
return message
- message = self.actions(message)
+ message = self.actions(message, session_id=session_id)
return message
class AsyncReAct(AsyncAgent):
- def __init__(self,
- llm: Union[BaseLLM, Dict],
- actions: Union[BaseAction, List[BaseAction]],
- template: Union[PromptTemplate, str] = None,
- memory: Dict = dict(type=Memory),
- output_format: Dict = dict(type=JSONParser),
- aggregator: Dict = dict(type=DefaultAggregator),
- hooks: List = [dict(type=ActionPreprocessor)],
- finish_condition: Callable[[AgentMessage], bool] = lambda m:
- 'conclusion' in m.content or 'conclusion' in m.formatted,
- max_turn: int = 5,
- **kwargs):
+ def __init__(
+ self,
+ llm: Union[BaseLLM, Dict],
+ actions: Union[BaseAction, List[BaseAction]],
+ template: Union[PromptTemplate, str] = None,
+ memory: Dict = dict(type=Memory),
+ output_format: Dict = dict(type=JSONParser),
+ aggregator: Dict = dict(type=DefaultAggregator),
+ hooks: List = [dict(type=ActionPreprocessor)],
+ finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content
+ or 'conclusion' in m.formatted,
+ max_turn: int = 5,
+ **kwargs
+ ):
self.max_turn = max_turn
self.finish_condition = finish_condition
- actions = dict(
- type=AsyncActionExecutor,
- actions=actions,
- hooks=hooks,
- )
- self.actions: AsyncActionExecutor = create_object(actions)
- select_agent = dict(
- type=AsyncAgent,
+ self.actions = AsyncActionExecutor(actions=actions, hooks=hooks)
+ self.select_agent = AsyncAgent(
llm=llm,
template=template.format(
- action_info=json.dumps(self.actions.description()),
- output_format=output_format.format_instruction()),
+ action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction()
+ ),
output_format=output_format,
memory=memory,
aggregator=aggregator,
hooks=hooks,
)
- self.select_agent = create_object(select_agent)
super().__init__(**kwargs)
- async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
+ async def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
for _ in range(self.max_turn):
- message = await self.select_agent(message)
+ message = await self.select_agent(message, session_id=session_id, **kwargs)
if self.finish_condition(message):
return message
- message = await self.actions(message)
+ message = await self.actions(message, session_id=session_id)
return message
if __name__ == '__main__':
- from lagent.llms import GPTAPI
+ import asyncio
+
+ from lagent.llms import GPTAPI, AsyncGPTAPI
class ActionCall(BaseModel):
name: str = Field(description='调用的函数名称')
@@ -125,37 +116,49 @@ class ActionCall(BaseModel):
class ActionFormat(BaseModel):
thought_process: str = Field(
- description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
+ description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。'
+ )
action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。')
class FinishFormat(BaseModel):
thought_process: str = Field(
- description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
+ description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。'
+ )
conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
prompt_template = PromptTemplate(select_action_template)
- output_format = JSONParser(
- output_format_template,
- function_format=ActionFormat,
- finish_format=FinishFormat)
-
- llm = dict(
- type=GPTAPI,
- model_type='gpt-4o-2024-05-13',
- key=None,
- max_new_tokens=4096,
- proxies=dict(),
- retry=1000)
+ output_format = JSONParser(output_format_template, function_format=ActionFormat, finish_format=FinishFormat)
agent = ReAct(
- llm=llm,
+ llm=dict(
+ type=GPTAPI,
+ model_type='gpt-4o-2024-05-13',
+ max_new_tokens=4096,
+ proxies=dict(),
+ retry=1000,
+ ),
template=prompt_template,
output_format=output_format,
- aggregator=dict(type='DefaultAggregator'),
- actions=[dict(type='PythonInterpreter')],
+ aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'),
+ actions=[dict(type='lagent.actions.PythonInterpreter')],
)
- response = agent(
- AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
+ response = agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
print(response)
response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
print(response)
+
+ async_agent = AsyncReAct(
+ llm=dict(
+ type=AsyncGPTAPI,
+ model_type='gpt-4o-2024-05-13',
+ max_new_tokens=4096,
+ proxies=dict(),
+ retry=1000,
+ ),
+ template=prompt_template,
+ output_format=output_format,
+ aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'),
+ actions=[dict(type='lagent.actions.AsyncPythonInterpreter')],
+ )
+ response = asyncio.run(async_agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')))
+ print(async_agent.state_dict())
diff --git a/lagent/agents/stream.py b/lagent/agents/stream.py
index 512250f..ba79fdc 100644
--- a/lagent/agents/stream.py
+++ b/lagent/agents/stream.py
@@ -15,22 +15,27 @@
API_PREFIX = (
"This is the subfunction for tool '{tool_name}', you can use this tool. "
- 'The description of this function is: \n{description}')
+ 'The description of this function is: \n{description}'
+)
-META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用')
+META_CN = '当开启工具以及代码时,根据需求选择合适的工具进行调用'
-INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
- '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
- '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
- '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
- '文本处理和分析(比如文本解析和自然语言处理),'
- '机器学习和数据科学(用于展示模型训练和数据可视化),'
- '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。')
+INTERPRETER_CN = (
+ '你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
+ '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
+ '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
+ '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
+ '文本处理和分析(比如文本解析和自然语言处理),'
+ '机器学习和数据科学(用于展示模型训练和数据可视化),'
+ '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。'
+)
-PLUGIN_CN = ('你可以使用如下工具:'
- '\n{prompt}\n'
- '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
- '同时注意你可以使用的工具,不要随意捏造!')
+PLUGIN_CN = (
+ '你可以使用如下工具:'
+ '\n{prompt}\n'
+ '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
+ '同时注意你可以使用的工具,不要随意捏造!'
+)
def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
@@ -41,19 +46,15 @@ def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
if action.is_toolkit:
for api in action_desc['api_list']:
api['name'] = f"{action.name}.{api['name']}"
- api['description'] = api_desc_template.format(
- tool_name=action.name, description=api['description'])
- api['parameters'] = [
- param for param in api['parameters']
- if param['name'] in api['required']
- ]
+ api['description'] = api_desc_template.format(tool_name=action.name, description=api['description'])
+ api['parameters'] = [param for param in api['parameters'] if param['name'] in api['required']]
plugin_descriptions.append(api)
else:
action_desc['description'] = api_desc_template.format(
- tool_name=action.name, description=action_desc['description'])
+ tool_name=action.name, description=action_desc['description']
+ )
action_desc['parameters'] = [
- param for param in action_desc['parameters']
- if param['name'] in action_desc['required']
+ param for param in action_desc['parameters'] if param['name'] in action_desc['required']
]
plugin_descriptions.append(action_desc)
return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
@@ -76,17 +77,15 @@ def __init__(
parsers=[
dict(type=PluginParser, template=PLUGIN_CN),
dict(type=InterpreterParser, template=INTERPRETER_CN),
- ]),
+ ],
+ ),
aggregator: Dict = dict(type=InternLMToolAggregator),
action_hooks: List = [dict(type=InternLMActionProcessor)],
- finish_condition: Callable[
- [AgentMessage],
- bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
+ finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
max_turn: int = 4,
**kwargs,
):
- agent = dict(
- type=self._INTERNAL_AGENT_CLS,
+ self.agent = self._INTERNAL_AGENT_CLS(
llm=llm,
template=template,
output_format=output_format,
@@ -94,22 +93,18 @@ def __init__(
aggregator=aggregator,
hooks=kwargs.pop('hooks', None),
)
- self.agent = create_object(agent)
- self.plugin_executor = plugins and ActionExecutor(
- plugins, hooks=action_hooks)
- self.interpreter_executor = interpreter and ActionExecutor(
- interpreter, hooks=action_hooks)
+ self.plugin_executor = plugins and ActionExecutor(plugins, hooks=action_hooks)
+ self.interpreter_executor = interpreter and ActionExecutor(interpreter, hooks=action_hooks)
if not (self.plugin_executor or self.interpreter_executor):
warnings.warn(
'Neither plugin nor interpreter executor is initialized. '
- 'An exception will be thrown when the agent call a tool.')
+ 'An exception will be thrown when the agent call a tool.'
+ )
self.finish_condition = finish_condition
self.max_turn = max_turn
super().__init__(**kwargs)
def forward(self, message: AgentMessage, session_id=0, **kwargs):
- if isinstance(message, str):
- message = AgentMessage(sender='user', content=message)
for _ in range(self.max_turn):
message = self.agent(message, session_id=session_id, **kwargs)
assert isinstance(message.formatted, dict)
@@ -127,15 +122,10 @@ def get_steps(self, session_id=0):
steps, tool_type = [], None
for msg in self.agent.memory.get_memory(session_id):
if msg.sender == self.agent.name:
- steps.append(
- dict(role='thought', content=msg.formatted['thought']))
+ steps.append(dict(role='thought', content=msg.formatted['thought']))
if msg.formatted['tool_type']:
tool_type = msg.formatted['tool_type']
- steps.append(
- dict(
- role='tool',
- content=msg.formatted['action'],
- name=tool_type))
+ steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type))
elif msg.sender != 'user':
feedback = dict(role='environment', content=msg.content)
if tool_type:
@@ -149,23 +139,22 @@ class MathCoder(AgentForInternLM):
def __init__(
self,
llm: Union[BaseLLM, Dict],
- interpreter: dict = dict(
- type=IPythonInteractive, timeout=20, max_out_len=8192),
+ interpreter: dict = dict(type=IPythonInteractive, timeout=20, max_out_len=8192),
template: Union[str, dict, List[dict]] = None,
memory: Dict = dict(type=Memory),
output_format: Dict = dict(
type=InterpreterParser,
- template=
- ('Integrate step-by-step reasoning and Python code to solve math problems '
- 'using the following guidelines:\n'
- '- Analyze the question and write jupyter code to solve the problem;\n'
- r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
- 'units. \n')),
+ template=(
+ 'Integrate step-by-step reasoning and Python code to solve math problems '
+ 'using the following guidelines:\n'
+ '- Analyze the question and write jupyter code to solve the problem;\n'
+ r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
+ 'units. \n'
+ ),
+ ),
aggregator: Dict = dict(type=InternLMToolAggregator),
action_hooks: List = [dict(type=InternLMActionProcessor)],
- finish_condition: Callable[
- [AgentMessage],
- bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
+ finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
max_turn: int = 6,
**kwargs,
):
@@ -180,7 +169,8 @@ def __init__(
action_hooks=action_hooks,
finish_condition=finish_condition,
max_turn=max_turn,
- **kwargs)
+ **kwargs,
+ )
class AsyncAgentForInternLM(AsyncAgent):
@@ -200,17 +190,15 @@ def __init__(
parsers=[
dict(type=PluginParser, template=PLUGIN_CN),
dict(type=InterpreterParser, template=INTERPRETER_CN),
- ]),
+ ],
+ ),
aggregator: Dict = dict(type=InternLMToolAggregator),
action_hooks: List = [dict(type=InternLMActionProcessor)],
- finish_condition: Callable[
- [AgentMessage],
- bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
+ finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
max_turn: int = 4,
**kwargs,
):
- agent = dict(
- type=self._INTERNAL_AGENT_CLS,
+ self.agent = self._INTERNAL_AGENT_CLS(
llm=llm,
template=template,
output_format=output_format,
@@ -218,25 +206,20 @@ def __init__(
aggregator=aggregator,
hooks=kwargs.pop('hooks', None),
)
- self.agent = create_object(agent)
- self.plugin_executor = plugins and AsyncActionExecutor(
- plugins, hooks=action_hooks)
- self.interpreter_executor = interpreter and AsyncActionExecutor(
- interpreter, hooks=action_hooks)
+ self.plugin_executor = plugins and AsyncActionExecutor(plugins, hooks=action_hooks)
+ self.interpreter_executor = interpreter and AsyncActionExecutor(interpreter, hooks=action_hooks)
if not (self.plugin_executor or self.interpreter_executor):
warnings.warn(
'Neither plugin nor interpreter executor is initialized. '
- 'An exception will be thrown when the agent call a tool.')
+ 'An exception will be thrown when the agent call a tool.'
+ )
self.finish_condition = finish_condition
self.max_turn = max_turn
super().__init__(**kwargs)
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
- if isinstance(message, str):
- message = AgentMessage(sender='user', content=message)
for _ in range(self.max_turn):
- message = await self.agent(
- message, session_id=session_id, **kwargs)
+ message = await self.agent(message, session_id=session_id, **kwargs)
assert isinstance(message.formatted, dict)
if self.finish_condition(message):
return message
@@ -252,15 +235,10 @@ def get_steps(self, session_id=0):
steps, tool_type = [], None
for msg in self.agent.memory.get_memory(session_id):
if msg.sender == self.agent.name:
- steps.append(
- dict(role='thought', content=msg.formatted['thought']))
+ steps.append(dict(role='thought', content=msg.formatted['thought']))
if msg.formatted['tool_type']:
tool_type = msg.formatted['tool_type']
- steps.append(
- dict(
- role='tool',
- content=msg.formatted['action'],
- name=tool_type))
+ steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type))
elif msg.sender != 'user':
feedback = dict(role='environment', content=msg.content)
if tool_type:
@@ -279,17 +257,17 @@ def __init__(
memory: Dict = dict(type=Memory),
output_format: Dict = dict(
type=InterpreterParser,
- template=
- ('Integrate step-by-step reasoning and Python code to solve math problems '
- 'using the following guidelines:\n'
- '- Analyze the question and write jupyter code to solve the problem;\n'
- r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
- 'units. \n')),
+ template=(
+ 'Integrate step-by-step reasoning and Python code to solve math problems '
+ 'using the following guidelines:\n'
+ '- Analyze the question and write jupyter code to solve the problem;\n'
+ r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
+ 'units. \n'
+ ),
+ ),
aggregator: Dict = dict(type=InternLMToolAggregator),
action_hooks: List = [dict(type=InternLMActionProcessor)],
- finish_condition: Callable[
- [AgentMessage],
- bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
+ finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
max_turn: int = 6,
**kwargs,
):
@@ -304,13 +282,13 @@ def __init__(
action_hooks=action_hooks,
finish_condition=finish_condition,
max_turn=max_turn,
- **kwargs)
+ **kwargs,
+ )
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
try:
return await super().forward(message, session_id, **kwargs)
finally:
- interpreter = next(
- iter(self.interpreter_executor.actions.values()))
+ interpreter = next(iter(self.interpreter_executor.actions.values()))
if interpreter.name == 'AsyncIPythonInterpreter':
await interpreter.close_session(session_id)
diff --git a/lagent/prompts/parsers/str_parser.py b/lagent/prompts/parsers/str_parser.py
index 6af7aa6..be997bc 100644
--- a/lagent/prompts/parsers/str_parser.py
+++ b/lagent/prompts/parsers/str_parser.py
@@ -1,3 +1,4 @@
+import string
from typing import Any
@@ -8,14 +9,17 @@ def __init__(
template: str = '',
**format_field,
):
+ fields = {item[1] for item in string.Formatter().parse(template) if item[1] is not None}
+ if not fields.issubset(format_field.keys()):
+ raise ValueError(
+ 'not all required fields of "template" are provided, missing '
+ f'{fields - format_field.keys()}. Please pass them as keyword arguments.'
+ )
self.template = template
self.format_field = format_field
def format_instruction(self) -> Any:
- format_data = {
- key: self.format_to_string(value)
- for key, value in self.format_field.items()
- }
+ format_data = {key: self.format_to_string(value) for key, value in self.format_field.items()}
return self.template.format(**format_data)
def format_to_string(self, format_model: Any) -> str:
diff --git a/lagent/prompts/parsers/tool_parser.py b/lagent/prompts/parsers/tool_parser.py
index 5343312..a8ffea3 100644
--- a/lagent/prompts/parsers/tool_parser.py
+++ b/lagent/prompts/parsers/tool_parser.py
@@ -23,29 +23,24 @@ class ToolStatusCode(IntEnum):
class ToolParser(StrParser):
- def __init__(self,
- tool_type: str,
- template: str = '',
- begin: str = '\n',
- end: str = '\n',
- validate: Callable[[str], Any] = None,
- **kwargs):
+ def __init__(
+ self,
+ tool_type: str,
+ template: str = '',
+ begin: str = '\n',
+ end: str = '\n',
+ validate: Callable[[str], Any] = None,
+ **kwargs
+ ):
super().__init__(template, begin=begin, end=end, **kwargs)
self.template = template
self.tool_type = tool_type
- # self.pattern = re.compile(
- # '(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)),
- # re.DOTALL)
- self.validate = load_class_from_string(validate) if isinstance(
- validate, str) else validate
+ # self.pattern = re.compile('(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)), re.DOTALL)
+ self.validate = load_class_from_string(validate) if isinstance(validate, str) else validate
def parse_response(self, data: str) -> dict:
if self.format_field['begin'] not in data:
- return dict(
- tool_type=None,
- thought=data,
- action=None,
- status=ToolStatusCode.NO_TOOL)
+ return dict(tool_type=None, thought=data, action=None, status=ToolStatusCode.NO_TOOL)
thought, action, *_ = data.split(self.format_field["begin"])
action = action.split(self.format_field['end'])[0]
status = ToolStatusCode.VALID_TOOL
@@ -54,11 +49,7 @@ def parse_response(self, data: str) -> dict:
action = self.validate(action)
except Exception:
status = ToolStatusCode.PARSING_ERROR
- return dict(
- tool_type=self.tool_type,
- thought=thought,
- action=action,
- status=status)
+ return dict(tool_type=self.tool_type, thought=thought, action=action, status=status)
def format_response(self, parsed: dict) -> str:
if parsed['action'] is None:
@@ -68,41 +59,40 @@ def format_response(self, parsed: dict) -> str:
action = json.dumps(parsed['action'], ensure_ascii=False)
else:
action = str(parsed['action'])
- return parsed['thought'] + self.format_field[
- 'begin'] + action + self.format_field['end']
+ return parsed['thought'] + self.format_field['begin'] + action + self.format_field['end']
class InterpreterParser(ToolParser):
- def __init__(self,
- tool_type: str = 'interpreter',
- template: str = '',
- begin: str = '<|action_start|><|interpreter|>\n',
- end: str = '<|action_end|>\n',
- validate: Callable[[str], Any] = None,
- **kwargs):
+ def __init__(
+ self,
+ tool_type: str = 'interpreter',
+ template: str = '',
+ begin: str = '<|action_start|><|interpreter|>\n',
+ end: str = '<|action_end|>\n',
+ validate: Callable[[str], Any] = None,
+ **kwargs
+ ):
super().__init__(tool_type, template, begin, end, validate, **kwargs)
class PluginParser(ToolParser):
- def __init__(self,
- tool_type: str = 'plugin',
- template: str = '',
- begin: str = '<|action_start|><|plugin|>\n',
- end: str = '<|action_end|>\n',
- validate: Callable[[str], Any] = default_plugin_validate,
- **kwargs):
+ def __init__(
+ self,
+ tool_type: str = 'plugin',
+ template: str = '',
+ begin: str = '<|action_start|><|plugin|>\n',
+ end: str = '<|action_end|>\n',
+ validate: Callable[[str], Any] = default_plugin_validate,
+ **kwargs
+ ):
super().__init__(tool_type, template, begin, end, validate, **kwargs)
class MixedToolParser(StrParser):
- def __init__(self,
- tool_type: Optional[str] = None,
- template='',
- parsers: List[ToolParser] = None,
- **format_field):
+ def __init__(self, tool_type: Optional[str] = None, template='', parsers: List[ToolParser] = None, **format_field):
self.parsers = {}
self.tool_type = tool_type
for parser in parsers or []:
@@ -125,11 +115,7 @@ def format_instruction(self) -> List[dict]:
return inst
def parse_response(self, data: str) -> dict:
- res = dict(
- tool_type=None,
- thought=data,
- action=None,
- status=ToolStatusCode.NO_TOOL)
+ res = dict(tool_type=None, thought=data, action=None, status=ToolStatusCode.NO_TOOL)
for name, parser in self.parsers.items():
res = parser.parse_response(data)
if res['tool_type'] == name: