Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions mcp_server_snowflake/object_manager/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,46 +99,55 @@ def list_objects(
like: str = None,
starts_with: str = None,
):
bindvars = []
if object_type == "image_repository":
object_name = "image repositories"
elif object_type == "compute_pool":
object_name = "compute pools"
else:
object_name = f"{object_type}s"

# Note: SHOW statements do not support variable binding. String formatting is
# safe here because object_type is restricted to a set of whitelisted values.
statement = f"SHOW {object_name}"

if like:
statement += f" LIKE '%{like.replace('%', '')}%'"
statement += " LIKE ?"
bindvars.extend([f"%{like.replace('%', '')}%"])

if object_type in ["database", "compute_pool", "role", "user"]:
pass
elif database_name is None and schema_name is None:
statement += " IN ACCOUNT"
elif database_name and schema_name:
statement += f" IN SCHEMA {database_name}.{schema_name}"
statement += " IN SCHEMA identifier(?)"
bindvars.extend([f"{database_name}.{schema_name}"])
elif database_name:
statement += f" IN DATABASE {database_name}"
statement += " IN DATABASE identifier(?)"
bindvars.extend([database_name])
elif schema_name:
statement += f" IN SCHEMA {schema_name}"
statement += " IN SCHEMA identifier(?)"
bindvars.extend([schema_name])
else:
raise SnowflakeException(
tool="list_objects",
message="Please specify a database, database + schema, or neither to query the account.",
)

if starts_with:
statement += f" STARTS WITH '{starts_with}'"
# sanitizing string manually because bind variables are not supported here
sanitized_starts_with = starts_with.replace("'", "")
statement += f" STARTS WITH '{sanitized_starts_with}'"

try:
result = execute_query(statement, snowflake_service)
result = execute_query(statement, snowflake_service, bindvars)

if len(result) > 0:
return result[0:1000] # Limit to 1000 results
else:
return f"No matching {object_name} found."
except Exception as e:
raise SnowflakeException(tool="list_semantic_views", message=str(e))
raise SnowflakeException(tool="list_objects", message=str(e))


def parse_object(target_object: Any, obj_type: supported_objects):
Expand Down
101 changes: 62 additions & 39 deletions mcp_server_snowflake/semantic_manager/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,36 @@ def list_semantic_views(
starts_with: str = None,
):
statement = "SHOW SEMANTIC VIEWS"
bindvars = []

if like:
statement += f" LIKE '%{like.replace('%', '')}%'"
statement += " LIKE ?"
bindvars.extend([f"%{like.replace('%', '')}%"])

if not database_name and not schema_name:
statement += " IN ACCOUNT"
elif database_name and schema_name:
statement += f" IN SCHEMA {database_name}.{schema_name}"
statement += " IN SCHEMA identifier(?)"
bindvars.extend([f"{database_name}.{schema_name}"])
elif database_name:
statement += f" IN DATABASE {database_name}"
statement += " IN DATABASE identifier(?)"
bindvars.extend([database_name])
elif schema_name:
statement += f" IN SCHEMA {schema_name}"
statement += " IN SCHEMA identifier(?)"
bindvars.extend([schema_name])
else:
raise SnowflakeException(
tool="list_semantic_views",
message="Please specify a database, database + schema, or neither to query the account.",
)

if starts_with:
statement += f" STARTS WITH '{starts_with}'"
# sanitizing string manually because bind variables are not supported here
sanitized_starts_with = starts_with.replace("'", "")
statement += f" STARTS WITH '{sanitized_starts_with}'"

try:
result = execute_query(statement, snowflake_service)
result = execute_query(statement, snowflake_service, bindvars)
# Semantic view metadata has unnecessary extension key
for item in result:
item.pop("extension", None)
Expand All @@ -62,10 +69,11 @@ def describe_semantic_view(
tool="describe_semantic_view", message="Please specify a view name."
)

statement = f"DESCRIBE SEMANTIC VIEW {database_name}.{schema_name}.{view_name}"
statement = "DESCRIBE SEMANTIC VIEW identifier(?)"
bindvars = [f"{database_name}.{schema_name}.{view_name}"]

try:
result = execute_query(statement, snowflake_service)
result = execute_query(statement, snowflake_service, bindvars)
# Semantic view metadata has ugly extension key, so we need to remove it
result = [item for item in result if item.get("object_kind") != "EXTENSION"]
return result
Expand All @@ -82,32 +90,33 @@ def show_semantic_expressions(
like: str = None,
starts_with: str = None,
):
bindvars = []
# fstring should be safe here since expression type is restricted to whitelisted value
statement = f"SHOW SEMANTIC {expression_type}"

if like:
statement += f" LIKE '%{like.replace('%', '')}%'"
statement += " LIKE ?"
bindvars.extend([f"%{like.replace('%', '')}%"])

if view_name:
statement += " IN"
elif schema_name:
statement += " IN SCHEMA"
if database_name and schema_name and view_name:
statement += " IN identifier(?)"
bindvars.extend([f"{database_name}.{schema_name}.{view_name}"])
elif database_name and schema_name:
statement += " IN SCHEMA identifier(?)"
bindvars.extend([f"{database_name}.{schema_name}"])
elif database_name:
statement += " IN DATABASE"
statement += " IN DATABASE identifier(?)"
bindvars.extend([f"{database_name}"])
else:
statement += " IN ACCOUNT"

if database_name:
statement += f" {database_name}"
if schema_name:
statement += f".{schema_name}"
if view_name:
statement += f".{view_name}"

if starts_with:
statement += f" STARTS WITH '{starts_with}'"
# sanitizing string manually because bind variables are not supported here
sanitized_starts_with = starts_with.replace("'", "")
statement += f" STARTS WITH '{sanitized_starts_with}'"

try:
result = execute_query(statement, snowflake_service)
result = execute_query(statement, snowflake_service, bindvars)
if not result:
return f"No {expression_type.lower()} found."
return result
Expand All @@ -127,10 +136,11 @@ def get_semantic_view_ddl(
tool="get_semantic_view_ddl", message="Please specify a view name."
)

statement = f"SELECT GET_DDL('SEMANTIC_VIEW', '{database_name}.{schema_name}.{view_name}', TRUE) as DDL"
statement = "SELECT GET_DDL('SEMANTIC_VIEW', ?, TRUE) as DDL"
bindvars = [f"{database_name}.{schema_name}.{view_name}"]

try:
return execute_query(statement, snowflake_service)[0].get("DDL")
return execute_query(statement, snowflake_service, bindvars)[0].get("DDL")
except Exception as e:
raise SnowflakeException(tool="get_semantic_view_ddl", message=str(e))

Expand All @@ -145,7 +155,7 @@ def write_semantic_view_query(
where_clause: str = None,
order_by: str = None,
limit: int | str = None,
):
) -> tuple[str, list[str]]:
"""
Query a semantic view with comprehensive support for all SEMANTIC_VIEW clauses.

Expand All @@ -171,23 +181,36 @@ def write_semantic_view_query(
message="Cannot specify both FACTS and METRICS in the same SEMANTIC_VIEW query",
)

statement = f"""SELECT * FROM SEMANTIC_VIEW (
{database_name}.{schema_name}.{view_name}
"""
statement = "SELECT * FROM SEMANTIC_VIEW (identifier(?)"
bindvars = [f"{database_name}.{schema_name}.{view_name}"]

# Add clauses in order (affects output column order)
if dimensions:
statement += f" DIMENSIONS {', '.join([f'{expr.table}.{expr.name}' for expr in dimensions])}"
statement += " DIMENSIONS"
for index, expr in enumerate(dimensions):
is_last = index == len(dimensions) - 1
statement += " identifier(?)"
bindvars.extend([f"{expr.table}.{expr.name}"])
if not is_last:
statement += ","

if metrics:
statement += (
f" METRICS {', '.join([f'{expr.table}.{expr.name}' for expr in metrics])}"
)
statement += " METRICS"
for index, expr in enumerate(metrics):
is_last = index == len(metrics) - 1
statement += " identifier(?)"
bindvars.extend([f"{expr.table}.{expr.name}"])
if not is_last:
statement += ","

if facts:
statement += (
f" FACTS {', '.join([f'{expr.table}.{expr.name}' for expr in facts])}"
)
statement += " FACTS"
for index, expr in enumerate(facts):
is_last = index == len(facts) - 1
statement += " identifier(?)"
bindvars.extend([f"{expr.table}.{expr.name}"])
if not is_last:
statement += ","

statement += ")" # Close out the semantic sub-select

Expand All @@ -202,7 +225,7 @@ def write_semantic_view_query(
statement += f" LIMIT {int(limit)}"

try:
return statement
return (statement, bindvars)
except Exception as e:
raise SnowflakeException(tool="write_semantic_view_query", message=str(e))

Expand All @@ -220,7 +243,7 @@ def query_semantic_view(
limit: int | str = None,
):
try:
statement = write_semantic_view_query(
(statement, bindvars) = write_semantic_view_query(
view_name,
database_name,
schema_name,
Expand All @@ -232,7 +255,7 @@ def query_semantic_view(
limit,
)

return execute_query(statement, snowflake_service)
return execute_query(statement, snowflake_service, bindvars)
except Exception as e:
raise SnowflakeException(tool="query_semantic_view", message=str(e))

Expand Down
2 changes: 2 additions & 0 deletions mcp_server_snowflake/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _get_persistent_connection(
**connection_params,
session_parameters=session_parameters,
client_session_keep_alive=True,
paramstyle="qmark",
)
if connection: # Send zero compute query to capture query tag
self.send_initial_query(connection)
Expand Down Expand Up @@ -362,6 +363,7 @@ def get_connection(
**connection_params,
session_parameters=session_parameters,
client_session_keep_alive=False,
paramstyle="qmark",
)

cursor = (
Expand Down
4 changes: 2 additions & 2 deletions mcp_server_snowflake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def warn_deprecated_params() -> None:
logger.info(f"Deprecated parameters: {', '.join(deprecated_found)}")


def execute_query(statement: str, snowflake_service):
def execute_query(statement: str, snowflake_service, bindvars: list[str] = []):
"""Execute a Snowflake query and return the results using Python connector dictionary cursor."""
with snowflake_service.get_connection(
use_dict_cursor=True,
Expand All @@ -64,7 +64,7 @@ def execute_query(statement: str, snowflake_service):
con,
cur,
):
cur.execute(statement)
cur.execute(statement, bindvars)
return cur.fetchall()


Expand Down
Loading