diff --git a/src/keboola_mcp_server/cli.py b/src/keboola_mcp_server/cli.py index 4ad1cf36..6b56e62f 100644 --- a/src/keboola_mcp_server/cli.py +++ b/src/keboola_mcp_server/cli.py @@ -21,6 +21,11 @@ from starlette.routing import Route from keboola_mcp_server.config import Config, ServerRuntimeInfo +from keboola_mcp_server.connections import ( + DEFAULT_MAX_CONNECTIONS, + ConnectionLimitMiddleware, + init_connection_metrics, +) from keboola_mcp_server.mcp import ForwardSlashMiddleware from keboola_mcp_server.server import CustomRoutes, create_server @@ -60,6 +65,33 @@ def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace: parser.add_argument('--port', type=int, default=8000, metavar='INT', help='The port to listen on.') parser.add_argument('--log-config', type=pathlib.Path, metavar='PATH', help='Logging config file.') + # Scalability options for HTTP-based transports + # These help address the asyncio single-threaded event loop bottleneck + parser.add_argument( + '--workers', + type=int, + default=1, + metavar='INT', + help=( + 'Number of uvicorn worker processes for HTTP transports. ' + 'Each worker runs its own event loop, helping distribute load across multiple CPU cores. ' + 'Recommended: 2-4 workers for production. Note: SSE connections are stateful, ' + 'so sticky sessions at the load balancer are required when using multiple workers.' + ), + ) + parser.add_argument( + '--max-connections', + type=int, + default=DEFAULT_MAX_CONNECTIONS, + metavar='INT', + help=( + 'Maximum number of concurrent connections per worker. ' + 'When exceeded, new connections receive HTTP 503 (Service Unavailable). ' + 'This prevents event loop overload and ensures existing connections remain responsive. ' + f'Default: {DEFAULT_MAX_CONNECTIONS}.' + ), + ) + return parser.parse_args(args) @@ -156,6 +188,14 @@ async def run_server(args: Optional[list[str]] = None) -> None: from fastmcp.server.http import StarletteWithLifespan from starlette.applications import Starlette + # Initialize connection metrics for backpressure handling. + # This helps prevent event loop overload by rejecting new connections when at capacity. + # Each uvicorn worker process will have its own ConnectionMetrics instance. + connection_metrics = init_connection_metrics(max_connections=parsed_args.max_connections) + LOG.info( + f'Connection limit configured: max_connections={parsed_args.max_connections} per worker' + ) + mount_paths: dict[str, StarletteWithLifespan] = {} custom_routes: CustomRoutes | None = None transports: list[str] = [] @@ -202,8 +242,14 @@ async def lifespan(_app: Starlette): await stack.enter_async_context(_inner_app.lifespan(_app)) yield + # Middleware order matters: + # 1. ConnectionLimitMiddleware - first to reject connections early if at capacity + # 2. ForwardSlashMiddleware - handles path normalization app = Starlette( - middleware=[Middleware(ForwardSlashMiddleware)], + middleware=[ + Middleware(ConnectionLimitMiddleware, connection_metrics=connection_metrics), + Middleware(ForwardSlashMiddleware), + ], lifespan=lifespan, exception_handlers=_exception_handlers, ) @@ -217,19 +263,32 @@ async def lifespan(_app: Starlette): tool.name: tool.parameters for tool in (await mcp_server.get_tools()).values() } - config = uvicorn.Config( + # Configure uvicorn with workers for horizontal scaling within a single instance. + # Each worker runs its own asyncio event loop, helping distribute load across CPU cores. + # This addresses the primary bottleneck: Python's asyncio is single-threaded, so with + # many concurrent SSE connections, every request competes for the same event loop. + # Note: SSE connections are stateful, so sticky sessions at the load balancer are + # required when using multiple workers. + uvicorn_config = uvicorn.Config( app, host=parsed_args.host, port=parsed_args.port, log_config=log_config, timeout_graceful_shutdown=0, lifespan='on', + workers=parsed_args.workers, ) - server = uvicorn.Server(config) + server = uvicorn.Server(uvicorn_config) + + workers_info = f' with {parsed_args.workers} worker(s)' if parsed_args.workers > 1 else '' LOG.info( f'Starting MCP server with {", ".join(transports)} transport{"s" if len(transports) > 1 else ""}' - f' on http://{parsed_args.host}:{parsed_args.port}/' + f'{workers_info} on http://{parsed_args.host}:{parsed_args.port}/' ) + if parsed_args.workers > 1: + LOG.info( + 'Note: Multiple workers require sticky sessions at the load balancer for SSE connections.' + ) await server.serve() diff --git a/src/keboola_mcp_server/connections.py b/src/keboola_mcp_server/connections.py new file mode 100644 index 00000000..b02d82dd --- /dev/null +++ b/src/keboola_mcp_server/connections.py @@ -0,0 +1,242 @@ +""" +Connection management for the Keboola MCP server. + +This module provides connection tracking and backpressure functionality to handle +high concurrency scenarios. Python's asyncio event loop runs on a single thread, +so with many concurrent SSE connections (e.g., 1000+), every new request competes +for the same event loop. This can cause simple operations to work but complex ones +(like tools/list) to timeout. + +The connection limit with backpressure prevents degradation for existing connections +by rejecting new connections when the server is at capacity, returning HTTP 503. +""" + +import json +import logging +import threading +from contextlib import contextmanager +from typing import Generator + +from starlette.types import ASGIApp, Receive, Scope, Send + +LOG = logging.getLogger(__name__) + +DEFAULT_MAX_CONNECTIONS = 1000 + + +class ConnectionMetrics: + """ + Thread-safe connection counter for tracking active SSE/HTTP connections. + + This class provides a simple mechanism to track the number of active connections + and enforce a maximum connection limit. When the limit is reached, new connections + should be rejected with HTTP 503 (Service Unavailable) to prevent degradation + of service for existing connections. + + The implementation uses a threading.Lock for thread-safety since uvicorn workers + run in separate processes, but within each process, multiple coroutines may + access the counter concurrently. + """ + + def __init__(self, max_connections: int = DEFAULT_MAX_CONNECTIONS) -> None: + """ + Initialize the connection metrics. + + :param max_connections: Maximum number of concurrent connections allowed. + """ + self._lock = threading.Lock() + self._count = 0 + self._max_connections = max_connections + + @property + def count(self) -> int: + """Return the current number of active connections.""" + with self._lock: + return self._count + + @property + def max_connections(self) -> int: + """Return the maximum number of connections allowed.""" + return self._max_connections + + def is_at_capacity(self) -> bool: + """Check if the server is at connection capacity.""" + with self._lock: + return self._count >= self._max_connections + + def increment(self) -> bool: + """ + Increment the connection count if not at capacity. + + :return: True if the connection was accepted, False if at capacity. + """ + with self._lock: + if self._count >= self._max_connections: + LOG.warning( + f'Connection rejected: at capacity ({self._count}/{self._max_connections})' + ) + return False + self._count += 1 + LOG.debug(f'Connection accepted: {self._count}/{self._max_connections}') + return True + + def decrement(self) -> None: + """Decrement the connection count.""" + with self._lock: + if self._count > 0: + self._count -= 1 + LOG.debug(f'Connection closed: {self._count}/{self._max_connections}') + + @contextmanager + def track_connection(self) -> Generator[bool, None, None]: + """ + Context manager for tracking a connection's lifecycle. + + Usage: + with connection_metrics.track_connection() as accepted: + if not accepted: + return JSONResponse({"error": "Server at capacity"}, status_code=503) + # Handle the connection... + + :yields: True if the connection was accepted, False if at capacity. + """ + accepted = self.increment() + try: + yield accepted + finally: + if accepted: + self.decrement() + + def get_stats(self) -> dict[str, int]: + """Return connection statistics.""" + with self._lock: + return { + 'active_connections': self._count, + 'max_connections': self._max_connections, + 'available_connections': max(0, self._max_connections - self._count), + } + + +# Global connection metrics instance - shared across the application +# Each uvicorn worker process will have its own instance +_connection_metrics: ConnectionMetrics | None = None + + +def get_connection_metrics() -> ConnectionMetrics | None: + """Get the global connection metrics instance.""" + return _connection_metrics + + +def init_connection_metrics(max_connections: int = DEFAULT_MAX_CONNECTIONS) -> ConnectionMetrics: + """ + Initialize the global connection metrics instance. + + :param max_connections: Maximum number of concurrent connections allowed. + :return: The initialized ConnectionMetrics instance. + """ + global _connection_metrics + _connection_metrics = ConnectionMetrics(max_connections) + LOG.info(f'Initialized connection metrics with max_connections={max_connections}') + return _connection_metrics + + +class ConnectionLimitMiddleware: + """ + ASGI middleware that enforces connection limits with backpressure. + + This middleware tracks active connections and returns HTTP 503 (Service Unavailable) + when the server is at capacity. This prevents degradation for existing connections + by rejecting new ones rather than allowing the event loop to become overloaded. + + Why this is needed: + - Python's asyncio event loop runs on a single thread + - With many concurrent SSE connections (e.g., 1000+), every new request competes + for the same event loop + - This causes simple operations to work but complex ones (like tools/list) to timeout + - By limiting connections and returning 503, we ensure existing connections remain responsive + + The middleware only tracks HTTP connections and applies limits to SSE/MCP endpoints. + Health check and info endpoints are excluded from connection tracking. + """ + + # Paths that should be excluded from connection tracking (health checks, etc.) + EXCLUDED_PATHS = frozenset(['/', '/health-check']) + + def __init__(self, app: ASGIApp, connection_metrics: ConnectionMetrics | None = None) -> None: + """ + Initialize the connection limit middleware. + + :param app: The ASGI application to wrap. + :param connection_metrics: ConnectionMetrics instance to use. If None, uses the global instance. + """ + self._app = app + self._connection_metrics = connection_metrics + + def _get_metrics(self) -> ConnectionMetrics | None: + """Get the connection metrics instance (local or global).""" + return self._connection_metrics or get_connection_metrics() + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + Process an ASGI request with connection tracking. + + For HTTP requests to tracked endpoints: + 1. Check if at capacity - if so, return 503 + 2. Increment connection count + 3. Process the request + 4. Decrement connection count when done + """ + if scope['type'] != 'http': + await self._app(scope, receive, send) + return + + path = scope.get('path', '') + + # Skip connection tracking for excluded paths (health checks, etc.) + if path in self.EXCLUDED_PATHS: + await self._app(scope, receive, send) + return + + metrics = self._get_metrics() + + # If no metrics configured, pass through without tracking + if metrics is None: + await self._app(scope, receive, send) + return + + # Check capacity and reject if at limit + if not metrics.increment(): + await self._send_503_response(send, metrics) + return + + try: + await self._app(scope, receive, send) + finally: + metrics.decrement() + + async def _send_503_response(self, send: Send, metrics: ConnectionMetrics) -> None: + """Send a 503 Service Unavailable response.""" + stats = metrics.get_stats() + body = json.dumps({ + 'error': 'Server at capacity', + 'message': ( + 'The server has reached its maximum connection limit. ' + 'Please retry your request later.' + ), + 'active_connections': stats['active_connections'], + 'max_connections': stats['max_connections'], + }).encode('utf-8') + + await send({ + 'type': 'http.response.start', + 'status': 503, + 'headers': [ + (b'content-type', b'application/json'), + (b'content-length', str(len(body)).encode('utf-8')), + (b'retry-after', b'5'), + ], + }) + await send({ + 'type': 'http.response.body', + 'body': body, + })