From 6888016d840931fa751ab9bdc7c4c4e9b2e2bd30 Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 6 Dec 2024 11:08:07 -0500 Subject: [PATCH 1/3] adding memory test --- tests/test_quickstart.py | 53 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/test_quickstart.py b/tests/test_quickstart.py index d455ae9..0345755 100644 --- a/tests/test_quickstart.py +++ b/tests/test_quickstart.py @@ -142,3 +142,56 @@ def get_news(_) -> dict: ) response = agent(question) assert answer_contains in response + + +@pytest.mark.parametrize( + "question, answer_contains", + [ + pytest.param( + "What is the market cap of Apple?", + "$3,019,131,060,224", + id="market_cap", + ), + pytest.param( + "When is Apple releasing a new chip?", + "May 7", + id="product_revenue", + ), + ], +) +def test_gateway_agent_without_memory(session, question, answer_contains): + search_config = { + "service_name": "SEC_SEARCH_SERVICE", + "service_topic": "Snowflake's business,product offerings,and performance", + "data_description": "Snowflake annual reports", + "retrieval_columns": ["CHUNK"], + "snowflake_connection": session, + } + analyst_config = { + "semantic_model": "sp500_semantic_model.yaml", + "stage": "ANALYST", + "service_topic": "S&P500 company and stock metrics", + "data_description": "a table with stock and financial metrics about S&P500 companies ", + "snowflake_connection": session, + } + + def get_news(_) -> dict: + with open("tests/data/response.json") as f: + d = json.load(f) + return d + + python_config = { + "tool_description": "searches for relevant news based on user query", + "output_description": "relevant articles", + "python_func": get_news, + } + annual_reports = CortexSearchTool(**search_config) + sp500 = CortexAnalystTool(**analyst_config) + news_search = PythonTool(**python_config) + agent = Agent( + snowflake_connection=session, + tools=[annual_reports, sp500, news_search], + memory=False, + ) + response = agent(question) + assert answer_contains in response From 9937b97a57af832e044028951760569467d198d6 Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 6 Dec 2024 11:16:55 -0500 Subject: [PATCH 2/3] fix: only access memory_context if memory arg is True --- agent_gateway/gateway/gateway.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/agent_gateway/gateway/gateway.py b/agent_gateway/gateway/gateway.py index 485a02f..ebc704f 100644 --- a/agent_gateway/gateway/gateway.py +++ b/agent_gateway/gateway/gateway.py @@ -468,8 +468,8 @@ async def acall( inputs["context"] = formatted_contexts max_memory = 3 # TODO consider exposing this to users - - if len(self.memory_context) <= max_memory: - self.memory_context.append({"Question:": input, "Answer": answer}) + if self.memory: + if len(self.memory_context) <= max_memory: + self.memory_context.append({"Question:": input, "Answer": answer}) return answer From 232cfa5b835a1c80051831fde65120805c53580c Mon Sep 17 00:00:00 2001 From: Tyler White Date: Fri, 6 Dec 2024 17:27:38 -0500 Subject: [PATCH 3/3] style(ruff): sort imports --- tests/test_quickstart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_quickstart.py b/tests/test_quickstart.py index 0345755..4cab29c 100644 --- a/tests/test_quickstart.py +++ b/tests/test_quickstart.py @@ -16,7 +16,7 @@ import pytest from agent_gateway import Agent -from agent_gateway.tools import CortexSearchTool, CortexAnalystTool, PythonTool +from agent_gateway.tools import CortexAnalystTool, CortexSearchTool, PythonTool @pytest.mark.parametrize(