From c1a2f250a569f2a1b311eb20301bda7c76576a32 Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 6 Dec 2024 08:31:38 -0500 Subject: [PATCH] refactor: shifting memory mechanism further down the call stack --- agent_gateway/gateway/gateway.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/agent_gateway/gateway/gateway.py b/agent_gateway/gateway/gateway.py index 01e52a7..101a02e 100644 --- a/agent_gateway/gateway/gateway.py +++ b/agent_gateway/gateway/gateway.py @@ -137,6 +137,7 @@ def __init__( max_retries: int = 2, planner_llm: str = "mistral-large2", agent_llm: str = "mistral-large2", + memory: bool = True, planner_example_prompt: str = SNOWFLAKE_PLANNER_PROMPT, planner_example_prompt_replan: Optional[str] = None, planner_stop: Optional[list[str]] = [END_OF_PLAN], @@ -155,6 +156,7 @@ def __init__( max_retries: Maximum number of replans to do. Defaults to 2. planner_llm: Name of Snowflake Cortex LLM to use for planning. agent_llm: Name of Snowflake Cortex LLM to use for planning. + memory: Boolean to turn on memory mechanism or not. Defaults to False. planner_example_prompt: Example prompt for planning. Defaults to SNOWFLAKE_PLANNER_PROMPT. planner_example_prompt_replan: Example prompt for replanning. Assign this if you want to use different example prompt for replanning. @@ -192,7 +194,9 @@ def __init__( self.max_retries = max_retries # basic memory - self.memory = [] + self.memory = memory + if self.memory: + self.memory_context = [] # callbacks self.planner_callback = None @@ -271,6 +275,7 @@ def _extract_replan_message(self, raw_answer): "rephrasing your request or validate that the provided tools contain " "sufficient information." ) + def _generate_context_for_replanner( self, tasks: Mapping[int, Task], fusion_thought: str ) -> str: @@ -343,9 +348,6 @@ def __call__(self, input: str): result = [] error = [] - if len(self.memory) >= 1: - input = f"My previous question/answer was: {self.memory[0]}\n. If needed, use that context and this {input} to answer my question. Otherwise just give me an answer to: {input} " - thread = threading.Thread(target=self.run_async, args=(input, result, error)) thread.start() thread.join() @@ -356,10 +358,6 @@ def __call__(self, input: str): if not result: raise AgentGatewayError("Unable to retrieve response. Result is empty.") - max_memory = 3 # TODO consider exposing this to users - if len(self.memory) <= max_memory: - self.memory.append({"Question:": input, "Answer": result[0]}) - return result[0] def handle_exception(self, loop, context): @@ -396,12 +394,17 @@ def run_async(self, input, result, error): async def acall( self, input: str, - # inputs: Dict[str, Any] ) -> Dict[str, Any]: contexts = [] fusion_thought = "" agent_scratchpad = "" - inputs = {"input": input} + + if self.memory: + input_with_mem = f"My previous question/answer was: {self.memory_context}\n. If needed, use that context and this {input} to answer my question. Otherwise just give me an answer to: {input} " + inputs = {"input": input_with_mem} + else: + inputs = {"input": input} + for i in range(self.max_retries): is_first_iter = i == 0 is_final_iter = i == self.max_retries - 1 @@ -464,4 +467,9 @@ async def acall( formatted_contexts = self._format_contexts(contexts) inputs["context"] = formatted_contexts + max_memory = 3 # TODO consider exposing this to users + + if len(self.memory_context) <= max_memory: + self.memory_context.append({"Question:": input, "Answer": answer}) + return answer