Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,16 @@ For Windsurf, the format in `mcp_config.json` is slightly different:
}
```

For local integration with your browser using, for example, [MCP for claude.ai](https://chromewebstore.google.com/detail/jbdhaamjibfahpekpnjeikanebpdpfpb?utm_source=item-share-cb), you may need to allow certain CORS origins, such as https://claude.ai. To do this, start the server with the `--cors-origins` parameter and provide the list of origins you want to whitelist.

For example, with Docker run:

```bash
docker run -p 8000:8000 \
-e DATABASE_URI=postgresql://username:password@localhost:5432/dbname \
crystaldba/postgres-mcp --access-mode=unrestricted --transport=sse --cors-origins https://claude.ai
```

## Postgres Extension Installation (Optional)

To enable index tuning and comprehensive performance analysis you need to load the `pg_statements` and `hypopg` extensions on your database.
Expand Down
31 changes: 27 additions & 4 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from typing import Union

import mcp.types as types
import uvicorn
from mcp.server.fastmcp import FastMCP
from pydantic import Field
from pydantic import validate_call
from starlette.middleware.cors import CORSMiddleware

from postgres_mcp.index.dta_calc import DatabaseTuningAdvisor

Expand Down Expand Up @@ -539,6 +541,12 @@ async def main():
default=8000,
help="Port for SSE server (default: 8000)",
)
parser.add_argument(
"--cors-origins",
nargs="*",
default=[],
help="List of allowed CORS origins (default: empty, no CORS)",
)

args = parser.parse_args()

Expand Down Expand Up @@ -589,10 +597,25 @@ async def main():
if args.transport == "stdio":
await mcp.run_stdio_async()
else:
# Update FastMCP settings based on command line arguments
mcp.settings.host = args.sse_host
mcp.settings.port = args.sse_port
await mcp.run_sse_async()
starlette_app = mcp.sse_app()

if args.cors_origins:
logger.info(f"Enabling CORS for origins: {', '.join(args.cors_origins)}")
starlette_app.add_middleware(
CORSMiddleware,
allow_origins=args.cors_origins,
allow_methods=['GET', 'POST', 'OPTIONS'],
allow_headers=['*']
)

config = uvicorn.Config(
starlette_app,
host=args.sse_host,
port=args.sse_port,
log_level="info",
)
server = uvicorn.Server(config)
await server.serve()


async def shutdown(sig=None):
Expand Down
260 changes: 260 additions & 0 deletions tests/unit/test_cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
"""Tests for CORS support in SSE transport."""

import multiprocessing
import socket
import time

import pytest
import requests
from starlette.middleware.cors import CORSMiddleware
from starlette.testclient import TestClient

from postgres_mcp.server import mcp


def find_free_port():
"""Find a free port to use for testing."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


def run_server(port: int, cors_origins: list[str]):
"""Run the MCP server in a subprocess."""
import asyncio

import uvicorn
from starlette.middleware.cors import CORSMiddleware

from postgres_mcp.server import mcp

starlette_app = mcp.sse_app()
if cors_origins:
starlette_app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)

config = uvicorn.Config(
starlette_app,
host="127.0.0.1",
port=port,
log_level="error",
)
server = uvicorn.Server(config)
asyncio.run(server.serve())


@pytest.fixture
def app_with_cors():
"""Create an SSE app with CORS middleware configured."""
app = mcp.sse_app()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://claude.ai", "https://example.com"],
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
return app


@pytest.fixture
def app_without_cors():
"""Create an SSE app without CORS middleware."""
return mcp.sse_app()


class TestCorsPreflightRequests:
"""Test CORS preflight (OPTIONS) requests."""

def test_preflight_allowed_origin_returns_cors_headers(self, app_with_cors):
"""OPTIONS preflight from allowed origin should return CORS headers."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.options(
"/sse",
headers={
"Origin": "https://claude.ai",
"Access-Control-Request-Method": "GET",
},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == "https://claude.ai"
assert "GET" in response.headers.get("access-control-allow-methods", "")

def test_preflight_second_allowed_origin(self, app_with_cors):
"""OPTIONS preflight from second allowed origin should also work."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.options(
"/sse",
headers={
"Origin": "https://example.com",
"Access-Control-Request-Method": "GET",
},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == "https://example.com"

def test_preflight_disallowed_origin_no_cors_header(self, app_with_cors):
"""OPTIONS preflight from non-allowed origin should not return CORS header."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.options(
"/sse",
headers={
"Origin": "https://malicious.com",
"Access-Control-Request-Method": "GET",
},
)
# The response may be 200 or 400, but should NOT have the allow-origin header
assert response.headers.get("access-control-allow-origin") is None

def test_preflight_messages_endpoint(self, app_with_cors):
"""OPTIONS preflight on /messages/ endpoint should also work."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.options(
"/messages/",
headers={
"Origin": "https://claude.ai",
"Access-Control-Request-Method": "POST",
},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == "https://claude.ai"
assert "POST" in response.headers.get("access-control-allow-methods", "")


class TestCorsOnActualRequests:
"""Test CORS headers on actual (non-preflight) requests."""

def test_post_request_with_allowed_origin(self, app_with_cors):
"""POST request from allowed origin should include CORS header in response."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
# Send a POST to /messages/ - it will fail (no valid session) but CORS headers should be present
response = client.post(
"/messages/",
headers={"Origin": "https://claude.ai"},
content="test",
)
# Even if the request fails, CORS headers should be present
assert response.headers.get("access-control-allow-origin") == "https://claude.ai"

def test_post_request_with_disallowed_origin(self, app_with_cors):
"""POST request from non-allowed origin should not have CORS header."""
client = TestClient(app_with_cors, raise_server_exceptions=False)
response = client.post(
"/messages/",
headers={"Origin": "https://malicious.com"},
content="test",
)
assert response.headers.get("access-control-allow-origin") is None


class TestCorsDisabled:
"""Test behavior when CORS middleware is not configured."""

def test_preflight_without_cors_middleware(self, app_without_cors):
"""App without CORS middleware should not handle preflight specially."""
client = TestClient(app_without_cors, raise_server_exceptions=False)
response = client.options(
"/sse",
headers={
"Origin": "https://claude.ai",
"Access-Control-Request-Method": "GET",
},
)
assert response.headers.get("access-control-allow-origin") is None

def test_request_without_cors_middleware(self, app_without_cors):
"""App without CORS middleware should not return CORS headers."""
client = TestClient(app_without_cors, raise_server_exceptions=False)
response = client.post(
"/messages/",
headers={"Origin": "https://claude.ai"},
content="test",
)
assert response.headers.get("access-control-allow-origin") is None


class TestCorsEndToEnd:
"""End-to-end tests that start an actual server process."""

def test_server_with_cors_enabled(self):
"""Test that a real server with CORS returns correct headers."""
port = find_free_port()
cors_origins = ["https://claude.ai", "https://example.com"]

# Start server in subprocess
proc = multiprocessing.Process(target=run_server, args=(port, cors_origins))
proc.start()

try:
# Wait for server to start
for _ in range(50):
try:
requests.options(f"http://127.0.0.1:{port}/sse", timeout=0.1)
break
except requests.exceptions.ConnectionError:
time.sleep(0.1)
else:
pytest.fail("Server did not start in time")

# Test allowed origin
response = requests.options(
f"http://127.0.0.1:{port}/sse",
headers={
"Origin": "https://claude.ai",
"Access-Control-Request-Method": "GET",
},
timeout=5,
)
assert response.headers.get("access-control-allow-origin") == "https://claude.ai"

# Test disallowed origin
response = requests.options(
f"http://127.0.0.1:{port}/sse",
headers={
"Origin": "https://malicious.com",
"Access-Control-Request-Method": "GET",
},
timeout=5,
)
assert response.headers.get("access-control-allow-origin") is None

finally:
proc.terminate()
proc.join(timeout=5)

def test_server_without_cors(self):
"""Test that a server without CORS does not return CORS headers."""
port = find_free_port()

# Start server without CORS
proc = multiprocessing.Process(target=run_server, args=(port, []))
proc.start()

try:
# Wait for server to start
for _ in range(50):
try:
requests.options(f"http://127.0.0.1:{port}/sse", timeout=0.1)
break
except requests.exceptions.ConnectionError:
time.sleep(0.1)
else:
pytest.fail("Server did not start in time")

# Test that no CORS headers are returned
response = requests.options(
f"http://127.0.0.1:{port}/sse",
headers={
"Origin": "https://claude.ai",
"Access-Control-Request-Method": "GET",
},
timeout=5,
)
assert response.headers.get("access-control-allow-origin") is None

finally:
proc.terminate()
proc.join(timeout=5)