Skip to content

Commit

Permalink
refactor: filter cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alherrera committed Dec 5, 2024
1 parent 1a4b4d2 commit 2b3c959
Showing 1 changed file with 1 addition and 33 deletions.
34 changes: 1 addition & 33 deletions agent_gateway/tools/snowflake_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
# limitations under the License.

import asyncio
import contextlib
import inspect
import json
import logging
import re
from typing import Any, Type, Union

from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel
from snowflake.connector.connection import SnowflakeConnection
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
Expand All @@ -45,8 +44,6 @@ class CortexSearchTool(Tool):
retrieval_columns: list = []
service_name: str = ""
connection: Union[Session, SnowflakeConnection] = None
auto_filter: bool = False
filter_generator: object = None

def __init__(
self,
Expand All @@ -55,7 +52,6 @@ def __init__(
data_description,
retrieval_columns,
snowflake_connection,
auto_filter=False,
k=5,
):
"""Parameters
Expand Down Expand Up @@ -185,34 +181,6 @@ def get_min_length(model: Type[BaseModel]):
return min_length


class JSONFilter(BaseModel):
answer: str = Field(description="The filter_query in valid JSON format")

@classmethod
def model_validate_json(
cls,
json_data: str,
*,
strict: bool | None = None,
context: dict[str, Any] | None = None,
):
__tracebackhide__ = True
try:
return cls.__pydantic_validator__.validate_json(
json_data, strict=strict, context=context
)
except ValidationError:
min_length = get_min_length(cls)
for substring_length in range(len(json_data), min_length - 1, -1):
for start in range(len(json_data) - substring_length + 1):
substring = json_data[start : start + substring_length]
with contextlib.suppress(ValidationError):
return cls.__pydantic_validator__.validate_json(
substring, strict=strict, context=context
)
raise ValueError("Could not find valid json")


class CortexAnalystTool(Tool):
"""""Cortex Analyst tool for use with Snowflake Agent Gateway""" ""

Expand Down

0 comments on commit 2b3c959

Please sign in to comment.