forked from ajar98/vocode-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
speller_agent.py
80 lines (61 loc) · 3.13 KB
/
speller_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import typing
from typing import Optional, Tuple
from vocode.streaming.agent.abstract_factory import AbstractAgentFactory
from vocode.streaming.agent.base_agent import BaseAgent, RespondAgent
from vocode.streaming.agent.chat_gpt_agent import ChatGPTAgent
from vocode.streaming.models.agent import AgentConfig, AgentType, ChatGPTAgentConfig
class SpellerAgentConfig(AgentConfig, type="agent_speller"):
"""Configuration for SpellerAgent. Inherits from AgentConfig."""
pass
class SpellerAgent(RespondAgent[SpellerAgentConfig]):
"""SpellerAgent class. Inherits from RespondAgent.
This agent takes human input and returns it with spaces between each character.
"""
def __init__(self, agent_config: SpellerAgentConfig):
"""Initializes SpellerAgent with the given configuration.
Args:
agent_config (SpellerAgentConfig): The configuration for this agent.
"""
super().__init__(agent_config=agent_config)
async def respond(
self,
human_input: str,
conversation_id: str,
is_interrupt: bool = False,
) -> Tuple[Optional[str], bool]:
"""Generates a response from the SpellerAgent.
The response is generated by joining each character in the human input with a space.
The second element of the tuple indicates whether the agent should stop (False means it should not stop).
Args:
human_input (str): The input from the human user.
conversation_id (str): The ID of the conversation.
is_interrupt (bool): A flag indicating whether the agent was interrupted.
Returns:
Tuple[Optional[str], bool]: The generated response and a flag indicating whether to stop.
"""
return "".join(c + " " for c in human_input), False
class SpellerAgentFactory(AbstractAgentFactory):
"""Factory class for creating agents based on the provided agent configuration."""
def create_agent(self, agent_config: AgentConfig) -> BaseAgent:
"""Creates an agent based on the provided agent configuration.
Args:
agent_config (AgentConfig): The configuration for the agent to be created.
Returns:
BaseAgent: The created agent.
Raises:
Exception: If the agent configuration type is not recognized.
"""
# If the agent configuration type is CHAT_GPT, create a ChatGPTAgent.
if agent_config.type == AgentType.CHAT_GPT:
return ChatGPTAgent(
# Cast the agent configuration to ChatGPTAgentConfig as we are sure about the type here.
agent_config=typing.cast(ChatGPTAgentConfig, agent_config)
)
# If the agent configuration type is agent_speller, create a SpellerAgent.
elif agent_config.type == "agent_speller":
return SpellerAgent(
# Cast the agent configuration to SpellerAgentConfig as we are sure about the type here.
agent_config=typing.cast(SpellerAgentConfig, agent_config)
)
# If the agent configuration type is not recognized, raise an exception.
raise Exception("Invalid agent config")