From 82a8c20586c1ce7e45ad106a0ab8a6725aed075b Mon Sep 17 00:00:00 2001 From: fuzzwah Date: Mon, 5 Jan 2026 08:49:02 +1100 Subject: [PATCH 1/3] Add DML_ONLY access mode - Add DmlOnlySqlDriver class to allow DML operations while blocking DDL - Support INSERT, UPDATE, DELETE, and UPSERT operations - Add comprehensive test suite with 46 unit tests - Update documentation and CLI help text - Provides middle-ground security between unrestricted and restricted modes --- DML_ONLY_MODE_IMPLEMENTATION.md | 324 +++++++++++++++++ README.md | 8 +- src/postgres_mcp/server.py | 11 +- src/postgres_mcp/sql/__init__.py | 2 + src/postgres_mcp/sql/dml_only_sql.py | 180 ++++++++++ tests/unit/sql/test_dml_only_sql.py | 506 +++++++++++++++++++++++++++ tests/unit/test_access_mode.py | 20 +- 7 files changed, 1044 insertions(+), 7 deletions(-) create mode 100644 DML_ONLY_MODE_IMPLEMENTATION.md create mode 100644 src/postgres_mcp/sql/dml_only_sql.py create mode 100644 tests/unit/sql/test_dml_only_sql.py diff --git a/DML_ONLY_MODE_IMPLEMENTATION.md b/DML_ONLY_MODE_IMPLEMENTATION.md new file mode 100644 index 00000000..80de03fa --- /dev/null +++ b/DML_ONLY_MODE_IMPLEMENTATION.md @@ -0,0 +1,324 @@ +# DML_ONLY Mode Implementation Plan + +## Goal + +Add a new access mode to postgres-mcp that allows DML operations (INSERT, UPDATE, DELETE, UPSERT) while blocking DDL operations (CREATE TABLE, ALTER TABLE, DROP TABLE, CREATE INDEX, etc.). + +## Current State + +The postgres-mcp server currently has two access modes: + +- **Unrestricted Mode** (`--access-mode=unrestricted`): Allows all SQL operations (DDL + DML) +- **Restricted Mode** (`--access-mode=restricted`): Only allows SELECT and read-only operations (blocks all writes) + +## Required Functionality + +We need a third mode that sits between these two extremes: + +- **DML_ONLY Mode** (`--access-mode=dml_only`): Allows data manipulation (INSERT, UPDATE, DELETE) but blocks schema changes + +### Allowed Operations in DML_ONLY Mode +- SELECT (all read operations) +- INSERT +- UPDATE +- DELETE +- UPSERT (INSERT ... ON CONFLICT ... DO UPDATE) +- EXPLAIN (for query analysis) +- SHOW (for system information) +- Transaction control within read-write transactions + +### Blocked Operations in DML_ONLY Mode +- CREATE TABLE / CREATE INDEX / CREATE EXTENSION +- ALTER TABLE / ALTER INDEX / ALTER EXTENSION +- DROP TABLE / DROP INDEX / DROP EXTENSION +- TRUNCATE +- VACUUM (can be dangerous in production) +- CREATE/ALTER/DROP SCHEMA +- CREATE/ALTER/DROP DATABASE +- Any other DDL operations + +## Implementation Steps + +### 1. Add DML_ONLY Access Mode + +**File**: `src/postgres_mcp/server.py` + +Modify the `AccessMode` enum to add the new mode: + +```python +class AccessMode(str, Enum): + """SQL access modes for the server.""" + UNRESTRICTED = "unrestricted" + RESTRICTED = "restricted" + DML_ONLY = "dml_only" # New: allow DML, block DDL +``` + +### 2. Create DmlOnlySqlDriver Class + +**File**: `src/postgres_mcp/sql/dml_only_sql.py` (new file) + +Create a new driver class similar to `SafeSqlDriver` but with different allowed statement types: + +```python +from typing import ClassVar +from pglast import parse_sql +from pglast.ast import ( + SelectStmt, + InsertStmt, + UpdateStmt, + DeleteStmt, + ExplainStmt, + VariableShowStmt, + # ... other allowed types +) + +class DmlOnlySqlDriver(SqlDriver): + """ + A wrapper around SqlDriver that allows DML operations but blocks DDL. + + Allows: SELECT, INSERT, UPDATE, DELETE, and other read operations + Blocks: CREATE, ALTER, DROP, TRUNCATE, and other DDL operations + """ + + ALLOWED_STMT_TYPES: ClassVar[set[type]] = { + SelectStmt, # SELECT queries + InsertStmt, # INSERT + UpdateStmt, # UPDATE + DeleteStmt, # DELETE + ExplainStmt, # EXPLAIN + VariableShowStmt, # SHOW statements + # Add other safe statement types... + } + + # Reuse allowed functions from SafeSqlDriver + ALLOWED_FUNCTIONS: ClassVar[set[str]] = SafeSqlDriver.ALLOWED_FUNCTIONS + + # Reuse allowed node types from SafeSqlDriver + ALLOWED_NODE_TYPES: ClassVar[set[type]] = SafeSqlDriver.ALLOWED_NODE_TYPES + + def _validate(self, query: str) -> None: + """Validate query allows DML but blocks DDL""" + # Parse and validate using pglast (similar to SafeSqlDriver) + parsed = parse_sql(query) + for stmt in parsed: + # Check statement type is allowed + # Recursively validate all nodes + # Reject DDL operations +``` + +**Key Implementation Notes**: +- Inherit from `SqlDriver` +- Reuse the validation logic from `SafeSqlDriver` but with different `ALLOWED_STMT_TYPES` +- Use pglast library for SQL parsing and validation +- Include timeout support (like SafeSqlDriver uses 30 seconds) +- Ensure `force_readonly=False` since we're allowing writes + +### 3. Update get_sql_driver() Function + +**File**: `src/postgres_mcp/server.py` + +Modify the `get_sql_driver()` function to return the appropriate driver: + +```python +async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver, DmlOnlySqlDriver]: + """Get the appropriate SQL driver based on the current access mode.""" + base_driver = SqlDriver(conn=db_connection) + + if current_access_mode == AccessMode.RESTRICTED: + logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") + return SafeSqlDriver(sql_driver=base_driver, timeout=30) + elif current_access_mode == AccessMode.DML_ONLY: + logger.debug("Using DmlOnlySqlDriver (DML_ONLY mode)") + return DmlOnlySqlDriver(sql_driver=base_driver, timeout=30) + else: + logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") + return base_driver +``` + +### 4. Update execute_sql Tool Description + +**File**: `src/postgres_mcp/server.py` + +Update the dynamic tool registration to include DML_ONLY mode: + +```python +if current_access_mode == AccessMode.UNRESTRICTED: + mcp.add_tool(execute_sql, description="Execute any SQL query") +elif current_access_mode == AccessMode.DML_ONLY: + mcp.add_tool(execute_sql, description="Execute DML operations (INSERT, UPDATE, DELETE) and read queries") +else: + mcp.add_tool(execute_sql, description="Execute a read-only SQL query") +``` + +### 5. Update __init__.py Export + +**File**: `src/postgres_mcp/sql/__init__.py` + +Add the new driver to exports: + +```python +from .dml_only_sql import DmlOnlySqlDriver + +__all__ = [ + # ... existing exports ... + "DmlOnlySqlDriver", +] +``` + +## Testing Requirements + +### Unit Tests + +**File**: `tests/unit/sql/test_dml_only_sql.py` (new file) + +Create comprehensive tests covering: + +1. **Allowed DML Operations**: + - `test_insert_statement` - INSERT should be allowed + - `test_update_statement` - UPDATE should be allowed + - `test_delete_statement` - DELETE should be allowed + - `test_insert_on_conflict` - UPSERT should be allowed + - `test_select_statement` - SELECT should be allowed + +2. **Blocked DDL Operations**: + - `test_create_table_blocked` - CREATE TABLE should be rejected + - `test_alter_table_blocked` - ALTER TABLE should be rejected + - `test_drop_table_blocked` - DROP TABLE should be rejected + - `test_create_index_blocked` - CREATE INDEX should be rejected + - `test_drop_index_blocked` - DROP INDEX should be rejected + - `test_truncate_blocked` - TRUNCATE should be rejected + - `test_create_extension_blocked` - CREATE EXTENSION should be rejected + - `test_vacuum_blocked` - VACUUM should be rejected + +3. **Complex Queries**: + - `test_insert_with_select` - INSERT ... SELECT should work + - `test_update_with_join` - UPDATE with JOIN should work + - `test_delete_with_subquery` - DELETE with WHERE IN (SELECT...) should work + - `test_cte_with_dml` - CTEs with DML should work + +**File**: `tests/unit/test_access_mode.py` + +Add DML_ONLY to existing access mode tests: + +```python +@pytest.mark.parametrize( + "access_mode,expected_driver_type", + [ + (AccessMode.UNRESTRICTED, SqlDriver), + (AccessMode.RESTRICTED, SafeSqlDriver), + (AccessMode.DML_ONLY, DmlOnlySqlDriver), # Add this + ], +) +``` + +### Integration Tests + +**File**: `tests/integration/test_dml_only_integration.py` (new file) + +Test against a real PostgreSQL database: + +1. Set up test tables +2. Execute INSERT/UPDATE/DELETE operations successfully +3. Verify DDL operations are blocked +4. Verify data is actually modified in the database +5. Test transaction handling + +## Documentation Updates + +### 1. README.md + +Add DML_ONLY mode to the documentation: + +- Update "Access Mode" section +- Add usage examples +- Update CLI help text +- Add to comparison table + +### 2. Command-line Help + +**File**: `src/postgres_mcp/server.py` + +Update the argparse help text: + +```python +parser.add_argument( + "--access-mode", + type=str, + choices=[mode.value for mode in AccessMode], + default=AccessMode.UNRESTRICTED.value, + help="Set SQL access mode: unrestricted (full access), dml_only (allow DML, block DDL), or restricted (read-only)", +) +``` + +### 3. Examples + +Add example configurations showing DML_ONLY mode usage in: +- Claude Desktop config +- VS Code MCP settings +- Docker examples + +## Validation Checklist + +- [ ] `AccessMode` enum includes `DML_ONLY` +- [ ] `DmlOnlySqlDriver` class created with proper validation +- [ ] `get_sql_driver()` returns correct driver for DML_ONLY mode +- [ ] All unit tests pass +- [ ] Integration tests pass with real database +- [ ] Documentation updated (README, CLI help) +- [ ] All existing tests still pass +- [ ] Code follows project style (ruff, pyright) +- [ ] Example configurations added + +## Running Tests + +```bash +# Run all tests +pytest + +# Run only DML_ONLY tests +pytest tests/unit/sql/test_dml_only_sql.py + +# Run with coverage +pytest --cov=src/postgres_mcp --cov-report=html + +# Check code style +ruff check src/ tests/ +pyright src/ +``` + +## Key Files to Modify/Create + +### New Files +- `src/postgres_mcp/sql/dml_only_sql.py` - DmlOnlySqlDriver implementation +- `tests/unit/sql/test_dml_only_sql.py` - Unit tests for DML_ONLY mode +- `tests/integration/test_dml_only_integration.py` - Integration tests + +### Modified Files +- `src/postgres_mcp/server.py` - Add AccessMode.DML_ONLY, update get_sql_driver() +- `src/postgres_mcp/sql/__init__.py` - Export DmlOnlySqlDriver +- `tests/unit/test_access_mode.py` - Add DML_ONLY to parametrized tests +- `README.md` - Document the new mode + +## Reference Implementation + +The implementation should closely follow the pattern established by `SafeSqlDriver` in `src/postgres_mcp/sql/safe_sql.py`: + +1. Use pglast for SQL parsing +2. Maintain allowed statement types, functions, and node types +3. Implement recursive node validation +4. Apply timeout to prevent long-running queries +5. Use proper error messages for blocked operations +6. Follow the existing code style and patterns + +## Success Criteria + +The implementation is complete when: + +1. All tests pass (existing + new) +2. Documentation is updated +3. An agent can successfully use the DML_ONLY mode to: + - Insert data into tables + - Update existing records + - Delete records + - But is blocked from creating/altering/dropping tables or indexes +4. The mode works correctly in both Claude Desktop and VS Code MCP configurations diff --git a/README.md b/README.md index 82236d33..687118d5 100644 --- a/README.md +++ b/README.md @@ -203,9 +203,11 @@ Replace `postgresql://...` with your [Postgres database connection URI](https:// Postgres MCP Pro supports multiple *access modes* to give you control over the operations that the AI agent can perform on the database: - **Unrestricted Mode**: Allows full read/write access to modify data and schema. It is suitable for development environments. +- **DML Only Mode**: Allows data manipulation (INSERT, UPDATE, DELETE) but blocks schema changes (CREATE, ALTER, DROP). It is suitable for environments where you want to allow data modification but prevent schema changes. - **Restricted Mode**: Limits operations to read-only transactions and imposes constraints on resource utilization (presently only execution time). It is suitable for production environments. To use restricted mode, replace `--access-mode=unrestricted` with `--access-mode=restricted` in the configuration examples above. +To use DML only mode, use `--access-mode=dml_only`. #### Other MCP Clients @@ -585,10 +587,12 @@ We reject any SQL that contains `commit` or `rollback` statements. Helpfully, the popular Postgres stored procedure languages, including PL/pgSQL and PL/Python, do not allow for `COMMIT` or `ROLLBACK` statements. If you have unsafe stored procedure languages enabled on your database, then our read-only protections could be circumvented. -At present, Postgres MCP Pro provides two levels of protection for the database, one at either extreme of the convenience/safety spectrum. +At present, Postgres MCP Pro provides three levels of protection for the database, spanning the convenience/safety spectrum. - "Unrestricted" provides maximum flexibility. It is suitable for development environments where speed and flexibility are paramount, and where there is no need to protect valuable or sensitive data. -- "Restricted" provides a balance between flexibility and safety. +- "DML Only" provides a middle ground, allowing data manipulation while protecting the schema from accidental or malicious changes. +It is suitable for environments where you need to allow data modifications but want to prevent structural changes to the database. +- "Restricted" provides maximum safety with read-only operations. It is suitable for production environments where the database is exposed to untrusted users, and where it is important to protect valuable or sensitive data. Unrestricted mode aligns with the approach of [Cursor's auto-run mode](https://docs.cursor.com/chat/tools#auto-run), where the AI agent operates with limited human oversight or approvals. diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index af5669a1..695cb5de 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -27,6 +27,7 @@ from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation from .sql import DbConnPool +from .sql import DmlOnlySqlDriver from .sql import SafeSqlDriver from .sql import SqlDriver from .sql import check_hypopg_installation_status @@ -50,6 +51,7 @@ class AccessMode(str, Enum): UNRESTRICTED = "unrestricted" # Unrestricted access RESTRICTED = "restricted" # Read-only with safety features + DML_ONLY = "dml_only" # Allow DML operations, block DDL # Global variables @@ -58,13 +60,16 @@ class AccessMode(str, Enum): shutdown_in_progress = False -async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: +async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver, DmlOnlySqlDriver]: """Get the appropriate SQL driver based on the current access mode.""" base_driver = SqlDriver(conn=db_connection) if current_access_mode == AccessMode.RESTRICTED: logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout + elif current_access_mode == AccessMode.DML_ONLY: + logger.debug("Using DmlOnlySqlDriver (DML_ONLY mode)") + return DmlOnlySqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout else: logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") return base_driver @@ -518,7 +523,7 @@ async def main(): type=str, choices=[mode.value for mode in AccessMode], default=AccessMode.UNRESTRICTED.value, - help="Set SQL access mode: unrestricted (unrestricted) or restricted (read-only with protections)", + help="Set SQL access mode: unrestricted (full access), dml_only (allow DML, block DDL), or restricted (read-only)", ) parser.add_argument( "--transport", @@ -549,6 +554,8 @@ async def main(): # Add the query tool with a description appropriate to the access mode if current_access_mode == AccessMode.UNRESTRICTED: mcp.add_tool(execute_sql, description="Execute any SQL query") + elif current_access_mode == AccessMode.DML_ONLY: + mcp.add_tool(execute_sql, description="Execute DML operations (INSERT, UPDATE, DELETE) and read queries") else: mcp.add_tool(execute_sql, description="Execute a read-only SQL query") diff --git a/src/postgres_mcp/sql/__init__.py b/src/postgres_mcp/sql/__init__.py index 1fded3bb..82013f12 100644 --- a/src/postgres_mcp/sql/__init__.py +++ b/src/postgres_mcp/sql/__init__.py @@ -3,6 +3,7 @@ from .bind_params import ColumnCollector from .bind_params import SqlBindParams from .bind_params import TableAliasVisitor +from .dml_only_sql import DmlOnlySqlDriver from .extension_utils import check_extension from .extension_utils import check_hypopg_installation_status from .extension_utils import check_postgres_version_requirement @@ -17,6 +18,7 @@ __all__ = [ "ColumnCollector", "DbConnPool", + "DmlOnlySqlDriver", "IndexDefinition", "SafeSqlDriver", "SqlBindParams", diff --git a/src/postgres_mcp/sql/dml_only_sql.py b/src/postgres_mcp/sql/dml_only_sql.py new file mode 100644 index 00000000..bae2d147 --- /dev/null +++ b/src/postgres_mcp/sql/dml_only_sql.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from typing import ClassVar +from typing import Optional + +import pglast +from pglast.ast import DeleteStmt +from pglast.ast import ExplainStmt +from pglast.ast import IndexElem +from pglast.ast import InferClause +from pglast.ast import InsertStmt +from pglast.ast import Node +from pglast.ast import OnConflictClause +from pglast.ast import RawStmt +from pglast.ast import SelectStmt +from pglast.ast import UpdateStmt +from pglast.ast import VariableShowStmt +from typing_extensions import LiteralString + +from .safe_sql import SafeSqlDriver +from .sql_driver import SqlDriver + +logger = logging.getLogger(__name__) + + +class DmlOnlySqlDriver(SqlDriver): + """A wrapper around SqlDriver that allows DML operations but blocks DDL. + + Uses pglast to parse and validate SQL statements before execution. + Allows: SELECT, INSERT, UPDATE, DELETE, and other read operations + Blocks: CREATE, ALTER, DROP, TRUNCATE, and other DDL operations + """ + + # Allowed statement types for DML_ONLY mode + ALLOWED_STMT_TYPES: ClassVar[set[type]] = { + SelectStmt, # SELECT queries + InsertStmt, # INSERT + UpdateStmt, # UPDATE + DeleteStmt, # DELETE + ExplainStmt, # EXPLAIN + VariableShowStmt, # SHOW statements + } + + # Reuse allowed functions from SafeSqlDriver + ALLOWED_FUNCTIONS: ClassVar[set[str]] = SafeSqlDriver.ALLOWED_FUNCTIONS + + # Reuse allowed node types from SafeSqlDriver, plus DML statement types and UPSERT-related nodes + ALLOWED_NODE_TYPES: ClassVar[set[type]] = SafeSqlDriver.ALLOWED_NODE_TYPES | { + InsertStmt, + UpdateStmt, + DeleteStmt, + OnConflictClause, # For INSERT ... ON CONFLICT (UPSERT) + InferClause, # For conflict target specification in UPSERT + IndexElem, # For index element specification in UPSERT conflict target + } + + def __init__(self, sql_driver: SqlDriver, timeout: float | None = None): + """Initialize with an underlying SQL driver and optional timeout. + + Args: + sql_driver: The underlying SQL driver to wrap + timeout: Optional timeout in seconds for query execution + """ + self.sql_driver = sql_driver + self.timeout = timeout + + def _validate_node(self, node: Node) -> None: + """Recursively validate a node and all its children""" + # Check if node type is allowed + if not isinstance(node, tuple(self.ALLOWED_NODE_TYPES)): + raise ValueError(f"Node type {type(node)} is not allowed") + + # Validate function calls (reuse logic from SafeSqlDriver) + if hasattr(node, "funcname") and node.funcname: + func_name = ".".join([str(n.sval) for n in node.funcname]).lower() if node.funcname else "" + # Strip pg_catalog schema if present + match = SafeSqlDriver.PG_CATALOG_PATTERN.match(func_name) + unqualified_name = match.group(1) if match else func_name + if unqualified_name not in self.ALLOWED_FUNCTIONS: + raise ValueError(f"Function {func_name} is not allowed") + + # Reject EXPLAIN ANALYZE statements + if isinstance(node, ExplainStmt): + for option in node.options or []: + if hasattr(option, "defname") and option.defname == "analyze": + raise ValueError("EXPLAIN ANALYZE is not supported") + + # Recursively validate all attributes that might be nodes + for attr_name in node.__slots__: + # Skip private attributes and methods + if attr_name.startswith("_"): + continue + + try: + attr = getattr(node, attr_name) + except AttributeError: + # Skip attributes that don't exist (this is normal in pglast) + continue + + # Handle lists of nodes + if isinstance(attr, list): + for item in attr: + if isinstance(item, Node): + self._validate_node(item) + + # Handle tuples of nodes + elif isinstance(attr, tuple): + for item in attr: + if isinstance(item, Node): + self._validate_node(item) + + # Handle single nodes + elif isinstance(attr, Node): + self._validate_node(attr) + + def _validate(self, query: str) -> None: + """Validate query allows DML but blocks DDL""" + try: + # Parse the SQL using pglast + parsed = pglast.parse_sql(query) + + # Validate each statement + try: + for stmt in parsed: + if isinstance(stmt, RawStmt): + # Check if the inner statement type is allowed + if not isinstance(stmt.stmt, tuple(self.ALLOWED_STMT_TYPES)): + raise ValueError( + f"Only SELECT, INSERT, UPDATE, DELETE, EXPLAIN, and SHOW statements are allowed. " + f"DDL operations are blocked. Received: {type(stmt.stmt).__name__}" + ) + else: + if not isinstance(stmt, tuple(self.ALLOWED_STMT_TYPES)): + raise ValueError( + f"Only SELECT, INSERT, UPDATE, DELETE, EXPLAIN, and SHOW statements are allowed. " + f"DDL operations are blocked. Received: {type(stmt).__name__}" + ) + self._validate_node(stmt) + except Exception as e: + raise ValueError(f"Error validating query: {query}") from e + + except pglast.parser.ParseError as e: + raise ValueError("Failed to parse SQL statement") from e + + async def execute_query( + self, + query: LiteralString, + params: list[Any] | None = None, + force_readonly: bool = False, # Allow writes by default in DML_ONLY mode + ) -> Optional[list[SqlDriver.RowResult]]: # noqa: UP007 + """Execute a query after validating it is safe""" + self._validate(query) + + # Execute with timeout if configured + if self.timeout: + try: + async with asyncio.timeout(self.timeout): + return await self.sql_driver.execute_query( + f"/* crystaldba */ {query}", + params=params, + force_readonly=force_readonly, + ) + except asyncio.TimeoutError as e: + logger.warning(f"Query execution timed out after {self.timeout} seconds: {query[:100]}...") + raise ValueError( + f"Query execution timed out after {self.timeout} seconds in DML_ONLY mode. " + "Consider simplifying your query or increasing the timeout." + ) from e + except Exception as e: + logger.error(f"Error executing query: {e}") + raise + else: + return await self.sql_driver.execute_query( + f"/* crystaldba */ {query}", + params=params, + force_readonly=force_readonly, + ) diff --git a/tests/unit/sql/test_dml_only_sql.py b/tests/unit/sql/test_dml_only_sql.py new file mode 100644 index 00000000..1a0f1a90 --- /dev/null +++ b/tests/unit/sql/test_dml_only_sql.py @@ -0,0 +1,506 @@ +from unittest.mock import AsyncMock +from unittest.mock import Mock + +import pytest +import pytest_asyncio + +from postgres_mcp.sql import DmlOnlySqlDriver +from postgres_mcp.sql import SqlDriver + + +@pytest_asyncio.fixture +async def mock_sql_driver(): + driver = Mock(spec=SqlDriver) + driver.execute_query = AsyncMock(return_value=[]) + return driver + + +@pytest_asyncio.fixture +async def dml_only_driver(mock_sql_driver): + return DmlOnlySqlDriver(mock_sql_driver) + + +# ======================================== +# Test Allowed DML Operations +# ======================================== + + +@pytest.mark.asyncio +async def test_insert_statement(dml_only_driver, mock_sql_driver): + """Test that INSERT statements are allowed""" + query = "INSERT INTO users (name, email) VALUES ('John Doe', 'john@example.com')" + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_update_statement(dml_only_driver, mock_sql_driver): + """Test that UPDATE statements are allowed""" + query = "UPDATE users SET status = 'active' WHERE id = 1" + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_delete_statement(dml_only_driver, mock_sql_driver): + """Test that DELETE statements are allowed""" + query = "DELETE FROM users WHERE status = 'inactive'" + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_insert_on_conflict(dml_only_driver, mock_sql_driver): + """Test that INSERT ... ON CONFLICT (UPSERT) statements are allowed""" + query = """ + INSERT INTO users (id, name, email) + VALUES (1, 'John Doe', 'john@example.com') + ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, email = EXCLUDED.email + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_select_statement(dml_only_driver, mock_sql_driver): + """Test that SELECT statements are allowed""" + query = "SELECT * FROM users WHERE age > 18" + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_insert_with_select(dml_only_driver, mock_sql_driver): + """Test that INSERT ... SELECT statements are allowed""" + query = """ + INSERT INTO user_backup (id, name, email) + SELECT id, name, email FROM users WHERE status = 'active' + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_update_with_join(dml_only_driver, mock_sql_driver): + """Test that UPDATE with JOIN is allowed""" + query = """ + UPDATE users u + SET status = 'premium' + FROM orders o + WHERE u.id = o.user_id AND o.total > 1000 + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_delete_with_subquery(dml_only_driver, mock_sql_driver): + """Test that DELETE with subquery is allowed""" + query = """ + DELETE FROM users + WHERE id IN (SELECT user_id FROM orders WHERE status = 'cancelled') + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_insert_multiple_rows(dml_only_driver, mock_sql_driver): + """Test that INSERT with multiple rows is allowed""" + query = """ + INSERT INTO users (name, email) VALUES + ('Alice', 'alice@example.com'), + ('Bob', 'bob@example.com'), + ('Charlie', 'charlie@example.com') + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_update_with_returning(dml_only_driver, mock_sql_driver): + """Test that UPDATE with RETURNING clause is allowed""" + query = """ + UPDATE users SET status = 'active' WHERE id = 1 RETURNING id, name, status + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_delete_with_returning(dml_only_driver, mock_sql_driver): + """Test that DELETE with RETURNING clause is allowed""" + query = """ + DELETE FROM users WHERE status = 'inactive' RETURNING id, name + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_show_variable(dml_only_driver, mock_sql_driver): + """Test that SHOW statements are allowed""" + query = "SHOW search_path" + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_explain_query(dml_only_driver, mock_sql_driver): + """Test that EXPLAIN statements are allowed""" + query = "EXPLAIN SELECT * FROM users WHERE age > 18" + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +# ======================================== +# Test Blocked DDL Operations +# ======================================== + + +@pytest.mark.asyncio +async def test_create_table_blocked(dml_only_driver): + """Test that CREATE TABLE statements are blocked""" + query = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_alter_table_blocked(dml_only_driver): + """Test that ALTER TABLE statements are blocked""" + query = "ALTER TABLE users ADD COLUMN age INTEGER" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_drop_table_blocked(dml_only_driver): + """Test that DROP TABLE statements are blocked""" + query = "DROP TABLE users" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_create_index_blocked(dml_only_driver): + """Test that CREATE INDEX statements are blocked""" + query = "CREATE INDEX idx_user_email ON users(email)" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_drop_index_blocked(dml_only_driver): + """Test that DROP INDEX statements are blocked""" + query = "DROP INDEX idx_user_email" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_truncate_blocked(dml_only_driver): + """Test that TRUNCATE statements are blocked""" + query = "TRUNCATE TABLE users" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_create_extension_blocked(dml_only_driver): + """Test that CREATE EXTENSION statements are blocked""" + query = "CREATE EXTENSION IF NOT EXISTS pg_trgm" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_drop_extension_blocked(dml_only_driver): + """Test that DROP EXTENSION statements are blocked""" + query = "DROP EXTENSION pg_trgm" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_vacuum_blocked(dml_only_driver): + """Test that VACUUM statements are blocked""" + query = "VACUUM users" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_create_schema_blocked(dml_only_driver): + """Test that CREATE SCHEMA statements are blocked""" + query = "CREATE SCHEMA test_schema" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_drop_schema_blocked(dml_only_driver): + """Test that DROP SCHEMA statements are blocked""" + query = "DROP SCHEMA test_schema" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_create_database_blocked(dml_only_driver): + """Test that CREATE DATABASE statements are blocked""" + query = "CREATE DATABASE test_db" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_drop_database_blocked(dml_only_driver): + """Test that DROP DATABASE statements are blocked""" + query = "DROP DATABASE test_db" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_create_view_blocked(dml_only_driver): + """Test that CREATE VIEW statements are blocked""" + query = "CREATE VIEW active_users AS SELECT * FROM users WHERE status = 'active'" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_drop_view_blocked(dml_only_driver): + """Test that DROP VIEW statements are blocked""" + query = "DROP VIEW active_users" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_create_function_blocked(dml_only_driver): + """Test that CREATE FUNCTION statements are blocked""" + query = """ + CREATE FUNCTION get_user_count() RETURNS INTEGER AS $$ + BEGIN + RETURN (SELECT COUNT(*) FROM users); + END; + $$ LANGUAGE plpgsql + """ + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_drop_function_blocked(dml_only_driver): + """Test that DROP FUNCTION statements are blocked""" + query = "DROP FUNCTION get_user_count()" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +# ======================================== +# Test Complex DML Queries +# ======================================== + + +@pytest.mark.asyncio +async def test_cte_with_insert(dml_only_driver, mock_sql_driver): + """Test that CTEs with INSERT are allowed""" + query = """ + WITH new_users AS ( + SELECT 'Alice' as name, 'alice@example.com' as email + UNION ALL + SELECT 'Bob' as name, 'bob@example.com' as email + ) + INSERT INTO users (name, email) + SELECT name, email FROM new_users + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_cte_with_update(dml_only_driver, mock_sql_driver): + """Test that CTEs with UPDATE are allowed""" + query = """ + WITH premium_users AS ( + SELECT user_id FROM orders + GROUP BY user_id + HAVING SUM(total) > 10000 + ) + UPDATE users SET status = 'premium' + WHERE id IN (SELECT user_id FROM premium_users) + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_cte_with_delete(dml_only_driver, mock_sql_driver): + """Test that CTEs with DELETE are allowed""" + query = """ + WITH inactive_users AS ( + SELECT id FROM users + WHERE last_login < NOW() - INTERVAL '1 year' + ) + DELETE FROM users WHERE id IN (SELECT id FROM inactive_users) + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_insert_with_complex_subquery(dml_only_driver, mock_sql_driver): + """Test that INSERT with complex subquery is allowed""" + query = """ + INSERT INTO user_stats (user_id, order_count, total_spent) + SELECT + u.id, + COUNT(o.id), + COALESCE(SUM(o.total), 0) + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + GROUP BY u.id + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_update_with_case(dml_only_driver, mock_sql_driver): + """Test that UPDATE with CASE expression is allowed""" + query = """ + UPDATE users + SET tier = CASE + WHEN total_orders > 100 THEN 'platinum' + WHEN total_orders > 50 THEN 'gold' + WHEN total_orders > 10 THEN 'silver' + ELSE 'bronze' + END + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_delete_with_exists(dml_only_driver, mock_sql_driver): + """Test that DELETE with EXISTS clause is allowed""" + query = """ + DELETE FROM users u + WHERE NOT EXISTS ( + SELECT 1 FROM orders o + WHERE o.user_id = u.id + AND o.created_at > NOW() - INTERVAL '1 year' + ) + """ + await dml_only_driver.execute_query(query) + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +# ======================================== +# Test Error Handling +# ======================================== + + +@pytest.mark.asyncio +async def test_invalid_sql_syntax(dml_only_driver): + """Test that queries with invalid SQL syntax are blocked""" + query = "INSERT INTO users (name email) VALUES ('John', 'john@example.com')" + with pytest.raises(ValueError, match="Failed to parse SQL statement"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_sql_injection_attempt(dml_only_driver): + """Test that SQL injection attempts are blocked""" + query = """ + SELECT * FROM users; DROP TABLE users; + """ + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_explain_analyze_blocked(dml_only_driver): + """Test that EXPLAIN ANALYZE is blocked""" + query = "EXPLAIN ANALYZE SELECT * FROM users" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_set_statement_blocked(dml_only_driver): + """Test that SET statements are blocked""" + query = "SET search_path TO public" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_begin_transaction_blocked(dml_only_driver): + """Test that BEGIN TRANSACTION is blocked""" + query = "BEGIN" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_commit_blocked(dml_only_driver): + """Test that COMMIT is blocked""" + query = "COMMIT" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_rollback_blocked(dml_only_driver): + """Test that ROLLBACK is blocked""" + query = "ROLLBACK" + with pytest.raises(ValueError, match="Error validating query"): + await dml_only_driver.execute_query(query) + + +# ======================================== +# Test Timeout Handling +# ======================================== + + +@pytest.mark.asyncio +async def test_timeout_configuration(mock_sql_driver): + """Test that timeout is properly configured""" + driver_with_timeout = DmlOnlySqlDriver(mock_sql_driver, timeout=30) + assert driver_with_timeout.timeout == 30 + + driver_without_timeout = DmlOnlySqlDriver(mock_sql_driver) + assert driver_without_timeout.timeout is None + + +# ======================================== +# Test Force Readonly Parameter +# ======================================== + + +@pytest.mark.asyncio +async def test_force_readonly_false_by_default(dml_only_driver, mock_sql_driver): + """Test that force_readonly is False by default for DML operations""" + query = "INSERT INTO users (name) VALUES ('Test')" + await dml_only_driver.execute_query(query) + # Verify that force_readonly=False is passed to the underlying driver + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) + + +@pytest.mark.asyncio +async def test_force_readonly_can_be_set(dml_only_driver, mock_sql_driver): + """Test that force_readonly can be explicitly set""" + query = "SELECT * FROM users" + await dml_only_driver.execute_query(query, force_readonly=True) + # Verify that force_readonly=True is passed to the underlying driver + mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=True) diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index f7d3b803..afb019c5 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -7,6 +7,7 @@ from postgres_mcp.server import AccessMode from postgres_mcp.server import get_sql_driver +from postgres_mcp.sql.dml_only_sql import DmlOnlySqlDriver from postgres_mcp.sql.safe_sql import SafeSqlDriver from postgres_mcp.sql.sql_driver import DbConnPool from postgres_mcp.sql.sql_driver import SqlDriver @@ -25,6 +26,7 @@ def mock_db_connection(): [ (AccessMode.UNRESTRICTED, SqlDriver), (AccessMode.RESTRICTED, SafeSqlDriver), + (AccessMode.DML_ONLY, DmlOnlySqlDriver), ], ) @pytest.mark.asyncio @@ -37,9 +39,8 @@ async def test_get_sql_driver_returns_correct_driver(access_mode, expected_drive driver = await get_sql_driver() assert isinstance(driver, expected_driver_type) - # When in RESTRICTED mode, verify timeout is set - if access_mode == AccessMode.RESTRICTED: - assert isinstance(driver, SafeSqlDriver) + # When in RESTRICTED or DML_ONLY mode, verify timeout is set + if access_mode in (AccessMode.RESTRICTED, AccessMode.DML_ONLY): assert driver.timeout == 30 @@ -68,6 +69,19 @@ async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection assert not hasattr(driver, "timeout") +@pytest.mark.asyncio +async def test_get_sql_driver_sets_timeout_in_dml_only_mode(mock_db_connection): + """Test that get_sql_driver sets the timeout in DML_ONLY mode.""" + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.DML_ONLY), + patch("postgres_mcp.server.db_connection", mock_db_connection), + ): + driver = await get_sql_driver() + assert isinstance(driver, DmlOnlySqlDriver) + assert driver.timeout == 30 + assert hasattr(driver, "sql_driver") + + @pytest.mark.asyncio async def test_command_line_parsing(): """Test that command-line arguments correctly set the access mode.""" From d34bc0c8bb478af06878e897b01afd47c6509fc9 Mon Sep 17 00:00:00 2001 From: fuzzwah Date: Mon, 5 Jan 2026 09:01:57 +1100 Subject: [PATCH 2/3] Fix code review issues in DML_ONLY implementation Addressed all critical issues from code review: 1. RawStmt validation: Now validates inner statement (stmt.stmt) instead of wrapper 2. Error messages: Removed generic wrapper, now returns specific detailed errors 3. Function validation: Use isinstance(node, FuncCall) before accessing funcname 4. Timeout handling: Changed from ValueError to TimeoutError for better error distinction 5. Type safety: Added proper DefElem and FuncCall imports and type checking - Mirror SafeSqlDriver implementation patterns exactly - Improve error message clarity (e.g., 'Statement type CreateStmt not allowed in DML_ONLY mode') - Update all 46 tests to match new error message patterns - All tests passing (53/53) - Code style checks passing (ruff) --- src/postgres_mcp/sql/dml_only_sql.py | 54 ++++++++++++++-------------- tests/unit/sql/test_dml_only_sql.py | 48 ++++++++++++------------- 2 files changed, 52 insertions(+), 50 deletions(-) diff --git a/src/postgres_mcp/sql/dml_only_sql.py b/src/postgres_mcp/sql/dml_only_sql.py index bae2d147..d6747537 100644 --- a/src/postgres_mcp/sql/dml_only_sql.py +++ b/src/postgres_mcp/sql/dml_only_sql.py @@ -7,8 +7,10 @@ from typing import Optional import pglast +from pglast.ast import DefElem from pglast.ast import DeleteStmt from pglast.ast import ExplainStmt +from pglast.ast import FuncCall from pglast.ast import IndexElem from pglast.ast import InferClause from pglast.ast import InsertStmt @@ -71,22 +73,22 @@ def _validate_node(self, node: Node) -> None: """Recursively validate a node and all its children""" # Check if node type is allowed if not isinstance(node, tuple(self.ALLOWED_NODE_TYPES)): - raise ValueError(f"Node type {type(node)} is not allowed") + raise ValueError(f"Node type {type(node).__name__} is not allowed") - # Validate function calls (reuse logic from SafeSqlDriver) - if hasattr(node, "funcname") and node.funcname: + # Validate function calls (mirror SafeSqlDriver implementation) + if isinstance(node, FuncCall): func_name = ".".join([str(n.sval) for n in node.funcname]).lower() if node.funcname else "" # Strip pg_catalog schema if present match = SafeSqlDriver.PG_CATALOG_PATTERN.match(func_name) unqualified_name = match.group(1) if match else func_name if unqualified_name not in self.ALLOWED_FUNCTIONS: - raise ValueError(f"Function {func_name} is not allowed") + raise ValueError(f"Function {func_name} is not allowed in DML_ONLY mode") # Reject EXPLAIN ANALYZE statements if isinstance(node, ExplainStmt): for option in node.options or []: - if hasattr(option, "defname") and option.defname == "analyze": - raise ValueError("EXPLAIN ANALYZE is not supported") + if isinstance(option, DefElem) and option.defname == "analyze": + raise ValueError("EXPLAIN ANALYZE is not supported in DML_ONLY mode") # Recursively validate all attributes that might be nodes for attr_name in node.__slots__: @@ -123,27 +125,27 @@ def _validate(self, query: str) -> None: parsed = pglast.parse_sql(query) # Validate each statement - try: - for stmt in parsed: - if isinstance(stmt, RawStmt): - # Check if the inner statement type is allowed - if not isinstance(stmt.stmt, tuple(self.ALLOWED_STMT_TYPES)): - raise ValueError( - f"Only SELECT, INSERT, UPDATE, DELETE, EXPLAIN, and SHOW statements are allowed. " - f"DDL operations are blocked. Received: {type(stmt.stmt).__name__}" - ) - else: - if not isinstance(stmt, tuple(self.ALLOWED_STMT_TYPES)): - raise ValueError( - f"Only SELECT, INSERT, UPDATE, DELETE, EXPLAIN, and SHOW statements are allowed. " - f"DDL operations are blocked. Received: {type(stmt).__name__}" - ) + for stmt in parsed: + if isinstance(stmt, RawStmt): + # Check if the inner statement type is allowed + if not isinstance(stmt.stmt, tuple(self.ALLOWED_STMT_TYPES)): + raise ValueError( + f"Statement type {type(stmt.stmt).__name__} not allowed in DML_ONLY mode. " + f"Only SELECT, INSERT, UPDATE, DELETE, EXPLAIN, and SHOW are permitted." + ) + # Validate the inner statement node, not the RawStmt wrapper + self._validate_node(stmt.stmt) + else: + # Direct statement (not wrapped) + if not isinstance(stmt, tuple(self.ALLOWED_STMT_TYPES)): + raise ValueError( + f"Statement type {type(stmt).__name__} not allowed in DML_ONLY mode. " + f"Only SELECT, INSERT, UPDATE, DELETE, EXPLAIN, and SHOW are permitted." + ) self._validate_node(stmt) - except Exception as e: - raise ValueError(f"Error validating query: {query}") from e except pglast.parser.ParseError as e: - raise ValueError("Failed to parse SQL statement") from e + raise ValueError(f"SQL parsing failed: {e!s}") from e async def execute_query( self, @@ -165,8 +167,8 @@ async def execute_query( ) except asyncio.TimeoutError as e: logger.warning(f"Query execution timed out after {self.timeout} seconds: {query[:100]}...") - raise ValueError( - f"Query execution timed out after {self.timeout} seconds in DML_ONLY mode. " + raise TimeoutError( + f"Query timed out after {self.timeout} seconds in DML_ONLY mode. " "Consider simplifying your query or increasing the timeout." ) from e except Exception as e: diff --git a/tests/unit/sql/test_dml_only_sql.py b/tests/unit/sql/test_dml_only_sql.py index 1a0f1a90..64852a44 100644 --- a/tests/unit/sql/test_dml_only_sql.py +++ b/tests/unit/sql/test_dml_only_sql.py @@ -168,7 +168,7 @@ async def test_create_table_blocked(dml_only_driver): created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -176,7 +176,7 @@ async def test_create_table_blocked(dml_only_driver): async def test_alter_table_blocked(dml_only_driver): """Test that ALTER TABLE statements are blocked""" query = "ALTER TABLE users ADD COLUMN age INTEGER" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -184,7 +184,7 @@ async def test_alter_table_blocked(dml_only_driver): async def test_drop_table_blocked(dml_only_driver): """Test that DROP TABLE statements are blocked""" query = "DROP TABLE users" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -192,7 +192,7 @@ async def test_drop_table_blocked(dml_only_driver): async def test_create_index_blocked(dml_only_driver): """Test that CREATE INDEX statements are blocked""" query = "CREATE INDEX idx_user_email ON users(email)" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -200,7 +200,7 @@ async def test_create_index_blocked(dml_only_driver): async def test_drop_index_blocked(dml_only_driver): """Test that DROP INDEX statements are blocked""" query = "DROP INDEX idx_user_email" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -208,7 +208,7 @@ async def test_drop_index_blocked(dml_only_driver): async def test_truncate_blocked(dml_only_driver): """Test that TRUNCATE statements are blocked""" query = "TRUNCATE TABLE users" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -216,7 +216,7 @@ async def test_truncate_blocked(dml_only_driver): async def test_create_extension_blocked(dml_only_driver): """Test that CREATE EXTENSION statements are blocked""" query = "CREATE EXTENSION IF NOT EXISTS pg_trgm" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -224,7 +224,7 @@ async def test_create_extension_blocked(dml_only_driver): async def test_drop_extension_blocked(dml_only_driver): """Test that DROP EXTENSION statements are blocked""" query = "DROP EXTENSION pg_trgm" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -232,7 +232,7 @@ async def test_drop_extension_blocked(dml_only_driver): async def test_vacuum_blocked(dml_only_driver): """Test that VACUUM statements are blocked""" query = "VACUUM users" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -240,7 +240,7 @@ async def test_vacuum_blocked(dml_only_driver): async def test_create_schema_blocked(dml_only_driver): """Test that CREATE SCHEMA statements are blocked""" query = "CREATE SCHEMA test_schema" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -248,7 +248,7 @@ async def test_create_schema_blocked(dml_only_driver): async def test_drop_schema_blocked(dml_only_driver): """Test that DROP SCHEMA statements are blocked""" query = "DROP SCHEMA test_schema" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -256,7 +256,7 @@ async def test_drop_schema_blocked(dml_only_driver): async def test_create_database_blocked(dml_only_driver): """Test that CREATE DATABASE statements are blocked""" query = "CREATE DATABASE test_db" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -264,7 +264,7 @@ async def test_create_database_blocked(dml_only_driver): async def test_drop_database_blocked(dml_only_driver): """Test that DROP DATABASE statements are blocked""" query = "DROP DATABASE test_db" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -272,7 +272,7 @@ async def test_drop_database_blocked(dml_only_driver): async def test_create_view_blocked(dml_only_driver): """Test that CREATE VIEW statements are blocked""" query = "CREATE VIEW active_users AS SELECT * FROM users WHERE status = 'active'" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -280,7 +280,7 @@ async def test_create_view_blocked(dml_only_driver): async def test_drop_view_blocked(dml_only_driver): """Test that DROP VIEW statements are blocked""" query = "DROP VIEW active_users" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -294,7 +294,7 @@ async def test_create_function_blocked(dml_only_driver): END; $$ LANGUAGE plpgsql """ - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -302,7 +302,7 @@ async def test_create_function_blocked(dml_only_driver): async def test_drop_function_blocked(dml_only_driver): """Test that DROP FUNCTION statements are blocked""" query = "DROP FUNCTION get_user_count()" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -414,7 +414,7 @@ async def test_delete_with_exists(dml_only_driver, mock_sql_driver): async def test_invalid_sql_syntax(dml_only_driver): """Test that queries with invalid SQL syntax are blocked""" query = "INSERT INTO users (name email) VALUES ('John', 'john@example.com')" - with pytest.raises(ValueError, match="Failed to parse SQL statement"): + with pytest.raises(ValueError, match="SQL parsing failed"): await dml_only_driver.execute_query(query) @@ -424,7 +424,7 @@ async def test_sql_injection_attempt(dml_only_driver): query = """ SELECT * FROM users; DROP TABLE users; """ - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -432,7 +432,7 @@ async def test_sql_injection_attempt(dml_only_driver): async def test_explain_analyze_blocked(dml_only_driver): """Test that EXPLAIN ANALYZE is blocked""" query = "EXPLAIN ANALYZE SELECT * FROM users" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="EXPLAIN ANALYZE is not supported in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -440,7 +440,7 @@ async def test_explain_analyze_blocked(dml_only_driver): async def test_set_statement_blocked(dml_only_driver): """Test that SET statements are blocked""" query = "SET search_path TO public" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -448,7 +448,7 @@ async def test_set_statement_blocked(dml_only_driver): async def test_begin_transaction_blocked(dml_only_driver): """Test that BEGIN TRANSACTION is blocked""" query = "BEGIN" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -456,7 +456,7 @@ async def test_begin_transaction_blocked(dml_only_driver): async def test_commit_blocked(dml_only_driver): """Test that COMMIT is blocked""" query = "COMMIT" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) @@ -464,7 +464,7 @@ async def test_commit_blocked(dml_only_driver): async def test_rollback_blocked(dml_only_driver): """Test that ROLLBACK is blocked""" query = "ROLLBACK" - with pytest.raises(ValueError, match="Error validating query"): + with pytest.raises(ValueError, match="not allowed in DML_ONLY mode"): await dml_only_driver.execute_query(query) From 086bde6c7b814c49b29f60d242cd2ab97e0a1fb0 Mon Sep 17 00:00:00 2001 From: fuzzwah Date: Mon, 5 Jan 2026 09:12:52 +1100 Subject: [PATCH 3/3] Require WHERE clause for UPDATE and DELETE statements Add critical safety feature to prevent accidental data loss: - UPDATE and DELETE statements now require WHERE clause - Prevents accidental modification/deletion of all rows - Added 2 new tests for WHERE clause validation - Updated test_update_with_case to include WHERE clause - Updated PR description to reflect this design decision All 55 tests passing (48 DML_ONLY + 7 access mode). --- PR_DESCRIPTION.md | 88 ++++++++++++++++++++++++++++ src/postgres_mcp/sql/dml_only_sql.py | 16 +++++ tests/unit/sql/test_dml_only_sql.py | 22 +++++++ 3 files changed, 126 insertions(+) create mode 100644 PR_DESCRIPTION.md diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md new file mode 100644 index 00000000..ff3fa249 --- /dev/null +++ b/PR_DESCRIPTION.md @@ -0,0 +1,88 @@ +# Add DML_ONLY Access Mode + +## Overview +Adds a new `DML_ONLY` access mode to postgres-mcp that allows data manipulation (INSERT, UPDATE, DELETE, SELECT) while blocking schema changes and other DDL operations. This provides a middle ground between `UNRESTRICTED` (allows everything) and `RESTRICTED` (read-only) modes. + +## Motivation +Users often need to perform data modifications without having the ability to alter database schema. The existing access modes didn't support this use case: +- `UNRESTRICTED`: Too permissive, allows DDL operations +- `RESTRICTED`: Too restrictive, blocks all writes including DML + +`DML_ONLY` mode enables safe data manipulation for use cases like: +- Application agents that need to insert/update data but shouldn't modify schema +- Data import/migration scripts that should be isolated from structural changes +- Development environments where schema changes require explicit approval + +## Implementation Details + +**New Components:** +- `DmlOnlySqlDriver` class in `src/postgres_mcp/sql/dml_only_sql.py` + - Wraps underlying SqlDriver with validation layer + - Uses pglast to parse and validate SQL AST before execution + - Reuses `SafeSqlDriver.ALLOWED_FUNCTIONS` and extends `ALLOWED_NODE_TYPES` + +**Allowed Operations:** +- ✅ `SELECT` - Read queries +- ✅ `INSERT` - Including UPSERT with `ON CONFLICT` +- ✅ `UPDATE` - With all standard clauses (WHERE, RETURNING, etc.) +- ✅ `DELETE` - With all standard clauses +- ✅ `EXPLAIN` - Query planning (but not `EXPLAIN ANALYZE`) +- ✅ `SHOW` - View configuration variables +- ✅ Complex queries (CTEs, subqueries, JOINs, CASE expressions) + +**Blocked Operations:** +- ❌ `CREATE/ALTER/DROP TABLE` +- ❌ `CREATE/DROP INDEX` +- ❌ `CREATE/DROP VIEW/FUNCTION/SCHEMA/DATABASE` +- ❌ `TRUNCATE` +- ❌ `VACUUM` +- ❌ `CREATE/DROP EXTENSION` +- ❌ `SET` (configuration changes) +- ❌ `BEGIN/COMMIT/ROLLBACK` (transaction control) +- ❌ `EXPLAIN ANALYZE` (can impact performance) + +**Usage:** +```bash +# Start server with DML_ONLY mode +mcp-server-postgres postgres://user:pass@localhost/dbname --access-mode dml_only + +# With timeout (recommended) +mcp-server-postgres postgres://user:pass@localhost/dbname \ + --access-mode dml_only \ + --query-timeout 30 +``` + +## Testing +- **46 unit tests** for DML_ONLY driver covering: + - 13 tests for allowed DML operations + - 18 tests for blocked DDL operations + - 6 tests for complex queries (CTEs, subqueries, etc.) + - 9 tests for error handling and edge cases +- **7 integration tests** for access mode selection +- All existing tests continue to pass (no regression) + +## Design Decisions + +1. **RawStmt Validation**: Validates inner statement (`stmt.stmt`) to properly check the actual SQL command, not just the wrapper node + +2. **Function Allow-list**: Reuses `SafeSqlDriver.ALLOWED_FUNCTIONS` to maintain consistency with read-only mode's security model + +3. **Error Messages**: Provides specific, actionable error messages (e.g., "Statement type CreateStmt not allowed in DML_ONLY mode") rather than generic failures + +4. **Timeout Handling**: Returns `TimeoutError` (not `ValueError`) for query timeouts, enabling callers to distinguish timeout from validation failures + +5. **DELETE/UPDATE without WHERE**: Required for safety. UPDATE and DELETE statements must include a WHERE clause to prevent accidental modification or deletion of all rows. This is a critical safety feature that helps prevent data loss from mistaken queries + +## Files Changed +- `src/postgres_mcp/server.py` - Add DML_ONLY to AccessMode enum, update driver selection +- `src/postgres_mcp/sql/dml_only_sql.py` - New DmlOnlySqlDriver implementation +- `src/postgres_mcp/sql/__init__.py` - Export DmlOnlySqlDriver +- `tests/unit/sql/test_dml_only_sql.py` - New comprehensive test suite +- `tests/unit/test_access_mode.py` - Add DML_ONLY test cases +- `README.md` - Document new access mode and usage + +## Related Documentation +Implementation follows the plan outlined in `DML_ONLY_MODE_IMPLEMENTATION.md` + +## Breaking Changes +None. This is a purely additive feature that doesn't modify existing behavior. diff --git a/src/postgres_mcp/sql/dml_only_sql.py b/src/postgres_mcp/sql/dml_only_sql.py index d6747537..cad176d1 100644 --- a/src/postgres_mcp/sql/dml_only_sql.py +++ b/src/postgres_mcp/sql/dml_only_sql.py @@ -133,6 +133,14 @@ def _validate(self, query: str) -> None: f"Statement type {type(stmt.stmt).__name__} not allowed in DML_ONLY mode. " f"Only SELECT, INSERT, UPDATE, DELETE, EXPLAIN, and SHOW are permitted." ) + # Require WHERE clause for UPDATE and DELETE to prevent accidental data loss + if isinstance(stmt.stmt, (UpdateStmt, DeleteStmt)): + if not stmt.stmt.whereClause: + stmt_type = "UPDATE" if isinstance(stmt.stmt, UpdateStmt) else "DELETE" + raise ValueError( + f"{stmt_type} statements require a WHERE clause in DML_ONLY mode. " + "This prevents accidental deletion or modification of all rows." + ) # Validate the inner statement node, not the RawStmt wrapper self._validate_node(stmt.stmt) else: @@ -142,6 +150,14 @@ def _validate(self, query: str) -> None: f"Statement type {type(stmt).__name__} not allowed in DML_ONLY mode. " f"Only SELECT, INSERT, UPDATE, DELETE, EXPLAIN, and SHOW are permitted." ) + # Require WHERE clause for UPDATE and DELETE + if isinstance(stmt, (UpdateStmt, DeleteStmt)): + if not stmt.whereClause: + stmt_type = "UPDATE" if isinstance(stmt, UpdateStmt) else "DELETE" + raise ValueError( + f"{stmt_type} statements require a WHERE clause in DML_ONLY mode. " + "This prevents accidental deletion or modification of all rows." + ) self._validate_node(stmt) except pglast.parser.ParseError as e: diff --git a/tests/unit/sql/test_dml_only_sql.py b/tests/unit/sql/test_dml_only_sql.py index 64852a44..10274e78 100644 --- a/tests/unit/sql/test_dml_only_sql.py +++ b/tests/unit/sql/test_dml_only_sql.py @@ -385,6 +385,7 @@ async def test_update_with_case(dml_only_driver, mock_sql_driver): WHEN total_orders > 10 THEN 'silver' ELSE 'bronze' END + WHERE id > 0 """ await dml_only_driver.execute_query(query) mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=False) @@ -504,3 +505,24 @@ async def test_force_readonly_can_be_set(dml_only_driver, mock_sql_driver): await dml_only_driver.execute_query(query, force_readonly=True) # Verify that force_readonly=True is passed to the underlying driver mock_sql_driver.execute_query.assert_awaited_once_with("/* crystaldba */ " + query, params=None, force_readonly=True) + + +# ======================================== +# Test WHERE Clause Requirement +# ======================================== + + +@pytest.mark.asyncio +async def test_delete_without_where_blocked(dml_only_driver): + """Test that DELETE without WHERE clause is blocked""" + query = "DELETE FROM users" + with pytest.raises(ValueError, match="DELETE statements require a WHERE clause"): + await dml_only_driver.execute_query(query) + + +@pytest.mark.asyncio +async def test_update_without_where_blocked(dml_only_driver): + """Test that UPDATE without WHERE clause is blocked""" + query = "UPDATE users SET status = 'inactive'" + with pytest.raises(ValueError, match="UPDATE statements require a WHERE clause"): + await dml_only_driver.execute_query(query)