Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/postgres_mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def main():
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

asyncio.run(server.main())
try:
asyncio.run(server.main())
except KeyboardInterrupt:
# Handle Ctrl+C gracefully without printing a traceback
pass


# Optionally expose other important items at package level
Expand Down
53 changes: 13 additions & 40 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# ruff: noqa: B008
import argparse
import asyncio
import logging
import os
import signal
import sys
from enum import Enum
from typing import Any
from typing import List
Expand Down Expand Up @@ -56,7 +53,6 @@ class AccessMode(str, Enum):
# Global variables
db_connection = DbConnPool()
current_access_mode = AccessMode.UNRESTRICTED
shutdown_in_progress = False


async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]:
Expand Down Expand Up @@ -633,47 +629,24 @@ async def main():
"The MCP server will start but database operations will fail until a valid connection is established.",
)

# Set up proper shutdown handling
# Run the server with the selected transport, with proper cleanup on exit
try:
loop = asyncio.get_running_loop()
signals = (signal.SIGTERM, signal.SIGINT)
for s in signals:
loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s)))
except NotImplementedError:
# Windows doesn't support signals properly
logger.warning("Signal handling not supported on Windows")
pass

# Run the server with the selected transport (always async)
if args.transport == "stdio":
await mcp.run_stdio_async()
else:
# Update FastMCP settings based on command line arguments
mcp.settings.host = args.sse_host
mcp.settings.port = args.sse_port
await mcp.run_sse_async()


async def shutdown(sig=None):
"""Clean shutdown of the server."""
global shutdown_in_progress

if shutdown_in_progress:
logger.warning("Forcing immediate exit")
# Use sys.exit instead of os._exit to allow for proper cleanup
sys.exit(1)

shutdown_in_progress = True
if args.transport == "stdio":
await mcp.run_stdio_async()
else:
mcp.settings.host = args.sse_host
mcp.settings.port = args.sse_port
await mcp.run_sse_async()
finally:
# Clean up database connections on exit
await cleanup()
Comment on lines +640 to +642

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Ensure cleanup runs under cancellation

When the server is interrupted (SIGINT/SIGTERM), asyncio.run() cancels the main() task; that cancellation propagates into the finally block and will cancel the awaited cleanup() call unless it is shielded or the CancelledError is handled. In that case, db_connection.close() may never run, so database connections can remain open on Ctrl+C despite the new shutdown simplification. Consider shielding cleanup() or catching CancelledError so cleanup completes even under cancellation.

Useful? React with 👍 / 👎.


if sig:
logger.info(f"Received exit signal {sig.name}")

# Close database connections
async def cleanup():
"""Clean up resources on server shutdown."""
logger.info("Shutting down server...")
try:
await db_connection.close()
logger.info("Closed database connections")
except Exception as e:
logger.error(f"Error closing database connections: {e}")

# Exit with appropriate status code
sys.exit(128 + sig if sig is not None else 0)
2 changes: 1 addition & 1 deletion tests/unit/test_access_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def test_command_line_parsing():
patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED),
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()),
patch("postgres_mcp.server.shutdown", AsyncMock()),
patch("postgres_mcp.server.cleanup", AsyncMock()),
):
# Reset the current_access_mode to UNRESTRICTED
import postgres_mcp.server
Expand Down
85 changes: 85 additions & 0 deletions tests/unit/test_shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import sys
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch

import pytest


@pytest.mark.asyncio
async def test_cleanup_closes_db_connection():
"""Test that cleanup properly closes database connections."""
from postgres_mcp.server import cleanup

mock_db = MagicMock()
mock_db.close = AsyncMock()

with patch("postgres_mcp.server.db_connection", mock_db):
await cleanup()
mock_db.close.assert_called_once()


@pytest.mark.asyncio
async def test_cleanup_handles_db_close_error():
"""Test that cleanup handles errors when closing database connections."""
from postgres_mcp.server import cleanup

mock_db = MagicMock()
mock_db.close = AsyncMock(side_effect=Exception("Connection error"))

with patch("postgres_mcp.server.db_connection", mock_db):
# Should not raise, just log the error
await cleanup()
mock_db.close.assert_called_once()


@pytest.mark.asyncio
async def test_main_calls_cleanup_on_normal_exit():
"""Test that main() calls cleanup when transport exits normally."""
from postgres_mcp.server import main

original_argv = sys.argv
try:
sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
]

mock_cleanup = AsyncMock()

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()),
patch("postgres_mcp.server.cleanup", mock_cleanup),
):
await main()
mock_cleanup.assert_called_once()
finally:
sys.argv = original_argv


@pytest.mark.asyncio
async def test_main_calls_cleanup_on_exception():
"""Test that main() calls cleanup even when transport raises an exception."""
from postgres_mcp.server import main

original_argv = sys.argv
try:
sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
]

mock_cleanup = AsyncMock()

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock(side_effect=Exception("Transport error"))),
patch("postgres_mcp.server.cleanup", mock_cleanup),
):
with pytest.raises(Exception, match="Transport error"):
await main()
# Cleanup should still be called due to finally block
mock_cleanup.assert_called_once()
finally:
sys.argv = original_argv