Skip to content

Commit c7ff533

Browse files
authored
fix: use factory to refresh session once it is finished (#97)
* Use factory to refresh session once it is finished * Use ProxyClient * Update unit tests
1 parent 9632822 commit c7ff533

File tree

3 files changed

+50
-34
lines changed

3 files changed

+50
-34
lines changed

mcp_proxy_for_aws/server.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
import httpx
2828
import logging
2929
import sys
30-
from fastmcp import Client
3130
from fastmcp.client import ClientTransport
3231
from fastmcp.server.middleware.error_handling import RetryMiddleware
3332
from fastmcp.server.middleware.logging import LoggingMiddleware
33+
from fastmcp.server.proxy import FastMCPProxy, ProxyClient
3434
from fastmcp.server.server import FastMCP
3535
from mcp import McpError
3636
from mcp.types import (
@@ -60,7 +60,7 @@ async def _initialize_client(transport: ClientTransport):
6060
# logger.debug('First line from kiro %s', line)
6161
async with contextlib.AsyncExitStack() as stack:
6262
try:
63-
client = await stack.enter_async_context(Client(transport))
63+
client = await stack.enter_async_context(ProxyClient(transport))
6464
if client.initialize_result:
6565
print(
6666
client.initialize_result.model_dump_json(
@@ -157,10 +157,20 @@ async def run_proxy(args) -> None:
157157
transport = create_transport_with_sigv4(
158158
args.endpoint, service, region, metadata, timeout, profile
159159
)
160+
160161
async with _initialize_client(transport) as client:
162+
163+
async def client_factory():
164+
nonlocal client
165+
if not client.is_connected():
166+
logger.debug('Reinitialize client')
167+
client = ProxyClient(transport)
168+
await client._connect()
169+
return client
170+
161171
try:
162-
proxy = FastMCP.as_proxy(
163-
client,
172+
proxy = FastMCPProxy(
173+
client_factory=client_factory,
164174
name='MCP Proxy for AWS',
165175
instructions=(
166176
'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. '

tests/unit/test_initialize_client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def test_successful_initialization():
2828
mock_transport = Mock()
2929
mock_client = Mock()
3030

31-
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
31+
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
3232
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
3333
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)
3434

@@ -48,7 +48,7 @@ async def test_http_error_with_jsonrpc_error(capsys):
4848

4949
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
5050

51-
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
51+
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
5252
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
5353

5454
with pytest.raises(httpx.HTTPStatusError):
@@ -70,7 +70,7 @@ async def test_http_error_with_jsonrpc_response(capsys):
7070

7171
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
7272

73-
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
73+
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
7474
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
7575

7676
with pytest.raises(httpx.HTTPStatusError):
@@ -91,7 +91,7 @@ async def test_http_error_with_invalid_json():
9191

9292
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
9393

94-
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
94+
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
9595
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
9696

9797
with pytest.raises(httpx.HTTPStatusError):
@@ -109,7 +109,7 @@ async def test_http_error_with_non_jsonrpc_message():
109109

110110
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
111111

112-
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
112+
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
113113
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
114114

115115
with pytest.raises(httpx.HTTPStatusError):
@@ -127,7 +127,7 @@ async def test_http_error_response_read_failure():
127127

128128
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
129129

130-
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
130+
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
131131
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
132132

133133
with pytest.raises(httpx.HTTPStatusError):
@@ -144,7 +144,7 @@ async def test_generic_error_with_mcp_error_cause(capsys):
144144
generic_error = Exception('Wrapper error')
145145
generic_error.__cause__ = mcp_error
146146

147-
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
147+
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
148148
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)
149149

150150
with pytest.raises(Exception):
@@ -162,7 +162,7 @@ async def test_generic_error_without_mcp_error_cause(capsys):
162162
mock_transport = Mock()
163163
generic_error = Exception('Generic error')
164164

165-
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
165+
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
166166
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)
167167

168168
with pytest.raises(Exception):

tests/unit/test_server.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
class TestServer:
3131
"""Tests for the server module."""
3232

33-
@patch('mcp_proxy_for_aws.server.Client')
33+
@patch('mcp_proxy_for_aws.server.ProxyClient')
3434
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
35-
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
35+
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
3636
@patch('mcp_proxy_for_aws.server.determine_aws_region')
3737
@patch('mcp_proxy_for_aws.server.determine_service_name')
3838
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@@ -43,7 +43,7 @@ async def test_setup_mcp_mode(
4343
mock_add_filtering,
4444
mock_determine_service,
4545
mock_determine_region,
46-
mock_as_proxy,
46+
mock_fastmcp_proxy,
4747
mock_create_transport,
4848
mock_client_class,
4949
):
@@ -73,13 +73,15 @@ async def test_setup_mcp_mode(
7373
mock_create_transport.return_value = mock_transport
7474

7575
mock_client = Mock()
76+
mock_client.initialize_result = None
77+
mock_client.is_connected = Mock(return_value=True)
7678
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
7779
mock_client.__aexit__ = AsyncMock(return_value=None)
7880
mock_client_class.return_value = mock_client
7981

8082
mock_proxy = Mock()
8183
mock_proxy.run_async = AsyncMock()
82-
mock_as_proxy.return_value = mock_proxy
84+
mock_fastmcp_proxy.return_value = mock_proxy
8385

8486
# Act
8587
await run_proxy(mock_args)
@@ -97,15 +99,14 @@ async def test_setup_mcp_mode(
9799
# call_args[0][4] is the Timeout object
98100
assert call_args[0][5] is None # profile
99101
mock_client_class.assert_called_once_with(mock_transport)
100-
mock_as_proxy.assert_called_once()
101-
assert mock_as_proxy.call_args[0][0] == mock_client
102+
mock_fastmcp_proxy.assert_called_once()
102103
mock_add_filtering.assert_called_once_with(mock_proxy, True)
103104
mock_add_retry.assert_called_once_with(mock_proxy, 1)
104-
mock_proxy.run_async.assert_called_once()
105+
mock_proxy.run_async.assert_called_once_with(transport='stdio')
105106

106-
@patch('mcp_proxy_for_aws.server.Client')
107+
@patch('mcp_proxy_for_aws.server.ProxyClient')
107108
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
108-
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
109+
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
109110
@patch('mcp_proxy_for_aws.server.determine_aws_region')
110111
@patch('mcp_proxy_for_aws.server.determine_service_name')
111112
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@@ -114,7 +115,7 @@ async def test_setup_mcp_mode_no_retries(
114115
mock_add_filtering,
115116
mock_determine_service,
116117
mock_determine_region,
117-
mock_as_proxy,
118+
mock_fastmcp_proxy,
118119
mock_create_transport,
119120
mock_client_class,
120121
):
@@ -144,13 +145,15 @@ async def test_setup_mcp_mode_no_retries(
144145
mock_create_transport.return_value = mock_transport
145146

146147
mock_client = Mock()
148+
mock_client.initialize_result = None
149+
mock_client.is_connected = Mock(return_value=True)
147150
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
148151
mock_client.__aexit__ = AsyncMock(return_value=None)
149152
mock_client_class.return_value = mock_client
150153

151154
mock_proxy = Mock()
152155
mock_proxy.run_async = AsyncMock()
153-
mock_as_proxy.return_value = mock_proxy
156+
mock_fastmcp_proxy.return_value = mock_proxy
154157

155158
# Act
156159
await run_proxy(mock_args)
@@ -171,14 +174,13 @@ async def test_setup_mcp_mode_no_retries(
171174
# call_args[0][4] is the Timeout object
172175
assert call_args[0][5] == 'test-profile' # profile
173176
mock_client_class.assert_called_once_with(mock_transport)
174-
mock_as_proxy.assert_called_once()
175-
assert mock_as_proxy.call_args[0][0] == mock_client
177+
mock_fastmcp_proxy.assert_called_once()
176178
mock_add_filtering.assert_called_once_with(mock_proxy, False)
177-
mock_proxy.run_async.assert_called_once()
179+
mock_proxy.run_async.assert_called_once_with(transport='stdio')
178180

179-
@patch('mcp_proxy_for_aws.server.Client')
181+
@patch('mcp_proxy_for_aws.server.ProxyClient')
180182
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
181-
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
183+
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
182184
@patch('mcp_proxy_for_aws.server.determine_aws_region')
183185
@patch('mcp_proxy_for_aws.server.determine_service_name')
184186
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@@ -187,7 +189,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
187189
mock_add_filtering,
188190
mock_determine_service,
189191
mock_determine_region,
190-
mock_as_proxy,
192+
mock_fastmcp_proxy,
191193
mock_create_transport,
192194
mock_client_class,
193195
):
@@ -214,13 +216,15 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
214216
mock_create_transport.return_value = mock_transport
215217

216218
mock_client = Mock()
219+
mock_client.initialize_result = None
220+
mock_client.is_connected = Mock(return_value=True)
217221
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
218222
mock_client.__aexit__ = AsyncMock(return_value=None)
219223
mock_client_class.return_value = mock_client
220224

221225
mock_proxy = Mock()
222226
mock_proxy.run_async = AsyncMock()
223-
mock_as_proxy.return_value = mock_proxy
227+
mock_fastmcp_proxy.return_value = mock_proxy
224228

225229
# Act
226230
await run_proxy(mock_args)
@@ -231,9 +235,9 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
231235
metadata = call_args[0][3]
232236
assert metadata == {'AWS_REGION': 'ap-southeast-1'}
233237

234-
@patch('mcp_proxy_for_aws.server.Client')
238+
@patch('mcp_proxy_for_aws.server.ProxyClient')
235239
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
236-
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
240+
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
237241
@patch('mcp_proxy_for_aws.server.determine_aws_region')
238242
@patch('mcp_proxy_for_aws.server.determine_service_name')
239243
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@@ -242,7 +246,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
242246
mock_add_filtering,
243247
mock_determine_service,
244248
mock_determine_region,
245-
mock_as_proxy,
249+
mock_fastmcp_proxy,
246250
mock_create_transport,
247251
mock_client_class,
248252
):
@@ -269,13 +273,15 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
269273
mock_create_transport.return_value = mock_transport
270274

271275
mock_client = Mock()
276+
mock_client.initialize_result = None
277+
mock_client.is_connected = Mock(return_value=True)
272278
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
273279
mock_client.__aexit__ = AsyncMock(return_value=None)
274280
mock_client_class.return_value = mock_client
275281

276282
mock_proxy = Mock()
277283
mock_proxy.run_async = AsyncMock()
278-
mock_as_proxy.return_value = mock_proxy
284+
mock_fastmcp_proxy.return_value = mock_proxy
279285

280286
# Act
281287
await run_proxy(mock_args)

0 commit comments

Comments
 (0)