Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
import httpx
import logging
import sys
from fastmcp import Client
from fastmcp.client import ClientTransport
from fastmcp.server.middleware.error_handling import RetryMiddleware
from fastmcp.server.middleware.logging import LoggingMiddleware
from fastmcp.server.proxy import FastMCPProxy, ProxyClient
from fastmcp.server.server import FastMCP
from mcp import McpError
from mcp.types import (
Expand Down Expand Up @@ -60,7 +60,7 @@ async def _initialize_client(transport: ClientTransport):
# logger.debug('First line from kiro %s', line)
async with contextlib.AsyncExitStack() as stack:
try:
client = await stack.enter_async_context(Client(transport))
client = await stack.enter_async_context(ProxyClient(transport))
if client.initialize_result:
print(
client.initialize_result.model_dump_json(
Expand Down Expand Up @@ -157,10 +157,20 @@ async def run_proxy(args) -> None:
transport = create_transport_with_sigv4(
args.endpoint, service, region, metadata, timeout, profile
)

async with _initialize_client(transport) as client:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client should be a ProxyClient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we probably want a mix of the if-else branches

  1. connect the session to be reused
  2. use ProxyClient to support features like elicitation and sampling; the regular client does not support


async def client_factory():
nonlocal client
if not client.is_connected():
logger.debug('Reinitialize client')
client = ProxyClient(transport)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if you do client.new() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It returns Client and not ProxyClient

await client._connect()
Copy link
Contributor

@wzxxing wzxxing Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this method throws, the client also hangs, the _initialize_client context manager writes the jsonrpc error to stdout.

Can this be done here too? I think it is a bit tricky to get the new json rpc message id.

return client

try:
proxy = FastMCP.as_proxy(
client,
proxy = FastMCPProxy(
client_factory=client_factory,
name='MCP Proxy for AWS',
instructions=(
'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. '
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_initialize_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def test_successful_initialization():
mock_transport = Mock()
mock_client = Mock()

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)

Expand All @@ -48,7 +48,7 @@ async def test_http_error_with_jsonrpc_error(capsys):

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
Expand All @@ -70,7 +70,7 @@ async def test_http_error_with_jsonrpc_response(capsys):

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
Expand All @@ -91,7 +91,7 @@ async def test_http_error_with_invalid_json():

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
Expand All @@ -109,7 +109,7 @@ async def test_http_error_with_non_jsonrpc_message():

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
Expand All @@ -127,7 +127,7 @@ async def test_http_error_response_read_failure():

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
Expand All @@ -144,7 +144,7 @@ async def test_generic_error_with_mcp_error_cause(capsys):
generic_error = Exception('Wrapper error')
generic_error.__cause__ = mcp_error

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)

with pytest.raises(Exception):
Expand All @@ -162,7 +162,7 @@ async def test_generic_error_without_mcp_error_cause(capsys):
mock_transport = Mock()
generic_error = Exception('Generic error')

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)

with pytest.raises(Exception):
Expand Down
50 changes: 28 additions & 22 deletions tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
class TestServer:
"""Tests for the server module."""

@patch('mcp_proxy_for_aws.server.Client')
@patch('mcp_proxy_for_aws.server.ProxyClient')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -43,7 +43,7 @@ async def test_setup_mcp_mode(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_as_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_class,
):
Expand Down Expand Up @@ -73,13 +73,15 @@ async def test_setup_mcp_mode(
mock_create_transport.return_value = mock_transport

mock_client = Mock()
mock_client.initialize_result = None
mock_client.is_connected = Mock(return_value=True)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_as_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -97,15 +99,14 @@ async def test_setup_mcp_mode(
# call_args[0][4] is the Timeout object
assert call_args[0][5] is None # profile
mock_client_class.assert_called_once_with(mock_transport)
mock_as_proxy.assert_called_once()
assert mock_as_proxy.call_args[0][0] == mock_client
mock_fastmcp_proxy.assert_called_once()
mock_add_filtering.assert_called_once_with(mock_proxy, True)
mock_add_retry.assert_called_once_with(mock_proxy, 1)
mock_proxy.run_async.assert_called_once()
mock_proxy.run_async.assert_called_once_with(transport='stdio')

@patch('mcp_proxy_for_aws.server.Client')
@patch('mcp_proxy_for_aws.server.ProxyClient')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -114,7 +115,7 @@ async def test_setup_mcp_mode_no_retries(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_as_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_class,
):
Expand Down Expand Up @@ -144,13 +145,15 @@ async def test_setup_mcp_mode_no_retries(
mock_create_transport.return_value = mock_transport

mock_client = Mock()
mock_client.initialize_result = None
mock_client.is_connected = Mock(return_value=True)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_as_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -171,14 +174,13 @@ async def test_setup_mcp_mode_no_retries(
# call_args[0][4] is the Timeout object
assert call_args[0][5] == 'test-profile' # profile
mock_client_class.assert_called_once_with(mock_transport)
mock_as_proxy.assert_called_once()
assert mock_as_proxy.call_args[0][0] == mock_client
mock_fastmcp_proxy.assert_called_once()
mock_add_filtering.assert_called_once_with(mock_proxy, False)
mock_proxy.run_async.assert_called_once()
mock_proxy.run_async.assert_called_once_with(transport='stdio')

@patch('mcp_proxy_for_aws.server.Client')
@patch('mcp_proxy_for_aws.server.ProxyClient')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -187,7 +189,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_as_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_class,
):
Expand All @@ -214,13 +216,15 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_create_transport.return_value = mock_transport

mock_client = Mock()
mock_client.initialize_result = None
mock_client.is_connected = Mock(return_value=True)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_as_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -231,9 +235,9 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
metadata = call_args[0][3]
assert metadata == {'AWS_REGION': 'ap-southeast-1'}

@patch('mcp_proxy_for_aws.server.Client')
@patch('mcp_proxy_for_aws.server.ProxyClient')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -242,7 +246,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_as_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_class,
):
Expand All @@ -269,13 +273,15 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_create_transport.return_value = mock_transport

mock_client = Mock()
mock_client.initialize_result = None
mock_client.is_connected = Mock(return_value=True)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_as_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand Down
Loading