Skip to content

Commit d3b3667

Browse files
committed
feat: pass client info such as kiro/1.0.0 to user-agent
1 parent 46cd380 commit d3b3667

File tree

7 files changed

+256
-9
lines changed

7 files changed

+256
-9
lines changed

mcp_proxy_for_aws/context.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Module-level storage for session-scoped data."""
16+
17+
from mcp.types import Implementation
18+
from typing import Optional
19+
20+
21+
_client_info: Optional[Implementation] = None
22+
23+
24+
def get_client_info() -> Optional[Implementation]:
25+
"""Get the stored client info."""
26+
return _client_info
27+
28+
29+
def set_client_info(info: Optional[Implementation]) -> None:
30+
"""Set the client info."""
31+
global _client_info
32+
_client_info = info
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from collections.abc import Awaitable, Callable
17+
from fastmcp.server.middleware import Middleware, MiddlewareContext
18+
from mcp import types as mt
19+
from mcp_proxy_for_aws.context import set_client_info
20+
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class ClientInfoMiddleware(Middleware):
26+
"""Middleware to capture client_info from initialize method."""
27+
28+
async def on_initialize(
29+
self,
30+
context: MiddlewareContext[mt.InitializeRequest],
31+
call_next: Callable[[MiddlewareContext[mt.InitializeRequest]], Awaitable[None]],
32+
) -> None:
33+
"""Capture client_info from initialize request."""
34+
if context.message.params and context.message.params.clientInfo:
35+
info = context.message.params.clientInfo
36+
set_client_info(info)
37+
logger.info('Captured client_info: name=%s, version=%s', info.name, info.version)
38+
39+
await call_next(context)

mcp_proxy_for_aws/middleware/tool_filter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
from typing import Sequence
2020

2121

22+
logger = logging.getLogger(__name__)
23+
24+
2225
class ToolFilteringMiddleware(Middleware):
2326
"""Middleware to filter tools based on read only flag."""
2427

25-
def __init__(self, read_only: bool, logger: logging.Logger | None = None):
28+
def __init__(self, read_only: bool):
2629
"""Initialize the middleware."""
2730
self.read_only = read_only
28-
self.logger = logger or logging.getLogger(__name__)
2931

3032
async def on_list_tools(
3133
self,
@@ -35,7 +37,7 @@ async def on_list_tools(
3537
"""Filter tools based on read only flag."""
3638
# Get list of FastMCP Components
3739
tools = await call_next(context)
38-
self.logger.info('Filtering tools for read only: %s', self.read_only)
40+
logger.info('Filtering tools for read only: %s', self.read_only)
3941

4042
# If not read only, return the list of tools as is
4143
if not self.read_only:
@@ -50,7 +52,7 @@ async def on_list_tools(
5052
read_only_hint = getattr(annotations, 'readOnlyHint', False)
5153
if not read_only_hint:
5254
# Skip tools that don't have readOnlyHint=True
53-
self.logger.info('Skipping tool %s needing write permissions', tool.name)
55+
logger.info('Skipping tool %s needing write permissions', tool.name)
5456
continue
5557

5658
filtered_tools.append(tool)

mcp_proxy_for_aws/server.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from mcp_proxy_for_aws.cli import parse_args
4444
from mcp_proxy_for_aws.logging_config import configure_logging
45+
from mcp_proxy_for_aws.middleware.client_info import ClientInfoMiddleware
4546
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
4647
from mcp_proxy_for_aws.utils import (
4748
create_transport_with_sigv4,
@@ -167,6 +168,7 @@ async def client_factory():
167168
'This proxy handles authentication and request routing to the appropriate backend services.'
168169
),
169170
)
171+
add_client_info_middleware(proxy)
170172
add_logging_middleware(proxy, args.log_level)
171173
add_tool_filtering_middleware(proxy, args.read_only)
172174

@@ -178,6 +180,16 @@ async def client_factory():
178180
raise e
179181

180182

183+
def add_client_info_middleware(mcp: FastMCP) -> None:
184+
"""Add client info middleware to capture client_info from initialize.
185+
186+
Args:
187+
mcp: The FastMCP instance to add client info middleware to
188+
"""
189+
logger.info('Adding client info middleware')
190+
mcp.add_middleware(ClientInfoMiddleware())
191+
192+
181193
def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
182194
"""Add tool filtering middleware to target MCP server.
183195
@@ -186,11 +198,7 @@ def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None
186198
read_only: Whether or not to filter out tools that require write permissions
187199
"""
188200
logger.info('Adding tool filtering middleware')
189-
mcp.add_middleware(
190-
ToolFilteringMiddleware(
191-
read_only=read_only,
192-
)
193-
)
201+
mcp.add_middleware(ToolFilteringMiddleware(read_only))
194202

195203

196204
def add_retry_middleware(mcp: FastMCP, retries: int) -> None:

mcp_proxy_for_aws/sigv4_helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from botocore.awsrequest import AWSRequest
2323
from botocore.credentials import Credentials
2424
from functools import partial
25+
from mcp_proxy_for_aws import __version__
26+
from mcp_proxy_for_aws.context import get_client_info
2527
from typing import Any, Dict, Generator, Optional
2628

2729

@@ -228,6 +230,13 @@ async def _sign_request_hook(
228230
# Set Content-Length for signing
229231
request.headers['Content-Length'] = str(len(request.content))
230232

233+
# Build User-Agent from client_info if available
234+
info = get_client_info()
235+
if info:
236+
user_agent = f'{info.name}/{info.version} mcp-proxy-for-aws/{__version__}'
237+
request.headers['User-Agent'] = user_agent
238+
logger.info('Set User-Agent header: %s', user_agent)
239+
231240
# Get AWS credentials
232241
session = create_aws_session(profile)
233242
credentials = session.get_credentials()
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from datetime import datetime
17+
from fastmcp.server.middleware import MiddlewareContext
18+
from mcp import types as mt
19+
from mcp_proxy_for_aws.context import get_client_info, set_client_info
20+
from mcp_proxy_for_aws.middleware.client_info import ClientInfoMiddleware
21+
22+
23+
@pytest.fixture
24+
def middleware():
25+
"""Create a ClientInfoMiddleware instance."""
26+
return ClientInfoMiddleware()
27+
28+
29+
@pytest.fixture
30+
def mock_context_with_client_info():
31+
"""Create a mock context with client_info."""
32+
params = mt.InitializeRequestParams(
33+
protocolVersion='2024-11-05',
34+
capabilities=mt.ClientCapabilities(),
35+
clientInfo=mt.Implementation(name='test-client', version='1.0.0'),
36+
)
37+
message = mt.InitializeRequest(
38+
method='initialize',
39+
params=params,
40+
)
41+
return MiddlewareContext(
42+
message=message,
43+
fastmcp_context=None,
44+
source='client',
45+
type='request',
46+
method='initialize',
47+
timestamp=datetime.now(),
48+
)
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_captures_client_info(middleware, mock_context_with_client_info):
53+
"""Test that middleware captures client_info from initialize request."""
54+
# Reset context variable
55+
set_client_info(None)
56+
57+
async def call_next(ctx):
58+
pass
59+
60+
await middleware.on_initialize(mock_context_with_client_info, call_next)
61+
62+
# Verify client_info was captured
63+
info = get_client_info()
64+
assert info is not None
65+
assert info.name == 'test-client'
66+
assert info.version == '1.0.0'
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_calls_next_middleware(middleware, mock_context_with_client_info):
71+
"""Test that middleware calls the next middleware in chain."""
72+
called = False
73+
74+
async def call_next(ctx):
75+
nonlocal called
76+
called = True
77+
78+
await middleware.on_initialize(mock_context_with_client_info, call_next)
79+
80+
assert called is True
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_captures_different_client_info(middleware):
85+
"""Test that middleware captures different client_info values."""
86+
# Reset context variable
87+
set_client_info(None)
88+
89+
params = mt.InitializeRequestParams(
90+
protocolVersion='2024-11-05',
91+
capabilities=mt.ClientCapabilities(),
92+
clientInfo=mt.Implementation(name='another-client', version='2.5.3'),
93+
)
94+
message = mt.InitializeRequest(
95+
method='initialize',
96+
params=params,
97+
)
98+
context = MiddlewareContext(
99+
message=message,
100+
fastmcp_context=None,
101+
source='client',
102+
type='request',
103+
method='initialize',
104+
timestamp=datetime.now(),
105+
)
106+
107+
async def call_next(ctx):
108+
pass
109+
110+
await middleware.on_initialize(context, call_next)
111+
112+
# Verify client_info was captured with correct values
113+
info = get_client_info()
114+
assert info is not None
115+
assert info.name == 'another-client'
116+
assert info.version == '2.5.3'

tests/unit/test_hooks.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,44 @@ async def test_sign_request_hook_with_partial_application(self, mock_create_sess
435435
assert 'authorization' in request.headers
436436
assert 'x-amz-date' in request.headers
437437
mock_create_session.assert_called_once_with(profile)
438+
439+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
440+
@pytest.mark.asyncio
441+
async def test_sign_request_hook_sets_user_agent_from_client_info(self, mock_create_session):
442+
"""Test that sign_request_hook sets User-Agent from client_info context."""
443+
from mcp.types import Implementation
444+
from mcp_proxy_for_aws import __version__
445+
from mcp_proxy_for_aws.context import set_client_info
446+
447+
mock_create_session.return_value = create_mock_session()
448+
449+
# Set client_info in context
450+
info = Implementation(name='test-client', version='2.5.0')
451+
set_client_info(info)
452+
453+
request = httpx.Request('POST', 'https://example.com/mcp', content=b'test')
454+
await _sign_request_hook('us-east-1', 'bedrock-agentcore', None, request)
455+
456+
assert (
457+
request.headers['user-agent'] == f'test-client/2.5.0 mcp-proxy-for-aws/{__version__}'
458+
)
459+
460+
# Clean up
461+
set_client_info(None)
462+
463+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
464+
@pytest.mark.asyncio
465+
async def test_sign_request_hook_without_client_info(self, mock_create_session):
466+
"""Test that sign_request_hook works without client_info."""
467+
from mcp_proxy_for_aws.context import set_client_info
468+
469+
mock_create_session.return_value = create_mock_session()
470+
471+
# Ensure client_info is None
472+
set_client_info(None)
473+
474+
request = httpx.Request('POST', 'https://example.com/mcp', content=b'test')
475+
await _sign_request_hook('us-east-1', 'bedrock-agentcore', None, request)
476+
477+
assert 'user-agent' not in request.headers
478+
assert 'authorization' in request.headers

0 commit comments

Comments
 (0)