diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 6a825a1..f3ba8f8 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -568,9 +568,9 @@ async def main(): parser.add_argument( "--transport", type=str, - choices=["stdio", "sse"], + choices=["stdio", "sse", "streamable-http"], default="stdio", - help="Select MCP transport: stdio (default) or sse", + help="Select MCP transport: stdio (default), sse, or streamable-http", ) parser.add_argument( "--sse-host", @@ -584,6 +584,18 @@ async def main(): default=8000, help="Port for SSE server (default: 8000)", ) + parser.add_argument( + "--streamable-http-host", + type=str, + default="localhost", + help="Host to bind streamable HTTP server to (default: localhost)", + ) + parser.add_argument( + "--streamable-http-port", + type=int, + default=8000, + help="Port for streamable HTTP server (default: 8000)", + ) args = parser.parse_args() @@ -647,11 +659,14 @@ async def main(): # Run the server with the selected transport (always async) if args.transport == "stdio": await mcp.run_stdio_async() - else: - # Update FastMCP settings based on command line arguments + elif args.transport == "sse": mcp.settings.host = args.sse_host mcp.settings.port = args.sse_port await mcp.run_sse_async() + elif args.transport == "streamable-http": + mcp.settings.host = args.streamable_http_host + mcp.settings.port = args.streamable_http_port + await mcp.run_streamable_http_async() async def shutdown(sig=None): diff --git a/tests/unit/test_transport.py b/tests/unit/test_transport.py new file mode 100644 index 0000000..4197ace --- /dev/null +++ b/tests/unit/test_transport.py @@ -0,0 +1,129 @@ +import sys +from unittest.mock import AsyncMock +from unittest.mock import patch + +import pytest + + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", ["stdio", "sse", "streamable-http"]) +async def test_transport_argument_parsing(transport): + """Test that all transport options are parsed correctly.""" + from postgres_mcp.server import main + + original_argv = sys.argv + try: + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()) as mock_stdio, + patch("postgres_mcp.server.mcp.run_sse_async", AsyncMock()) as mock_sse, + patch("postgres_mcp.server.mcp.run_streamable_http_async", AsyncMock()) as mock_http, + ): + await main() + + # Verify the correct transport method was called + if transport == "stdio": + mock_stdio.assert_called_once() + mock_sse.assert_not_called() + mock_http.assert_not_called() + elif transport == "sse": + mock_stdio.assert_not_called() + mock_sse.assert_called_once() + mock_http.assert_not_called() + elif transport == "streamable-http": + mock_stdio.assert_not_called() + mock_sse.assert_not_called() + mock_http.assert_called_once() + finally: + sys.argv = original_argv + + +@pytest.mark.asyncio +async def test_streamable_http_host_port_arguments(): + """Test that streamable-http host and port arguments are applied correctly.""" + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + original_argv = sys.argv + try: + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + "--transport=streamable-http", + "--streamable-http-host=0.0.0.0", + "--streamable-http-port=9000", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_streamable_http_async", AsyncMock()), + ): + await main() + + # Verify the host and port were set correctly + assert mcp.settings.host == "0.0.0.0" + assert mcp.settings.port == 9000 + finally: + sys.argv = original_argv + + +@pytest.mark.asyncio +async def test_sse_host_port_arguments(): + """Test that SSE host and port arguments are applied correctly.""" + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + original_argv = sys.argv + try: + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + "--transport=sse", + "--sse-host=0.0.0.0", + "--sse-port=8080", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_sse_async", AsyncMock()), + ): + await main() + + # Verify the host and port were set correctly + assert mcp.settings.host == "0.0.0.0" + assert mcp.settings.port == 8080 + finally: + sys.argv = original_argv + + +@pytest.mark.asyncio +async def test_default_transport_is_stdio(): + """Test that the default transport is stdio when not specified.""" + from postgres_mcp.server import main + + original_argv = sys.argv + try: + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()) as mock_stdio, + patch("postgres_mcp.server.mcp.run_sse_async", AsyncMock()) as mock_sse, + patch("postgres_mcp.server.mcp.run_streamable_http_async", AsyncMock()) as mock_http, + ): + await main() + + mock_stdio.assert_called_once() + mock_sse.assert_not_called() + mock_http.assert_not_called() + finally: + sys.argv = original_argv