From 09f7fcdc4f30aa8caa4a0e9ed64fe9d32b2f0ab7 Mon Sep 17 00:00:00 2001 From: Arne Wouters Date: Wed, 26 Nov 2025 21:06:24 +0100 Subject: [PATCH 1/3] Use factory to refresh session once it is finished --- mcp_proxy_for_aws/server.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index c782bd1..1b4c209 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -28,6 +28,7 @@ from fastmcp import Client from fastmcp.server.middleware.error_handling import RetryMiddleware from fastmcp.server.middleware.logging import LoggingMiddleware +from fastmcp.server.proxy import FastMCPProxy from fastmcp.server.server import FastMCP from mcp_proxy_for_aws.cli import parse_args from mcp_proxy_for_aws.logging_config import configure_logging @@ -85,8 +86,17 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None: args.endpoint, service, region, metadata, timeout, profile ) async with Client(transport=transport) as client: - # Create proxy with the transport - proxy = FastMCP.as_proxy(client) + + async def client_factory(): + nonlocal client + if not client.is_connected(): + logger.debug('Reinitialize client') + client = client.new() + await client._connect() + return client + + proxy = FastMCPProxy(client_factory=client_factory) + add_logging_middleware(proxy, args.log_level) add_tool_filtering_middleware(proxy, args.read_only) From f720d226b0305017b95e1f3f83eb7f97aa302ed4 Mon Sep 17 00:00:00 2001 From: Arne Wouters Date: Thu, 27 Nov 2025 10:32:27 +0100 Subject: [PATCH 2/3] Use ProxyClient --- mcp_proxy_for_aws/server.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index c2874b6..9becb18 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -27,11 +27,10 @@ import httpx import logging import sys -from fastmcp import Client from fastmcp.client import ClientTransport from fastmcp.server.middleware.error_handling import RetryMiddleware from fastmcp.server.middleware.logging import LoggingMiddleware -from fastmcp.server.proxy import FastMCPProxy +from fastmcp.server.proxy import FastMCPProxy, ProxyClient from fastmcp.server.server import FastMCP from mcp import McpError from mcp.types import ( @@ -61,7 +60,7 @@ async def _initialize_client(transport: ClientTransport): # logger.debug('First line from kiro %s', line) async with contextlib.AsyncExitStack() as stack: try: - client = await stack.enter_async_context(Client(transport)) + client = await stack.enter_async_context(ProxyClient(transport)) if client.initialize_result: print( client.initialize_result.model_dump_json( @@ -165,7 +164,7 @@ async def client_factory(): nonlocal client if not client.is_connected(): logger.debug('Reinitialize client') - client = client.new() + client = ProxyClient(transport) await client._connect() return client From 12441925379ca395b47dfad4475b890ece68611c Mon Sep 17 00:00:00 2001 From: Arne Wouters Date: Thu, 27 Nov 2025 10:32:41 +0100 Subject: [PATCH 3/3] Update unit tests --- tests/unit/test_initialize_client.py | 16 ++++----- tests/unit/test_server.py | 50 ++++++++++++++++------------ 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/tests/unit/test_initialize_client.py b/tests/unit/test_initialize_client.py index b2cfabe..a9568c6 100644 --- a/tests/unit/test_initialize_client.py +++ b/tests/unit/test_initialize_client.py @@ -28,7 +28,7 @@ async def test_successful_initialization(): mock_transport = Mock() mock_client = Mock() - with patch('mcp_proxy_for_aws.server.Client') as mock_client_class: + with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) @@ -48,7 +48,7 @@ async def test_http_error_with_jsonrpc_error(capsys): http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - with patch('mcp_proxy_for_aws.server.Client') as mock_client_class: + with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) with pytest.raises(httpx.HTTPStatusError): @@ -70,7 +70,7 @@ async def test_http_error_with_jsonrpc_response(capsys): http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - with patch('mcp_proxy_for_aws.server.Client') as mock_client_class: + with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) with pytest.raises(httpx.HTTPStatusError): @@ -91,7 +91,7 @@ async def test_http_error_with_invalid_json(): http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - with patch('mcp_proxy_for_aws.server.Client') as mock_client_class: + with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) with pytest.raises(httpx.HTTPStatusError): @@ -109,7 +109,7 @@ async def test_http_error_with_non_jsonrpc_message(): http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - with patch('mcp_proxy_for_aws.server.Client') as mock_client_class: + with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) with pytest.raises(httpx.HTTPStatusError): @@ -127,7 +127,7 @@ async def test_http_error_response_read_failure(): http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - with patch('mcp_proxy_for_aws.server.Client') as mock_client_class: + with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) with pytest.raises(httpx.HTTPStatusError): @@ -144,7 +144,7 @@ async def test_generic_error_with_mcp_error_cause(capsys): generic_error = Exception('Wrapper error') generic_error.__cause__ = mcp_error - with patch('mcp_proxy_for_aws.server.Client') as mock_client_class: + with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error) with pytest.raises(Exception): @@ -162,7 +162,7 @@ async def test_generic_error_without_mcp_error_cause(capsys): mock_transport = Mock() generic_error = Exception('Generic error') - with patch('mcp_proxy_for_aws.server.Client') as mock_client_class: + with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error) with pytest.raises(Exception): diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 015fa89..e42433b 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -30,9 +30,9 @@ class TestServer: """Tests for the server module.""" - @patch('mcp_proxy_for_aws.server.Client') + @patch('mcp_proxy_for_aws.server.ProxyClient') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') - @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') + @patch('mcp_proxy_for_aws.server.FastMCPProxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @patch('mcp_proxy_for_aws.server.determine_service_name') @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') @@ -43,7 +43,7 @@ async def test_setup_mcp_mode( mock_add_filtering, mock_determine_service, mock_determine_region, - mock_as_proxy, + mock_fastmcp_proxy, mock_create_transport, mock_client_class, ): @@ -73,13 +73,15 @@ async def test_setup_mcp_mode( mock_create_transport.return_value = mock_transport mock_client = Mock() + mock_client.initialize_result = None + mock_client.is_connected = Mock(return_value=True) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) mock_client_class.return_value = mock_client mock_proxy = Mock() mock_proxy.run_async = AsyncMock() - mock_as_proxy.return_value = mock_proxy + mock_fastmcp_proxy.return_value = mock_proxy # Act await run_proxy(mock_args) @@ -97,15 +99,14 @@ async def test_setup_mcp_mode( # call_args[0][4] is the Timeout object assert call_args[0][5] is None # profile mock_client_class.assert_called_once_with(mock_transport) - mock_as_proxy.assert_called_once() - assert mock_as_proxy.call_args[0][0] == mock_client + mock_fastmcp_proxy.assert_called_once() mock_add_filtering.assert_called_once_with(mock_proxy, True) mock_add_retry.assert_called_once_with(mock_proxy, 1) - mock_proxy.run_async.assert_called_once() + mock_proxy.run_async.assert_called_once_with(transport='stdio') - @patch('mcp_proxy_for_aws.server.Client') + @patch('mcp_proxy_for_aws.server.ProxyClient') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') - @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') + @patch('mcp_proxy_for_aws.server.FastMCPProxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @patch('mcp_proxy_for_aws.server.determine_service_name') @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') @@ -114,7 +115,7 @@ async def test_setup_mcp_mode_no_retries( mock_add_filtering, mock_determine_service, mock_determine_region, - mock_as_proxy, + mock_fastmcp_proxy, mock_create_transport, mock_client_class, ): @@ -144,13 +145,15 @@ async def test_setup_mcp_mode_no_retries( mock_create_transport.return_value = mock_transport mock_client = Mock() + mock_client.initialize_result = None + mock_client.is_connected = Mock(return_value=True) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) mock_client_class.return_value = mock_client mock_proxy = Mock() mock_proxy.run_async = AsyncMock() - mock_as_proxy.return_value = mock_proxy + mock_fastmcp_proxy.return_value = mock_proxy # Act await run_proxy(mock_args) @@ -171,14 +174,13 @@ async def test_setup_mcp_mode_no_retries( # call_args[0][4] is the Timeout object assert call_args[0][5] == 'test-profile' # profile mock_client_class.assert_called_once_with(mock_transport) - mock_as_proxy.assert_called_once() - assert mock_as_proxy.call_args[0][0] == mock_client + mock_fastmcp_proxy.assert_called_once() mock_add_filtering.assert_called_once_with(mock_proxy, False) - mock_proxy.run_async.assert_called_once() + mock_proxy.run_async.assert_called_once_with(transport='stdio') - @patch('mcp_proxy_for_aws.server.Client') + @patch('mcp_proxy_for_aws.server.ProxyClient') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') - @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') + @patch('mcp_proxy_for_aws.server.FastMCPProxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @patch('mcp_proxy_for_aws.server.determine_service_name') @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') @@ -187,7 +189,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( mock_add_filtering, mock_determine_service, mock_determine_region, - mock_as_proxy, + mock_fastmcp_proxy, mock_create_transport, mock_client_class, ): @@ -214,13 +216,15 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( mock_create_transport.return_value = mock_transport mock_client = Mock() + mock_client.initialize_result = None + mock_client.is_connected = Mock(return_value=True) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) mock_client_class.return_value = mock_client mock_proxy = Mock() mock_proxy.run_async = AsyncMock() - mock_as_proxy.return_value = mock_proxy + mock_fastmcp_proxy.return_value = mock_proxy # Act await run_proxy(mock_args) @@ -231,9 +235,9 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( metadata = call_args[0][3] assert metadata == {'AWS_REGION': 'ap-southeast-1'} - @patch('mcp_proxy_for_aws.server.Client') + @patch('mcp_proxy_for_aws.server.ProxyClient') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') - @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') + @patch('mcp_proxy_for_aws.server.FastMCPProxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @patch('mcp_proxy_for_aws.server.determine_service_name') @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') @@ -242,7 +246,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it( mock_add_filtering, mock_determine_service, mock_determine_region, - mock_as_proxy, + mock_fastmcp_proxy, mock_create_transport, mock_client_class, ): @@ -269,13 +273,15 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it( mock_create_transport.return_value = mock_transport mock_client = Mock() + mock_client.initialize_result = None + mock_client.is_connected = Mock(return_value=True) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) mock_client_class.return_value = mock_client mock_proxy = Mock() mock_proxy.run_async = AsyncMock() - mock_as_proxy.return_value = mock_proxy + mock_fastmcp_proxy.return_value = mock_proxy # Act await run_proxy(mock_args)