diff --git a/helm/mcp-optimizer/Chart.yaml b/helm/mcp-optimizer/Chart.yaml index 63f0dba..4e946aa 100644 --- a/helm/mcp-optimizer/Chart.yaml +++ b/helm/mcp-optimizer/Chart.yaml @@ -2,8 +2,8 @@ apiVersion: v2 name: mcp-optimizer description: A Helm chart for deploying MCP Optimizer MCP Server in Kubernetes type: application -version: 0.1.1 -appVersion: "0.2.0" +version: 0.2.1 +appVersion: "0.2.1" keywords: - mcp - mcp-optimizer @@ -14,7 +14,3 @@ sources: - https://github.com/StacklokLabs/mcp-optimizer maintainers: - name: MCP Optimizer Team - -# Changelog: -# 0.1.1 - Fix groupFiltering.allowedGroups feature with automatic env var merging in podTemplateSpec -# 0.1.0 - Initial release diff --git a/src/mcp_optimizer/db/models.py b/src/mcp_optimizer/db/models.py index d9e158f..47015bd 100644 --- a/src/mcp_optimizer/db/models.py +++ b/src/mcp_optimizer/db/models.py @@ -18,7 +18,7 @@ class TransportType(str, Enum): """ Enum for transport types. - There is 1:1 relation between ToolHive proxy modes to database transport types. + There is 1:1 relation between ToolHive transport modes to database transport types. """ SSE = "sse" diff --git a/src/mcp_optimizer/ingestion.py b/src/mcp_optimizer/ingestion.py index 40a7b81..4803547 100644 --- a/src/mcp_optimizer/ingestion.py +++ b/src/mcp_optimizer/ingestion.py @@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection from mcp_optimizer.db.config import DatabaseConfig -from mcp_optimizer.db.exceptions import DbNotFoundError +from mcp_optimizer.db.exceptions import DbNotFoundError, DuplicateRegistryServersError from mcp_optimizer.db.models import ( McpStatus, RegistryServer, @@ -27,11 +27,14 @@ from mcp_optimizer.db.workload_server_ops import WorkloadServerOps from mcp_optimizer.db.workload_tool_ops import WorkloadToolOps from mcp_optimizer.embeddings import EmbeddingManager -from mcp_optimizer.mcp_client import MCPServerClient, WorkloadConnectionError +from mcp_optimizer.mcp_client import ( + MCPServerClient, + WorkloadConnectionError, + determine_transport_type, +) from mcp_optimizer.token_counter import TokenCounter from mcp_optimizer.toolhive.api_models.core import Workload from mcp_optimizer.toolhive.api_models.registry import ImageMetadata, Registry, RemoteServerMetadata -from mcp_optimizer.toolhive.enums import ToolHiveProxyMode, url_to_toolhive_proxy_mode from mcp_optimizer.toolhive.k8s_client import K8sClient from mcp_optimizer.toolhive.toolhive_client import ( ToolhiveClient, @@ -184,44 +187,6 @@ async def _batch_gather(self, tasks: list[Any], batch_size: int) -> list[Any]: ) return results - def _map_transport_type(self, workload: Workload) -> TransportType: - """Map Toolhive transport type to database transport type. - - Args: - workload: Workload object with proxy_mode and url - - Returns: - Mapped transport type for database storage - - Raises: - ValueError: If transport type is not supported - """ - mapping = { - ToolHiveProxyMode.SSE: TransportType.SSE, - ToolHiveProxyMode.STREAMABLE: TransportType.STREAMABLE, - } - - # Prefer using the proxy_mode field if available - if workload.proxy_mode: - proxy_mode_str = workload.proxy_mode.lower() - if proxy_mode_str == "sse": - return TransportType.SSE - elif proxy_mode_str == "streamable-http": - return TransportType.STREAMABLE - else: - logger.warning( - f"Unknown proxy_mode '{proxy_mode_str}', falling back to URL detection", - workload=workload.name, - ) - - # Fallback to URL-based detection for backwards compatibility - if workload.url is None: - raise IngestionError(f"Workload {workload.name} has no URL") - - toolhive_proxy_mode = url_to_toolhive_proxy_mode(workload.url) - - return mapping[toolhive_proxy_mode] - def _map_workload_status(self, workload_status: str | None) -> McpStatus: """Map workload status to McpStatus enum. @@ -607,8 +572,6 @@ async def _find_and_link_registry_server( Raises: DuplicateRegistryServersError: If multiple matching servers found """ - from mcp_optimizer.db.exceptions import DuplicateRegistryServersError - # Find matching registry servers matching_servers = await self.registry_server_ops.find_matching_servers( url=url, package=package, remote=remote, conn=conn @@ -710,9 +673,8 @@ async def _upsert_workload_server( ValueError: If workload data is invalid DuplicateRegistryServersError: If multiple matching registry servers found """ - from mcp_optimizer.db.exceptions import DuplicateRegistryServersError - - transport = self._map_transport_type(workload) + # Cast to TransportType (DB enum) from ToolHiveTransportType + transport = cast(TransportType, determine_transport_type(workload, self.runtime_mode)) status = self._map_workload_status(workload.status) if not workload.name: @@ -987,8 +949,6 @@ async def _process_workload(self, workload: Workload, conn: AsyncConnection) -> Returns: Processing result with status and counts """ - from mcp_optimizer.db.exceptions import DuplicateRegistryServersError - result = { "name": workload.name, "status": "failed", @@ -1016,7 +976,9 @@ async def _process_workload(self, workload: Workload, conn: AsyncConnection) -> ) # Get tools from MCP server - mcp_client = MCPServerClient(workload, timeout=self.mcp_timeout) + mcp_client = MCPServerClient( + workload, timeout=self.mcp_timeout, runtime_mode=self.runtime_mode + ) tools_result = await mcp_client.list_tools() # Sync tools with appropriate context @@ -1027,6 +989,18 @@ async def _process_workload(self, workload: Workload, conn: AsyncConnection) -> # Track if anything was updated was_updated = server_was_updated or tools_were_updated + logger.info( + "Processed workload", + server_id=server_id, + workload_name=workload.name, + url=workload.url, + transport_type=workload.transport_type, + group=workload.group, + tools_count=tools_count, + server_was_updated=server_was_updated, + tools_were_updated=tools_were_updated, + ) + # Calculate autonomous embedding if: # 1. Not linked to registry (registry_server_id is None) # 2. No server embedding exists diff --git a/src/mcp_optimizer/mcp_client.py b/src/mcp_optimizer/mcp_client.py index 132d3dc..ca77e45 100644 --- a/src/mcp_optimizer/mcp_client.py +++ b/src/mcp_optimizer/mcp_client.py @@ -4,17 +4,16 @@ import asyncio from typing import Any, Awaitable, Callable -from urllib.parse import urlparse, urlunparse import structlog from mcp import ClientSession from mcp.client.sse import sse_client -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.shared.exceptions import McpError from mcp.types import CallToolResult, ListToolsResult from mcp_optimizer.toolhive.api_models.core import Workload -from mcp_optimizer.toolhive.enums import ToolHiveProxyMode, url_to_toolhive_proxy_mode +from mcp_optimizer.toolhive.enums import ToolHiveTransportMode, url_to_toolhive_transport_mode logger = structlog.get_logger(__name__) @@ -25,78 +24,95 @@ class WorkloadConnectionError(Exception): pass +def determine_transport_type(workload: Workload, runtime_mode: str) -> ToolHiveTransportMode: + """ + Determine the transport type from workload configuration based on runtime mode. + + Depending on the runtime mode, the transport type is determined differently: + - In docker mode: determined from proxy_mode (how the proxy connects to the container) + - In k8s mode: determined from transport_type (the direct connection type to the pod) + + Args: + workload: Workload configuration containing transport information + runtime_mode: Runtime environment - "docker" or "k8s" + + Returns: + ToolHiveProxyMode: The transport type to use (SSE or STREAMABLE) + + Raises: + WorkloadConnectionError: If transport type cannot be determined + """ + # Docker mode: Check proxy_mode field (proxy connection type) + if runtime_mode == "docker": + if workload.proxy_mode: + transport_field_lower = workload.proxy_mode.lower() + logger.debug( + f"Docker mode: determining transport from proxy_mode field: " + f"{transport_field_lower}", + workload=workload.name, + runtime_mode=runtime_mode, + ) + if transport_field_lower == "streamable-http": + return ToolHiveTransportMode.STREAMABLE + elif transport_field_lower == "sse": + return ToolHiveTransportMode.SSE + else: + logger.warning( + f"Unknown transport in proxy_mode: '{transport_field_lower}', " + "falling back to URL detection", + workload=workload.name, + ) + + # K8s mode: Check transport_type field (direct connection type) + elif runtime_mode == "k8s": + if workload.transport_type: + transport_field_lower = workload.transport_type.lower() + logger.debug( + f"K8s mode: determining transport from transport_type field: " + f"{transport_field_lower}", + workload=workload.name, + runtime_mode=runtime_mode, + ) + if transport_field_lower == "streamable-http": + return ToolHiveTransportMode.STREAMABLE + elif transport_field_lower == "sse": + return ToolHiveTransportMode.SSE + else: + logger.warning( + f"Unknown transport in transport_type: '{transport_field_lower}', " + "falling back to URL detection", + workload=workload.name, + ) + + # Fallback to URL-based detection for backwards compatibility + if not workload.url: + raise WorkloadConnectionError( + f"No transport type or URL available. Workload: {workload.name}", + ) + + logger.debug( + "No transport field available, falling back to URL-based detection", + workload=workload.name, + runtime_mode=runtime_mode, + ) + return url_to_toolhive_transport_mode(workload.url) + + class MCPServerClient: """Client for connecting to individual MCP servers.""" - def __init__(self, workload: Workload, timeout: float): + def __init__(self, workload: Workload, timeout: float, runtime_mode: str = "docker"): """ Initialize MCP client for a specific workload. Args: workload: The workload (MCP server) to connect to timeout: Timeout in seconds for operations (default: from config) + runtime_mode: Runtime environment - "docker" or "k8s" (default: "docker") """ self.workload = workload self.timeout = timeout - - def _normalize_url(self, url: str, proxy_mode: ToolHiveProxyMode) -> str: - """ - Normalize URL for the given proxy mode. - - For streamable-http: - - Fragments must be stripped as they're not supported - - Path must be /mcp (not /sse) as streamable-http uses /mcp endpoint - For SSE, fragments are preserved as they're used for container identification. - - Args: - url: Original URL from ToolHive - proxy_mode: The proxy mode being used - - Returns: - Normalized URL without fragments and with correct path for streamable-http, - original URL for SSE - """ - if proxy_mode == ToolHiveProxyMode.STREAMABLE: - # Strip fragments for streamable-http - # (fragments not supported by streamable-http client) - parsed = urlparse(url) - - # Fix path: streamable-http uses /mcp endpoint, not /sse - path = parsed.path - if path.endswith("/sse"): - path = path.replace("/sse", "/mcp") - elif not path.endswith("/mcp"): - # Only add /mcp if the path doesn't already contain /mcp - # This prevents double-adding /mcp to URLs like /mcp/test-server - if "/mcp" not in path: - # If path doesn't end with /mcp or /sse, and doesn't contain /mcp, - # ensure it ends with /mcp - if path.endswith("/"): - path = path + "mcp" - else: - path = path + "/mcp" - - # Reconstruct URL without fragment and with corrected path - normalized_tuple = ( - parsed.scheme, - parsed.netloc, - path, - parsed.params, - parsed.query, - "", # Empty fragment - ) - normalized = str(urlunparse(normalized_tuple)) - if normalized != url: - logger.debug( - "Normalized URL for streamable-http", - original_url=url, - normalized_url=normalized, - workload=self.workload.name, - ) - return normalized - else: - # SSE preserves fragments (used for container identification) - return url + self.runtime_mode = runtime_mode def _extract_error_from_exception_group(self, eg: ExceptionGroup) -> str: """ @@ -133,46 +149,6 @@ def collect_exceptions(exc_group): # Fallback to the exception group message return str(eg) - def _determine_proxy_mode(self) -> ToolHiveProxyMode: - """ - Determine the proxy mode from workload configuration. - - Returns: - ToolHiveProxyMode: The proxy mode to use - - Raises: - WorkloadConnectionError: If proxy mode is unknown or not set - """ - if self.workload.proxy_mode: - proxy_mode_lower = self.workload.proxy_mode.lower() - logger.debug( - f"Determining proxy mode from proxy_mode field: {proxy_mode_lower}", - workload=self.workload.name, - ) - if proxy_mode_lower == "streamable-http": - return ToolHiveProxyMode.STREAMABLE - elif proxy_mode_lower == "sse": - return ToolHiveProxyMode.SSE - else: - logger.warning( - f"Unknown proxy_mode '{proxy_mode_lower}', falling back to URL detection", - workload=self.workload.name, - ) - - # Fallback to URL-based detection for backwards compatibility - if not self.workload.url: - logger.warning( - "No proxy_mode or URL available, defaulting to SSE", - workload=self.workload.name, - ) - return ToolHiveProxyMode.SSE - - logger.debug( - "No proxy_mode available, falling back to URL-based detection", - workload=self.workload.name, - ) - return url_to_toolhive_proxy_mode(self.workload.url) - async def _execute_with_session(self, operation: Callable[[ClientSession], Awaitable]) -> Any: """ Execute an operation with an MCP session. @@ -188,31 +164,34 @@ async def _execute_with_session(self, operation: Callable[[ClientSession], Await logger.debug(f"Workload URL: {self.workload.url}") - # Determine proxy mode and normalize URL - proxy_mode = self._determine_proxy_mode() - normalized_url = self._normalize_url(self.workload.url, proxy_mode) + # Determine transport type based on runtime mode + # Docker: uses proxy_mode field (how proxy connects to container) + # K8s: uses transport_type field (direct connection to pod) + transport_type = determine_transport_type(self.workload, self.runtime_mode) logger.info( - f"Using {proxy_mode} client for workload '{self.workload.name}'", + f"Using {transport_type} client for workload '{self.workload.name}'", workload=self.workload.name, + transport_type_field=self.workload.transport_type, proxy_mode_field=self.workload.proxy_mode, - original_url=self.workload.url, - normalized_url=normalized_url, + url=self.workload.url, ) try: - if proxy_mode == ToolHiveProxyMode.STREAMABLE: + if transport_type == ToolHiveTransportMode.STREAMABLE: return await asyncio.wait_for( - self._execute_streamable_session(operation, normalized_url), + self._execute_streamable_session(operation, self.workload.url), timeout=self.timeout, ) - elif proxy_mode == ToolHiveProxyMode.SSE: + elif transport_type == ToolHiveTransportMode.SSE: return await asyncio.wait_for( - self._execute_sse_session(operation, normalized_url), timeout=self.timeout + self._execute_sse_session(operation, self.workload.url), timeout=self.timeout ) else: - logger.error(f"Unsupported transport type: {proxy_mode}", workload=self.workload) - raise WorkloadConnectionError(f"Unsupported transport type: {proxy_mode}") + logger.error( + f"Unsupported transport type: {transport_type}", workload=self.workload + ) + raise WorkloadConnectionError(f"Unsupported transport type: {transport_type}") except asyncio.TimeoutError as e: logger.error( f"Operation timed out after {self.timeout} seconds", workload=self.workload @@ -248,15 +227,15 @@ async def _execute_streamable_session( workload=self.workload.name, url=url, ) - async with streamablehttp_client(url) as (read_stream, write_stream, _): + async with streamable_http_client(url) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: logger.info( - f"Initializing MCP session for workload '{self.workload.name}'", + "Initializing Streamable MCP session for workload", workload=self.workload.name, ) await session.initialize() logger.debug( - f"MCP session initialized successfully for workload '{self.workload.name}'", + "Streamable MCP session initialized successfully", workload=self.workload.name, ) return await operation(session) @@ -273,12 +252,12 @@ async def _execute_sse_session( async with sse_client(url) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: logger.info( - f"Initializing MCP session for workload '{self.workload.name}'", + "Initializing SSE MCP session for workload", workload=self.workload.name, ) await session.initialize() logger.debug( - f"MCP session initialized successfully for workload '{self.workload.name}'", + "SSE MCP session initialized successfully for workload", workload=self.workload.name, ) return await operation(session) diff --git a/src/mcp_optimizer/server.py b/src/mcp_optimizer/server.py index 2dc4edf..3708b0c 100644 --- a/src/mcp_optimizer/server.py +++ b/src/mcp_optimizer/server.py @@ -545,7 +545,9 @@ async def call_tool(server_name: str, tool_name: str, parameters: dict) -> CallT logger.info( f"Calling tool '{tool_name}' on server '{server_name}' with parameters: {parameters}" ) - mcp_client = MCPServerClient(workload, timeout=_config.mcp_timeout) + mcp_client = MCPServerClient( + workload, timeout=_config.mcp_timeout, runtime_mode=_config.runtime_mode + ) # Call the tool using the MCP client try: diff --git a/src/mcp_optimizer/toolhive/enums.py b/src/mcp_optimizer/toolhive/enums.py index d5df4ae..84a5a79 100644 --- a/src/mcp_optimizer/toolhive/enums.py +++ b/src/mcp_optimizer/toolhive/enums.py @@ -1,17 +1,17 @@ from enum import Enum -class ToolHiveProxyMode(str, Enum): +class ToolHiveTransportMode(str, Enum): """ Enum for ToolHive proxy modes. - Proxy modes is the MCP transport type in which ToolHive's proxy operates for a workload. + Proxy modes represent the MCP transport type in which ToolHive's proxy operates for a workload. """ STREAMABLE = "streamable-http" SSE = "sse" def __str__(self) -> str: - """Return the string representation of the transport type.""" + """Return the string representation of the proxy mode.""" return self.value @@ -32,8 +32,8 @@ def __str__(self) -> str: return self.value -def url_to_toolhive_proxy_mode(toolhive_url: str) -> ToolHiveProxyMode: - """Map ToolHive URL to ToolHiveProxyMode enum. +def url_to_toolhive_transport_mode(toolhive_url: str) -> ToolHiveTransportMode: + """Map ToolHive URL to ToolHiveTransportMode enum. Args: toolhive_url: ToolHive URL @@ -45,8 +45,8 @@ def url_to_toolhive_proxy_mode(toolhive_url: str) -> ToolHiveProxyMode: IngestionError: If URL is not supported """ if "/mcp" in toolhive_url: - return ToolHiveProxyMode.STREAMABLE + return ToolHiveTransportMode.STREAMABLE elif "/sse" in toolhive_url: - return ToolHiveProxyMode.SSE + return ToolHiveTransportMode.SSE else: raise ValueError(f"Unsupported ToolHive URL: {toolhive_url}") diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py index ea3bd93..0b86ffa 100644 --- a/tests/test_ingestion.py +++ b/tests/test_ingestion.py @@ -8,10 +8,9 @@ from mcp.types import Tool as McpTool from mcp_optimizer.db.config import DatabaseConfig -from mcp_optimizer.db.models import McpStatus, TransportType +from mcp_optimizer.db.models import McpStatus from mcp_optimizer.embeddings import EmbeddingManager from mcp_optimizer.ingestion import IngestionError, IngestionService -from mcp_optimizer.toolhive.api_models.core import Workload class TestIngestionServiceMapping: @@ -39,26 +38,6 @@ def ingestion_service(self, mock_db_config, mock_embedding_manager): encoding="cl100k_base", ) - def test_map_transport_type_sse(self, ingestion_service): - """Test mapping SSE transport to TransportType.SSE.""" - workload = Workload(name="test", url="http://localhost:8000/sse", proxy_mode="sse") - result = ingestion_service._map_transport_type(workload) - assert result == TransportType.SSE - - def test_map_transport_type_streamable(self, ingestion_service): - """Test mapping streamable-http transport to TransportType.STREAMABLE.""" - workload = Workload( - name="test", url="http://localhost:8000/mcp", proxy_mode="streamable-http" - ) - result = ingestion_service._map_transport_type(workload) - assert result == TransportType.STREAMABLE - - def test_map_transport_type_none_url_raises_error(self, ingestion_service): - """Test that None URL raises IngestionError.""" - workload = Workload(name="test", url=None, proxy_mode=None) - with pytest.raises(IngestionError, match="Workload test has no URL"): - ingestion_service._map_transport_type(workload) - def test_map_workload_status_running(self, ingestion_service): """Test mapping running status to McpStatus.RUNNING.""" result = ingestion_service._map_workload_status("running") diff --git a/tests/test_mcp_client.py b/tests/test_mcp_client.py index 6518de3..3804866 100644 --- a/tests/test_mcp_client.py +++ b/tests/test_mcp_client.py @@ -8,8 +8,13 @@ from mcp.shared.exceptions import McpError from mcp.types import ErrorData -from mcp_optimizer.mcp_client import MCPServerClient, WorkloadConnectionError +from mcp_optimizer.mcp_client import ( + MCPServerClient, + WorkloadConnectionError, + determine_transport_type, +) from mcp_optimizer.toolhive.api_models.core import Workload +from mcp_optimizer.toolhive.enums import ToolHiveTransportMode @pytest.fixture @@ -36,7 +41,7 @@ async def test_mcp_server_client_no_url(): tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") with pytest.raises(WorkloadConnectionError, match="Workload has no URL"): await client.list_tools() @@ -49,7 +54,7 @@ async def test_mcp_server_client_call_tool_no_url(): status="running", tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") with pytest.raises(WorkloadConnectionError, match="Workload has no URL"): await client.call_tool("test_tool", {"param": "value"}) @@ -64,7 +69,7 @@ async def test_mcp_server_client_call_tool_streamable(): status="running", tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") # Mock the MCP client session and result mock_result = AsyncMock() @@ -74,7 +79,7 @@ async def test_mcp_server_client_call_tool_streamable(): mock_session.call_tool.return_value = mock_result with ( - patch("mcp_optimizer.mcp_client.streamablehttp_client") as mock_client, + patch("mcp_optimizer.mcp_client.streamable_http_client") as mock_client, patch( "mcp_optimizer.mcp_client.ClientSession", return_value=mock_session ) as mock_session_class, @@ -99,7 +104,7 @@ async def test_mcp_server_client_call_tool_sse(): status="running", tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") # Mock the MCP client session and result mock_result = AsyncMock() @@ -134,7 +139,7 @@ async def test_mcp_server_client_call_tool_unsupported_transport(): status="running", tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") with pytest.raises(ValueError, match="Unsupported ToolHive URL"): await client.call_tool("test_tool", {"param": "value"}) @@ -149,7 +154,7 @@ async def test_mcp_server_client_handles_exception_group(): status="running", tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") # Create an ExceptionGroup with a nested McpError (simulating Python 3.13+ TaskGroup behavior) error_data = ErrorData(code=1, message="Session terminated") @@ -160,7 +165,7 @@ async def test_mcp_server_client_handles_exception_group(): mock_session.initialize.side_effect = exception_group with ( - patch("mcp_optimizer.mcp_client.streamablehttp_client") as mock_client, + patch("mcp_optimizer.mcp_client.streamable_http_client") as mock_client, patch( "mcp_optimizer.mcp_client.ClientSession", return_value=mock_session ) as mock_session_class, @@ -183,7 +188,7 @@ async def test_mcp_server_client_handles_mcp_error(): status="running", tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") # Create a direct McpError error_data = ErrorData(code=1, message="Connection refused") @@ -193,7 +198,7 @@ async def test_mcp_server_client_handles_mcp_error(): mock_session.initialize.side_effect = mcp_error with ( - patch("mcp_optimizer.mcp_client.streamablehttp_client") as mock_client, + patch("mcp_optimizer.mcp_client.streamable_http_client") as mock_client, patch( "mcp_optimizer.mcp_client.ClientSession", return_value=mock_session ) as mock_session_class, @@ -215,7 +220,7 @@ def test_extract_error_from_exception_group(): status="running", tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") # Test with McpError error_data = ErrorData(code=1, message="Test error") @@ -254,25 +259,25 @@ def mock_mcp_session(): @pytest.mark.parametrize( - "url,proxy_mode", + "url,transport_type", [ ("http://localhost:8080/sse/test-server", None), ("http://localhost:8080/mcp/test-server", "streamable-http"), ("http://localhost:8080/custom/endpoint", "sse"), ], ) -def test_workload_url_unchanged_after_init(url, proxy_mode): +def test_workload_url_unchanged_after_init(url, transport_type): """Test that workload URL is not modified during MCPServerClient initialization.""" workload = Workload( name="test-server", url=url, - proxy_mode=proxy_mode, + transport_type=transport_type, status="running", tool_type="mcp", ) # Create client - _client = MCPServerClient(workload, timeout=10) + _client = MCPServerClient(workload, timeout=10, runtime_mode="docker") # Verify URL is unchanged assert workload.url == url @@ -284,7 +289,7 @@ def test_workload_url_unchanged_after_init(url, proxy_mode): [ ( "http://localhost:8080/mcp/test-server", - "streamablehttp_client", + "streamable_http_client", (AsyncMock(), AsyncMock(), AsyncMock()), ), ("http://localhost:8080/sse/test-server", "sse_client", (AsyncMock(), AsyncMock())), @@ -301,7 +306,7 @@ async def test_workload_url_unchanged_during_list_tools( tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") with ( patch(f"mcp_optimizer.mcp_client.{client_mock_name}") as mock_client, @@ -325,28 +330,26 @@ async def test_workload_url_unchanged_during_list_tools( @pytest.mark.asyncio @pytest.mark.parametrize( - "url,proxy_mode,client_mock_name,context_return,expected_normalized_url", + "url,proxy_mode,client_mock_name,context_return", [ ( "http://localhost:8080/mcp/test-server", None, - "streamablehttp_client", + "streamable_http_client", (AsyncMock(), AsyncMock(), AsyncMock()), - "http://localhost:8080/mcp/test-server", # Already contains /mcp, no normalization ), ( "http://localhost:8080/custom/endpoint", "streamable-http", - "streamablehttp_client", + "streamable_http_client", (AsyncMock(), AsyncMock(), AsyncMock()), - "http://localhost:8080/custom/endpoint/mcp", # Normalized to add /mcp ), ], ) async def test_workload_url_unchanged_during_call_tool( - url, proxy_mode, client_mock_name, context_return, expected_normalized_url, mock_mcp_session + url, proxy_mode, client_mock_name, context_return, mock_mcp_session ): - """Test that workload URL remains unchanged during call_tool.""" + """Test that workload URL remains unchanged during call_tool in docker mode.""" workload = Workload( name="test-server", url=url, @@ -355,7 +358,7 @@ async def test_workload_url_unchanged_during_call_tool( tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") with ( patch(f"mcp_optimizer.mcp_client.{client_mock_name}") as mock_client, @@ -373,8 +376,8 @@ async def test_workload_url_unchanged_during_call_tool( # Verify URL is unchanged in workload (we don't mutate the workload object) assert workload.url == url - # Verify the client was called with the normalized URL (normalization happens internally) - mock_client.assert_called_once_with(expected_normalized_url) + # Verify the client was called with the original URL (no normalization) + mock_client.assert_called_once_with(url) @pytest.mark.asyncio @@ -388,10 +391,10 @@ async def test_workload_url_unchanged_multiple_operations(mock_mcp_session): tool_type="mcp", ) - client = MCPServerClient(workload, timeout=10) + client = MCPServerClient(workload, timeout=10, runtime_mode="docker") with ( - patch("mcp_optimizer.mcp_client.streamablehttp_client") as mock_client, + patch("mcp_optimizer.mcp_client.streamable_http_client") as mock_client, patch( "mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session ) as mock_session_class, @@ -414,3 +417,219 @@ async def test_workload_url_unchanged_multiple_operations(mock_mcp_session): assert mock_client.call_count == 3 for call in mock_client.call_args_list: assert call[0][0] == original_url + + +# Unit tests for determine_transport_type function + + +def test_determine_transport_type_streamable_http(): + """Test determine_transport_type with transport_type set to 'streamable-http' in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="streamable-http", + url="http://localhost:8080/some/path", + ) + result = determine_transport_type(workload, "k8s") + assert result == ToolHiveTransportMode.STREAMABLE + + +def test_determine_transport_type_sse(): + """Test determine_transport_type with transport_type set to 'sse' in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="sse", + url="http://localhost:8080/some/path", + ) + result = determine_transport_type(workload, "k8s") + assert result == ToolHiveTransportMode.SSE + + +def test_determine_transport_type_case_insensitive_streamable(): + """Test determine_transport_type with uppercase 'STREAMABLE-HTTP' in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="STREAMABLE-HTTP", + url="http://localhost:8080/some/path", + ) + result = determine_transport_type(workload, "k8s") + assert result == ToolHiveTransportMode.STREAMABLE + + +def test_determine_transport_type_case_insensitive_sse(): + """Test determine_transport_type with uppercase 'SSE' in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="SSE", + url="http://localhost:8080/some/path", + ) + result = determine_transport_type(workload, "k8s") + assert result == ToolHiveTransportMode.SSE + + +def test_determine_transport_type_mixed_case(): + """Test determine_transport_type with mixed case 'Streamable-Http' in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="Streamable-Http", + url="http://localhost:8080/some/path", + ) + result = determine_transport_type(workload, "k8s") + assert result == ToolHiveTransportMode.STREAMABLE + + +def test_determine_transport_type_fallback_to_url_mcp(): + """Test determine_transport_type falls back to URL detection for /mcp path in docker mode.""" + workload = Workload( + name="test-workload", + url="http://localhost:8080/mcp/test-server", + ) + result = determine_transport_type(workload, "docker") + assert result == ToolHiveTransportMode.STREAMABLE + + +def test_determine_transport_type_fallback_to_url_sse(): + """Test determine_transport_type falls back to URL detection for /sse path in docker mode.""" + workload = Workload( + name="test-workload", + url="http://localhost:8080/sse/test-server", + ) + result = determine_transport_type(workload, "docker") + assert result == ToolHiveTransportMode.SSE + + +def test_determine_transport_type_unknown_fallback_to_url_mcp(): + """Test determine_transport_type with unknown transport_type falls back to URL with /mcp + in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="unknown-transport", + url="http://localhost:8080/mcp/test-server", + ) + result = determine_transport_type(workload, "k8s") + assert result == ToolHiveTransportMode.STREAMABLE + + +def test_determine_transport_type_unknown_fallback_to_url_sse(): + """Test determine_transport_type with unknown transport_type falls back to URL with /sse + in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="unknown-transport", + url="http://localhost:8080/sse/test-server", + ) + result = determine_transport_type(workload, "k8s") + assert result == ToolHiveTransportMode.SSE + + +def test_determine_transport_type_no_transport_no_url(): + """Test determine_transport_type raises error when no transport_type and no URL + in docker mode.""" + workload = Workload( + name="test-workload", + ) + with pytest.raises(WorkloadConnectionError, match="No transport type or URL available"): + determine_transport_type(workload, "docker") + + +def test_determine_transport_type_unknown_transport_no_url(): + """Test determine_transport_type raises error when unknown transport_type and no URL + in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="unknown-transport", + ) + with pytest.raises(WorkloadConnectionError, match="No transport type or URL available"): + determine_transport_type(workload, "k8s") + + +def test_determine_transport_type_unknown_transport_unsupported_url(): + """Test determine_transport_type raises ValueError when unknown transport_type + and unsupported URL in k8s mode.""" + workload = Workload( + name="test-workload", + transport_type="unknown-transport", + url="http://localhost:8080/unsupported/path", + ) + with pytest.raises(ValueError, match="Unsupported ToolHive URL"): + determine_transport_type(workload, "k8s") + + +def test_determine_transport_type_no_transport_unsupported_url(): + """Test determine_transport_type raises ValueError when no transport_type + and unsupported URL in docker mode.""" + workload = Workload( + name="test-workload", + url="http://localhost:8080/unknown/path", + ) + with pytest.raises(ValueError, match="Unsupported ToolHive URL"): + determine_transport_type(workload, "docker") + + +def test_determine_transport_type_docker_mode_proxy_mode_streamable(): + """Test docker mode uses proxy_mode field for streamable-http.""" + workload = Workload( + name="test-workload", + proxy_mode="streamable-http", + url="http://localhost:8080/some/path", + ) + result = determine_transport_type(workload, "docker") + assert result == ToolHiveTransportMode.STREAMABLE + + +def test_determine_transport_type_docker_mode_proxy_mode_sse(): + """Test docker mode uses proxy_mode field for sse.""" + workload = Workload( + name="test-workload", + proxy_mode="sse", + url="http://localhost:8080/some/path", + ) + result = determine_transport_type(workload, "docker") + assert result == ToolHiveTransportMode.SSE + + +def test_determine_transport_type_docker_ignores_transport_type(): + """Test that docker mode ignores transport_type field when proxy_mode is set.""" + workload = Workload( + name="test-workload", + transport_type="sse", + proxy_mode="streamable-http", + url="http://localhost:8080/mcp/path", + ) + result = determine_transport_type(workload, "docker") + # Should use proxy_mode, not transport_type + assert result == ToolHiveTransportMode.STREAMABLE + + +def test_determine_transport_type_k8s_ignores_proxy_mode(): + """Test that k8s mode ignores proxy_mode field when transport_type is set.""" + workload = Workload( + name="test-workload", + proxy_mode="streamable-http", + transport_type="sse", + url="http://localhost:8080/sse/path", + ) + result = determine_transport_type(workload, "k8s") + # Should use transport_type, not proxy_mode + assert result == ToolHiveTransportMode.SSE + + +def test_determine_transport_type_k8s_fallback_to_url(): + """Test k8s mode falls back to URL when transport_type not set.""" + workload = Workload( + name="test-workload", + url="http://localhost:8080/sse/test-server", + ) + result = determine_transport_type(workload, "k8s") + assert result == ToolHiveTransportMode.SSE + + +def test_determine_transport_type_docker_fallback_when_no_proxy_mode(): + """Test docker mode falls back to URL when proxy_mode not set but transport_type is.""" + workload = Workload( + name="test-workload", + transport_type="streamable-http", + url="http://localhost:8080/mcp/test-server", + ) + result = determine_transport_type(workload, "docker") + # Docker mode should ignore transport_type and fallback to URL + assert result == ToolHiveTransportMode.STREAMABLE