Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions src/keboola_mcp_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from starlette.routing import Route

from keboola_mcp_server.config import Config, ServerRuntimeInfo
from keboola_mcp_server.connections import (
DEFAULT_MAX_CONNECTIONS,
ConnectionLimitMiddleware,
init_connection_metrics,
)
from keboola_mcp_server.mcp import ForwardSlashMiddleware
from keboola_mcp_server.server import CustomRoutes, create_server

Expand Down Expand Up @@ -60,6 +65,33 @@ def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace:
parser.add_argument('--port', type=int, default=8000, metavar='INT', help='The port to listen on.')
parser.add_argument('--log-config', type=pathlib.Path, metavar='PATH', help='Logging config file.')

# Scalability options for HTTP-based transports
# These help address the asyncio single-threaded event loop bottleneck
parser.add_argument(
'--workers',
type=int,
default=1,
metavar='INT',
help=(
'Number of uvicorn worker processes for HTTP transports. '
'Each worker runs its own event loop, helping distribute load across multiple CPU cores. '
'Recommended: 2-4 workers for production. Note: SSE connections are stateful, '
'so sticky sessions at the load balancer are required when using multiple workers.'
),
)
parser.add_argument(
'--max-connections',
type=int,
default=DEFAULT_MAX_CONNECTIONS,
metavar='INT',
help=(
'Maximum number of concurrent connections per worker. '
'When exceeded, new connections receive HTTP 503 (Service Unavailable). '
'This prevents event loop overload and ensures existing connections remain responsive. '
f'Default: {DEFAULT_MAX_CONNECTIONS}.'
),
)

return parser.parse_args(args)


Expand Down Expand Up @@ -156,6 +188,14 @@ async def run_server(args: Optional[list[str]] = None) -> None:
from fastmcp.server.http import StarletteWithLifespan
from starlette.applications import Starlette

# Initialize connection metrics for backpressure handling.
# This helps prevent event loop overload by rejecting new connections when at capacity.
# Each uvicorn worker process will have its own ConnectionMetrics instance.
connection_metrics = init_connection_metrics(max_connections=parsed_args.max_connections)
LOG.info(
f'Connection limit configured: max_connections={parsed_args.max_connections} per worker'
)

mount_paths: dict[str, StarletteWithLifespan] = {}
custom_routes: CustomRoutes | None = None
transports: list[str] = []
Expand Down Expand Up @@ -202,8 +242,14 @@ async def lifespan(_app: Starlette):
await stack.enter_async_context(_inner_app.lifespan(_app))
yield

# Middleware order matters:
# 1. ConnectionLimitMiddleware - first to reject connections early if at capacity
# 2. ForwardSlashMiddleware - handles path normalization
app = Starlette(
middleware=[Middleware(ForwardSlashMiddleware)],
middleware=[
Middleware(ConnectionLimitMiddleware, connection_metrics=connection_metrics),
Middleware(ForwardSlashMiddleware),
],
lifespan=lifespan,
exception_handlers=_exception_handlers,
)
Expand All @@ -217,19 +263,32 @@ async def lifespan(_app: Starlette):
tool.name: tool.parameters for tool in (await mcp_server.get_tools()).values()
}

config = uvicorn.Config(
# Configure uvicorn with workers for horizontal scaling within a single instance.
# Each worker runs its own asyncio event loop, helping distribute load across CPU cores.
# This addresses the primary bottleneck: Python's asyncio is single-threaded, so with
# many concurrent SSE connections, every request competes for the same event loop.
# Note: SSE connections are stateful, so sticky sessions at the load balancer are
# required when using multiple workers.
uvicorn_config = uvicorn.Config(
app,
host=parsed_args.host,
port=parsed_args.port,
log_config=log_config,
timeout_graceful_shutdown=0,
lifespan='on',
workers=parsed_args.workers,
)
server = uvicorn.Server(config)
server = uvicorn.Server(uvicorn_config)

workers_info = f' with {parsed_args.workers} worker(s)' if parsed_args.workers > 1 else ''
LOG.info(
f'Starting MCP server with {", ".join(transports)} transport{"s" if len(transports) > 1 else ""}'
f' on http://{parsed_args.host}:{parsed_args.port}/'
f'{workers_info} on http://{parsed_args.host}:{parsed_args.port}/'
)
if parsed_args.workers > 1:
LOG.info(
'Note: Multiple workers require sticky sessions at the load balancer for SSE connections.'
)

await server.serve()

Expand Down
242 changes: 242 additions & 0 deletions src/keboola_mcp_server/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
"""
Connection management for the Keboola MCP server.

This module provides connection tracking and backpressure functionality to handle
high concurrency scenarios. Python's asyncio event loop runs on a single thread,
so with many concurrent SSE connections (e.g., 1000+), every new request competes
for the same event loop. This can cause simple operations to work but complex ones
(like tools/list) to timeout.

The connection limit with backpressure prevents degradation for existing connections
by rejecting new connections when the server is at capacity, returning HTTP 503.
"""

import json
import logging
import threading
from contextlib import contextmanager
from typing import Generator

from starlette.types import ASGIApp, Receive, Scope, Send

LOG = logging.getLogger(__name__)

DEFAULT_MAX_CONNECTIONS = 1000


class ConnectionMetrics:
"""
Thread-safe connection counter for tracking active SSE/HTTP connections.

This class provides a simple mechanism to track the number of active connections
and enforce a maximum connection limit. When the limit is reached, new connections
should be rejected with HTTP 503 (Service Unavailable) to prevent degradation
of service for existing connections.

The implementation uses a threading.Lock for thread-safety since uvicorn workers
run in separate processes, but within each process, multiple coroutines may
access the counter concurrently.
"""

def __init__(self, max_connections: int = DEFAULT_MAX_CONNECTIONS) -> None:
"""
Initialize the connection metrics.

:param max_connections: Maximum number of concurrent connections allowed.
"""
self._lock = threading.Lock()
self._count = 0
self._max_connections = max_connections

@property
def count(self) -> int:
"""Return the current number of active connections."""
with self._lock:
return self._count

@property
def max_connections(self) -> int:
"""Return the maximum number of connections allowed."""
return self._max_connections

def is_at_capacity(self) -> bool:
"""Check if the server is at connection capacity."""
with self._lock:
return self._count >= self._max_connections

def increment(self) -> bool:
"""
Increment the connection count if not at capacity.

:return: True if the connection was accepted, False if at capacity.
"""
with self._lock:
if self._count >= self._max_connections:
LOG.warning(
f'Connection rejected: at capacity ({self._count}/{self._max_connections})'
)
return False
self._count += 1
LOG.debug(f'Connection accepted: {self._count}/{self._max_connections}')
return True

def decrement(self) -> None:
"""Decrement the connection count."""
with self._lock:
if self._count > 0:
self._count -= 1
LOG.debug(f'Connection closed: {self._count}/{self._max_connections}')

@contextmanager
def track_connection(self) -> Generator[bool, None, None]:
"""
Context manager for tracking a connection's lifecycle.

Usage:
with connection_metrics.track_connection() as accepted:
if not accepted:
return JSONResponse({"error": "Server at capacity"}, status_code=503)
# Handle the connection...

:yields: True if the connection was accepted, False if at capacity.
"""
accepted = self.increment()
try:
yield accepted
finally:
if accepted:
self.decrement()

def get_stats(self) -> dict[str, int]:
"""Return connection statistics."""
with self._lock:
return {
'active_connections': self._count,
'max_connections': self._max_connections,
'available_connections': max(0, self._max_connections - self._count),
}


# Global connection metrics instance - shared across the application
# Each uvicorn worker process will have its own instance
_connection_metrics: ConnectionMetrics | None = None


def get_connection_metrics() -> ConnectionMetrics | None:
"""Get the global connection metrics instance."""
return _connection_metrics


def init_connection_metrics(max_connections: int = DEFAULT_MAX_CONNECTIONS) -> ConnectionMetrics:
"""
Initialize the global connection metrics instance.

:param max_connections: Maximum number of concurrent connections allowed.
:return: The initialized ConnectionMetrics instance.
"""
global _connection_metrics
_connection_metrics = ConnectionMetrics(max_connections)
LOG.info(f'Initialized connection metrics with max_connections={max_connections}')
return _connection_metrics


class ConnectionLimitMiddleware:
"""
ASGI middleware that enforces connection limits with backpressure.

This middleware tracks active connections and returns HTTP 503 (Service Unavailable)
when the server is at capacity. This prevents degradation for existing connections
by rejecting new ones rather than allowing the event loop to become overloaded.

Why this is needed:
- Python's asyncio event loop runs on a single thread
- With many concurrent SSE connections (e.g., 1000+), every new request competes
for the same event loop
- This causes simple operations to work but complex ones (like tools/list) to timeout
- By limiting connections and returning 503, we ensure existing connections remain responsive

The middleware only tracks HTTP connections and applies limits to SSE/MCP endpoints.
Health check and info endpoints are excluded from connection tracking.
"""

# Paths that should be excluded from connection tracking (health checks, etc.)
EXCLUDED_PATHS = frozenset(['/', '/health-check'])

def __init__(self, app: ASGIApp, connection_metrics: ConnectionMetrics | None = None) -> None:
"""
Initialize the connection limit middleware.

:param app: The ASGI application to wrap.
:param connection_metrics: ConnectionMetrics instance to use. If None, uses the global instance.
"""
self._app = app
self._connection_metrics = connection_metrics

def _get_metrics(self) -> ConnectionMetrics | None:
"""Get the connection metrics instance (local or global)."""
return self._connection_metrics or get_connection_metrics()

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Process an ASGI request with connection tracking.

For HTTP requests to tracked endpoints:
1. Check if at capacity - if so, return 503
2. Increment connection count
3. Process the request
4. Decrement connection count when done
"""
if scope['type'] != 'http':
await self._app(scope, receive, send)
return

path = scope.get('path', '')

# Skip connection tracking for excluded paths (health checks, etc.)
if path in self.EXCLUDED_PATHS:
await self._app(scope, receive, send)
return

metrics = self._get_metrics()

# If no metrics configured, pass through without tracking
if metrics is None:
await self._app(scope, receive, send)
return

# Check capacity and reject if at limit
if not metrics.increment():
await self._send_503_response(send, metrics)
return

try:
await self._app(scope, receive, send)
finally:
metrics.decrement()

async def _send_503_response(self, send: Send, metrics: ConnectionMetrics) -> None:
"""Send a 503 Service Unavailable response."""
stats = metrics.get_stats()
body = json.dumps({
'error': 'Server at capacity',
'message': (
'The server has reached its maximum connection limit. '
'Please retry your request later.'
),
'active_connections': stats['active_connections'],
'max_connections': stats['max_connections'],
}).encode('utf-8')

await send({
'type': 'http.response.start',
'status': 503,
'headers': [
(b'content-type', b'application/json'),
(b'content-length', str(len(body)).encode('utf-8')),
(b'retry-after', b'5'),
],
})
await send({
'type': 'http.response.body',
'body': body,
})
Loading