Skip to content

Commit

Permalink
refactor: eliminate DSPy dependency (#76)
Browse files Browse the repository at this point in the history
* refactor: eliminiating dspy dependency

* chore: updating uv.lock file

* refactor: filter cleanup

* refactor: filter cleanup
  • Loading branch information
sfc-gh-alherrera authored Dec 5, 2024
1 parent 9f7a01a commit 0558d8f
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 196 deletions.
2 changes: 1 addition & 1 deletion agent_gateway/gateway/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class Planner:
def __init__(
self,
session: object,
llm: str, # point to dspy
llm: str,
example_prompt: str,
example_prompt_replan: str,
tools: Sequence[Union[Tool, StructuredTool]],
Expand Down
162 changes: 13 additions & 149 deletions agent_gateway/tools/snowflake_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
# limitations under the License.

import asyncio
import contextlib
import inspect
import json
import logging
import re
from typing import Any, Type, Union

import dspy
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel
from snowflake.connector.connection import SnowflakeConnection
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
Expand All @@ -46,8 +44,6 @@ class CortexSearchTool(Tool):
retrieval_columns: list = []
service_name: str = ""
connection: Union[Session, SnowflakeConnection] = None
auto_filter: bool = False
filter_generator: object = None

def __init__(
self,
Expand All @@ -56,7 +52,6 @@ def __init__(
data_description,
retrieval_columns,
snowflake_connection,
auto_filter=False,
k=5,
):
"""Parameters
Expand All @@ -67,7 +62,6 @@ def __init__(
data_description (str): description of the source data that has been indexed.
retrieval_columns (list): list of columns to include in Cortex Search results.
snowflake_connection (object): snowpark connection object
auto_filter (bool): automatically generate filter based on user's query or not.
k: number of records to include in results
"""
tool_name = f"{service_name.lower()}_cortexsearch"
Expand All @@ -79,13 +73,7 @@ def __init__(
super().__init__(
name=tool_name, description=tool_description, func=self.asearch
)
self.auto_filter = auto_filter
self.connection = _get_connection(snowflake_connection)
if self.auto_filter:
self.filter_generator = SmartSearch()
lm = dspy.Snowflake(session=self.session, model="mixtral-8x7b")
dspy.settings.configure(lm=lm)

self.k = k
self.retrieval_columns = retrieval_columns
self.service_name = service_name
Expand Down Expand Up @@ -113,27 +101,11 @@ def _prepare_request(self, query):
self.connection.schema,
self.service_name,
)
if self.auto_filter:
search_attributes, sample_vals = self._get_sample_values(
snowflake_connection=Session.builder.config(
"connection", self.connection
),
cortex_search_service=self.service_name,
)
raw_filter = self.filter_generator(
query=query,
attributes=str(search_attributes),
sample_values=str(sample_vals),
)["answer"]
filter = json.loads(raw_filter)
else:
filter = None

data = {
"query": query,
"columns": self.retrieval_columns,
"limit": self.k,
"filter": filter,
}

return headers, url, data
Expand Down Expand Up @@ -209,82 +181,6 @@ def get_min_length(model: Type[BaseModel]):
return min_length


class JSONFilter(BaseModel):
answer: str = Field(description="The filter_query in valid JSON format")

@classmethod
def model_validate_json(
cls,
json_data: str,
*,
strict: bool | None = None,
context: dict[str, Any] | None = None,
):
__tracebackhide__ = True
try:
return cls.__pydantic_validator__.validate_json(
json_data, strict=strict, context=context
)
except ValidationError:
min_length = get_min_length(cls)
for substring_length in range(len(json_data), min_length - 1, -1):
for start in range(len(json_data) - substring_length + 1):
substring = json_data[start : start + substring_length]
with contextlib.suppress(ValidationError):
return cls.__pydantic_validator__.validate_json(
substring, strict=strict, context=context
)
raise ValueError("Could not find valid json")


class GenerateFilter(dspy.Signature):
"""Given a query, attributes in the data, and example values of each attribute, generate a filter in valid JSON format.
Ensure the filter only uses valid operators: @eq, @contains,@and,@or,@not
Ensure only the valid JSON is output with no other reasoning.
---
Query: What was the sentiment of CEOs between 2021 and 2024?
Attributes: industry,hq,date
Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"HQ":["NY, US","CA,US","FL,US"],"date":["01/01,1999","01/01/2024"]}
Answer: {"@or":[{"@eq":{"year":"2021"}},{"@eq":{"year":"2022"}},{"@eq":{"year":"2023"}},{"@eq":{"year":"2024"}}]}
Query: What is the sentiment of Biotech CEOs of companies based in New York?
Attributes: industry,hq,date
Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"HQ":["NY, US","CA,US","FL,US"],"date":["01/01,1999","01/01/2024"]}
Answer: {"@and":[{ "@eq": { "industry": "biotechnology" } },{"@not":{"@eq":{"HQ":"CA,US"}}}]}
Query: What is the sentiment of Biotech CEOs outside of California?
Attributes: industry,hq,date
Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"HQ":["NY, US","CA,US","FL,US"],"date":["01/01,1999","01/01/2024"]}
Answer: {"@and":[{ "@eq": { "industry": "biotechnology" } },{"@not":{"@eq":{"HQ":"CA,US"}}}]}
Query: What is sentiment towards ag and biotech companies based outside of the US?
Attributes: industry,hq,date
Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"COUNTRY":["United States","Ireland","Russia","Georgia","Spain"],"month":["01","02","03","06","11","12"],"year":["2022","2023","2024"]}
Answer: {"@and": [{ "@or": [{"@eq":{ "industry": "biotechnology" } },{"@eq":{"industry":"agriculture"}}]},{ "@not": {"@eq": { "COUNTRY": "United States" } }}]}
"""

query = dspy.InputField(desc="user query")
attributes = dspy.InputField(desc="attributes to filter on")
sample_values = dspy.InputField(desc="examples of values per attribute")
answer: JSONFilter = dspy.OutputField(
desc="filter query in valid JSON format. ONLY output the filter query in JSON, no reasoning"
)


class SmartSearch(dspy.Module):
def __init__(self):
super().__init__()
self.filter_gen = dspy.ChainOfThought(GenerateFilter)

def forward(self, query, attributes, sample_values):
filter_query = self.filter_gen(
query=query, attributes=attributes, sample_values=sample_values
)

return filter_query


class CortexAnalystTool(Tool):
"""""Cortex Analyst tool for use with Snowflake Agent Gateway""" ""

Expand Down Expand Up @@ -329,42 +225,21 @@ def __call__(self, prompt) -> Any:
async def asearch(self, query):
gateway_logger.log(logging.DEBUG, f"Cortex Analyst Prompt:{query}")

for _ in range(3):
current_query = query
url, headers, data = self._prepare_analyst_request(prompt=query)
url, headers, data = self._prepare_analyst_request(prompt=query)

response_text = await post_cortex_request(
url=url, headers=headers, data=data
)
json_response = json.loads(response_text)
response_text = await post_cortex_request(url=url, headers=headers, data=data)
json_response = json.loads(response_text)

gateway_logger.log(
logging.DEBUG, f"Cortex Analyst Raw Response:{json_response}"
)
gateway_logger.log(
logging.DEBUG, f"Cortex Analyst Raw Response:{json_response}"
)

try:
query_response = self._process_analyst_message(
json_response["message"]["content"]
)

if "Unable to generate valid SQL Query" in query_response:
lm = dspy.Snowflake(
session=Session.builder.config(
"connection", self.connection
).getOrCreate(),
model="llama3.2-1b",
)
dspy.settings.configure(lm=lm)
rephrase_prompt = dspy.ChainOfThought(PromptRephrase)
prompt = f"Original Query: {current_query}. Previous Response Context: {query_response}"
current_query = rephrase_prompt(user_prompt=prompt)[
"rephrased_prompt"
]
else:
break

except Exception:
raise SnowflakeError(message=json_response["message"])
try:
query_response = self._process_analyst_message(
json_response["message"]["content"]
)
except Exception:
raise SnowflakeError(message=json_response["message"])

return query_response

Expand Down Expand Up @@ -426,17 +301,6 @@ def _prepare_analyst_description(
return base_analyst_description


class PromptRephrase(dspy.Signature):
"""Takes in a prompt and rephrases it using context into to a single concise, and specific question.
If there are references to entities that are not clear or consistent with the question being asked, make the references more appropriate.
"""

user_prompt = dspy.InputField(desc="original user prompt")
rephrased_prompt = dspy.OutputField(
desc="rephrased prompt with more clear and specific intent"
)


class PythonTool(Tool):
python_callable: object = None

Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ version = "0.1.0"
requires-python = ">=3.9"
description = "Multi-agent framework for Snowflake"
authors = [
{ name = "Alejandro Ferrera", email = "alejandro.herrera@snowflake.com" },
{ name = "Alejandro Herrera", email = "alejandro.herrera@snowflake.com" },
]
readme = "README.md"

dependencies = [
"snowflake-snowpark-python>=1.22.1",
"dspy-ai>=2.5.3",
"langchain>=0.3.2",
"asyncio>=3.4.3",
"aiohttp>=3.10.9",
Expand Down
44 changes: 0 additions & 44 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0558d8f

Please sign in to comment.