diff --git a/mcp_server_snowflake/object_manager/tools.py b/mcp_server_snowflake/object_manager/tools.py index 570b506..fba3a45 100644 --- a/mcp_server_snowflake/object_manager/tools.py +++ b/mcp_server_snowflake/object_manager/tools.py @@ -99,6 +99,7 @@ def list_objects( like: str = None, starts_with: str = None, ): + bindvars = [] if object_type == "image_repository": object_name = "image repositories" elif object_type == "compute_pool": @@ -106,21 +107,27 @@ def list_objects( else: object_name = f"{object_type}s" + # Note: SHOW statements do not support variable binding. String formatting is + # safe here because object_type is restricted to a set of whitelisted values. statement = f"SHOW {object_name}" if like: - statement += f" LIKE '%{like.replace('%', '')}%'" + statement += " LIKE ?" + bindvars.extend([f"%{like.replace('%', '')}%"]) if object_type in ["database", "compute_pool", "role", "user"]: pass elif database_name is None and schema_name is None: statement += " IN ACCOUNT" elif database_name and schema_name: - statement += f" IN SCHEMA {database_name}.{schema_name}" + statement += " IN SCHEMA identifier(?)" + bindvars.extend([f"{database_name}.{schema_name}"]) elif database_name: - statement += f" IN DATABASE {database_name}" + statement += " IN DATABASE identifier(?)" + bindvars.extend([database_name]) elif schema_name: - statement += f" IN SCHEMA {schema_name}" + statement += " IN SCHEMA identifier(?)" + bindvars.extend([schema_name]) else: raise SnowflakeException( tool="list_objects", @@ -128,17 +135,19 @@ def list_objects( ) if starts_with: - statement += f" STARTS WITH '{starts_with}'" + # sanitizing string manually because bind variables are not supported here + sanitized_starts_with = starts_with.replace("'", "") + statement += f" STARTS WITH '{sanitized_starts_with}'" try: - result = execute_query(statement, snowflake_service) + result = execute_query(statement, snowflake_service, bindvars) if len(result) > 0: return result[0:1000] # Limit to 1000 results else: return f"No matching {object_name} found." except Exception as e: - raise SnowflakeException(tool="list_semantic_views", message=str(e)) + raise SnowflakeException(tool="list_objects", message=str(e)) def parse_object(target_object: Any, obj_type: supported_objects): diff --git a/mcp_server_snowflake/semantic_manager/tools.py b/mcp_server_snowflake/semantic_manager/tools.py index f8b686f..1c93ee2 100644 --- a/mcp_server_snowflake/semantic_manager/tools.py +++ b/mcp_server_snowflake/semantic_manager/tools.py @@ -19,18 +19,23 @@ def list_semantic_views( starts_with: str = None, ): statement = "SHOW SEMANTIC VIEWS" + bindvars = [] if like: - statement += f" LIKE '%{like.replace('%', '')}%'" + statement += " LIKE ?" + bindvars.extend([f"%{like.replace('%', '')}%"]) if not database_name and not schema_name: statement += " IN ACCOUNT" elif database_name and schema_name: - statement += f" IN SCHEMA {database_name}.{schema_name}" + statement += " IN SCHEMA identifier(?)" + bindvars.extend([f"{database_name}.{schema_name}"]) elif database_name: - statement += f" IN DATABASE {database_name}" + statement += " IN DATABASE identifier(?)" + bindvars.extend([database_name]) elif schema_name: - statement += f" IN SCHEMA {schema_name}" + statement += " IN SCHEMA identifier(?)" + bindvars.extend([schema_name]) else: raise SnowflakeException( tool="list_semantic_views", @@ -38,10 +43,12 @@ def list_semantic_views( ) if starts_with: - statement += f" STARTS WITH '{starts_with}'" + # sanitizing string manually because bind variables are not supported here + sanitized_starts_with = starts_with.replace("'", "") + statement += f" STARTS WITH '{sanitized_starts_with}'" try: - result = execute_query(statement, snowflake_service) + result = execute_query(statement, snowflake_service, bindvars) # Semantic view metadata has unnecessary extension key for item in result: item.pop("extension", None) @@ -62,10 +69,11 @@ def describe_semantic_view( tool="describe_semantic_view", message="Please specify a view name." ) - statement = f"DESCRIBE SEMANTIC VIEW {database_name}.{schema_name}.{view_name}" + statement = "DESCRIBE SEMANTIC VIEW identifier(?)" + bindvars = [f"{database_name}.{schema_name}.{view_name}"] try: - result = execute_query(statement, snowflake_service) + result = execute_query(statement, snowflake_service, bindvars) # Semantic view metadata has ugly extension key, so we need to remove it result = [item for item in result if item.get("object_kind") != "EXTENSION"] return result @@ -82,32 +90,33 @@ def show_semantic_expressions( like: str = None, starts_with: str = None, ): + bindvars = [] + # fstring should be safe here since expression type is restricted to whitelisted value statement = f"SHOW SEMANTIC {expression_type}" if like: - statement += f" LIKE '%{like.replace('%', '')}%'" + statement += " LIKE ?" + bindvars.extend([f"%{like.replace('%', '')}%"]) - if view_name: - statement += " IN" - elif schema_name: - statement += " IN SCHEMA" + if database_name and schema_name and view_name: + statement += " IN identifier(?)" + bindvars.extend([f"{database_name}.{schema_name}.{view_name}"]) + elif database_name and schema_name: + statement += " IN SCHEMA identifier(?)" + bindvars.extend([f"{database_name}.{schema_name}"]) elif database_name: - statement += " IN DATABASE" + statement += " IN DATABASE identifier(?)" + bindvars.extend([f"{database_name}"]) else: statement += " IN ACCOUNT" - if database_name: - statement += f" {database_name}" - if schema_name: - statement += f".{schema_name}" - if view_name: - statement += f".{view_name}" - if starts_with: - statement += f" STARTS WITH '{starts_with}'" + # sanitizing string manually because bind variables are not supported here + sanitized_starts_with = starts_with.replace("'", "") + statement += f" STARTS WITH '{sanitized_starts_with}'" try: - result = execute_query(statement, snowflake_service) + result = execute_query(statement, snowflake_service, bindvars) if not result: return f"No {expression_type.lower()} found." return result @@ -127,10 +136,11 @@ def get_semantic_view_ddl( tool="get_semantic_view_ddl", message="Please specify a view name." ) - statement = f"SELECT GET_DDL('SEMANTIC_VIEW', '{database_name}.{schema_name}.{view_name}', TRUE) as DDL" + statement = "SELECT GET_DDL('SEMANTIC_VIEW', ?, TRUE) as DDL" + bindvars = [f"{database_name}.{schema_name}.{view_name}"] try: - return execute_query(statement, snowflake_service)[0].get("DDL") + return execute_query(statement, snowflake_service, bindvars)[0].get("DDL") except Exception as e: raise SnowflakeException(tool="get_semantic_view_ddl", message=str(e)) @@ -145,7 +155,7 @@ def write_semantic_view_query( where_clause: str = None, order_by: str = None, limit: int | str = None, -): +) -> tuple[str, list[str]]: """ Query a semantic view with comprehensive support for all SEMANTIC_VIEW clauses. @@ -171,23 +181,36 @@ def write_semantic_view_query( message="Cannot specify both FACTS and METRICS in the same SEMANTIC_VIEW query", ) - statement = f"""SELECT * FROM SEMANTIC_VIEW ( - {database_name}.{schema_name}.{view_name} - """ + statement = "SELECT * FROM SEMANTIC_VIEW (identifier(?)" + bindvars = [f"{database_name}.{schema_name}.{view_name}"] # Add clauses in order (affects output column order) if dimensions: - statement += f" DIMENSIONS {', '.join([f'{expr.table}.{expr.name}' for expr in dimensions])}" + statement += " DIMENSIONS" + for index, expr in enumerate(dimensions): + is_last = index == len(dimensions) - 1 + statement += " identifier(?)" + bindvars.extend([f"{expr.table}.{expr.name}"]) + if not is_last: + statement += "," if metrics: - statement += ( - f" METRICS {', '.join([f'{expr.table}.{expr.name}' for expr in metrics])}" - ) + statement += " METRICS" + for index, expr in enumerate(metrics): + is_last = index == len(metrics) - 1 + statement += " identifier(?)" + bindvars.extend([f"{expr.table}.{expr.name}"]) + if not is_last: + statement += "," if facts: - statement += ( - f" FACTS {', '.join([f'{expr.table}.{expr.name}' for expr in facts])}" - ) + statement += " FACTS" + for index, expr in enumerate(facts): + is_last = index == len(facts) - 1 + statement += " identifier(?)" + bindvars.extend([f"{expr.table}.{expr.name}"]) + if not is_last: + statement += "," statement += ")" # Close out the semantic sub-select @@ -202,7 +225,7 @@ def write_semantic_view_query( statement += f" LIMIT {int(limit)}" try: - return statement + return (statement, bindvars) except Exception as e: raise SnowflakeException(tool="write_semantic_view_query", message=str(e)) @@ -220,7 +243,7 @@ def query_semantic_view( limit: int | str = None, ): try: - statement = write_semantic_view_query( + (statement, bindvars) = write_semantic_view_query( view_name, database_name, schema_name, @@ -232,7 +255,7 @@ def query_semantic_view( limit, ) - return execute_query(statement, snowflake_service) + return execute_query(statement, snowflake_service, bindvars) except Exception as e: raise SnowflakeException(tool="query_semantic_view", message=str(e)) diff --git a/mcp_server_snowflake/server.py b/mcp_server_snowflake/server.py index e4cb5e0..e3053f3 100644 --- a/mcp_server_snowflake/server.py +++ b/mcp_server_snowflake/server.py @@ -298,6 +298,7 @@ def _get_persistent_connection( **connection_params, session_parameters=session_parameters, client_session_keep_alive=True, + paramstyle="qmark", ) if connection: # Send zero compute query to capture query tag self.send_initial_query(connection) @@ -362,6 +363,7 @@ def get_connection( **connection_params, session_parameters=session_parameters, client_session_keep_alive=False, + paramstyle="qmark", ) cursor = ( diff --git a/mcp_server_snowflake/utils.py b/mcp_server_snowflake/utils.py index 2f23ff4..297a0ee 100644 --- a/mcp_server_snowflake/utils.py +++ b/mcp_server_snowflake/utils.py @@ -55,7 +55,7 @@ def warn_deprecated_params() -> None: logger.info(f"Deprecated parameters: {', '.join(deprecated_found)}") -def execute_query(statement: str, snowflake_service): +def execute_query(statement: str, snowflake_service, bindvars: list[str] = []): """Execute a Snowflake query and return the results using Python connector dictionary cursor.""" with snowflake_service.get_connection( use_dict_cursor=True, @@ -64,7 +64,7 @@ def execute_query(statement: str, snowflake_service): con, cur, ): - cur.execute(statement) + cur.execute(statement, bindvars) return cur.fetchall()