From 3e80e249961e3bffa52d70d47bcc9aa37897baac Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 25 Mar 2025 19:22:43 +0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- codemcp/testing.py | 101 +++++++++++++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 30 deletions(-) diff --git a/codemcp/testing.py b/codemcp/testing.py index ac972d6f..b63abe35 100644 --- a/codemcp/testing.py +++ b/codemcp/testing.py @@ -1,19 +1,47 @@ #!/usr/bin/env python3 +# pyright: reportUnknownArgumentType=false, reportUnknownMemberType=false, reportUnknownVariableType=false import asyncio import os +import re import subprocess import sys import tempfile import unittest from contextlib import asynccontextmanager -from typing import Any, List, Union +from typing import ( + Any, + AsyncGenerator, + Dict, + List, + Optional, + Protocol, + TypeVar, + Union, + cast, +) from unittest import mock from expecttest import TestCase from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +# Define types for objects used in the testing module +T = TypeVar("T") + + +class TextContent(Protocol): + """Protocol for objects with a text attribute.""" + + text: str + + +class CallToolResult(Protocol): + """Protocol for objects returned by call_tool.""" + + isError: bool + content: Union[str, List[TextContent], Any] + class MCPEndToEndTestCase(TestCase, unittest.IsolatedAsyncioTestCase): """Base class for end-to-end tests of codemcp using MCP client.""" @@ -84,26 +112,26 @@ async def setup_repository(self): await self.git_run(["add", "README.md", "codemcp.toml"]) await self.git_run(["commit", "-m", "Initial commit"]) - def normalize_path(self, text): + def normalize_path(self, text: Any) -> Union[str, List[TextContent], Any]: """Normalize temporary directory paths in output text.""" if self.temp_dir and self.temp_dir.name: # Handle CallToolResult objects by converting to string first if hasattr(text, "content"): # This is a CallToolResult object, extract the content - text = text.content + text = cast(CallToolResult, text).content # Handle lists of TextContent objects if isinstance(text, list) and len(text) > 0 and hasattr(text[0], "text"): # For list of TextContent objects, we'll preserve the list structure # but normalize the path in each TextContent's text attribute - return text + return cast(List[TextContent], text) # Replace the actual temp dir path with a fixed placeholder if isinstance(text, str): return text.replace(self.temp_dir.name, "/tmp/test_dir") return text - def extract_text_from_result(self, result): + def extract_text_from_result(self, result: Any) -> str: """Extract text content from various result formats for assertions. Args: @@ -114,12 +142,12 @@ def extract_text_from_result(self, result): """ if isinstance(result, list) and len(result) > 0 and hasattr(result[0], "text"): - return result[0].text + return cast(TextContent, result[0]).text if isinstance(result, str): return result return str(result) - def extract_chat_id_from_text(self, text): + def extract_chat_id_from_text(self, text: str) -> str: """Extract chat_id from init_result_text. Args: @@ -131,13 +159,16 @@ def extract_chat_id_from_text(self, text): Raises: AssertionError: If chat_id cannot be found in text """ - import re - chat_id_match = re.search(r"chat ID: ([a-zA-Z0-9-]+)", text) assert chat_id_match is not None, "Could not find chat ID in text" return chat_id_match.group(1) - async def call_tool_assert_error(self, session, tool_name, tool_params): + async def call_tool_assert_error( + self, + session: Optional[ClientSession], + tool_name: str, + tool_params: Dict[str, Any], + ) -> str: """Call a tool and assert that it fails (isError=True). This is a helper method for the error path of tool calls, which: @@ -174,17 +205,25 @@ async def call_tool_assert_error(self, session, tool_name, tool_params): # If we get here, the call succeeded - but we expected it to fail self.fail(f"Tool call to {tool_name} succeeded, expected to fail") else: + assert session is not None, ( + "Session cannot be None when in_process=False" + ) result = await session.call_tool("codemcp", tool_params) self.assertTrue(result.isError, result) error_message = self.extract_text_from_result(result.content) - return self.normalize_path(error_message) + return cast(str, self.normalize_path(error_message)) except Exception as e: # The call failed as expected - return the error message error_message = f"Error executing tool {tool_name}: {str(e)}" normalized_result = self.normalize_path(error_message) - return normalized_result + return cast(str, normalized_result) - async def call_tool_assert_success(self, session, tool_name, tool_params): + async def call_tool_assert_success( + self, + session: Optional[ClientSession], + tool_name: str, + tool_params: Dict[str, Any], + ) -> str: """Call a tool and assert that it succeeds (isError=False). This is a helper method for the happy path of tool calls, which: @@ -221,12 +260,13 @@ async def call_tool_assert_success(self, session, tool_name, tool_params): normalized_result = self.normalize_path(result) return self.extract_text_from_result(normalized_result) else: + assert session is not None, "Session cannot be None when in_process=False" result = await session.call_tool("codemcp", tool_params) self.assertFalse(result.isError, result) - error_message = self.extract_text_from_result(result.content) - return self.normalize_path(error_message) + response_text = self.extract_text_from_result(result.content) + return cast(str, self.normalize_path(response_text)) - async def get_chat_id(self, session): + async def get_chat_id(self, session: Optional[ClientSession]) -> str: """Initialize project and get chat_id. Args: @@ -247,16 +287,16 @@ async def get_chat_id(self, session): ) # Extract chat_id from the init result - import re - - chat_id_match = re.search(r"chat ID: ([a-zA-Z0-9-]+)", init_result_text) + chat_id_match = re.search(r"chat ID: ([a-zA-Z0-9-]+)", str(init_result_text)) + assert chat_id_match is not None, ( + "Could not find chat ID in initialization result" + ) chat_id = chat_id_match.group(1) - assert chat_id is not None return chat_id @asynccontextmanager - async def _unwrap_exception_groups(self): + async def _unwrap_exception_groups(self) -> AsyncGenerator[None, None]: """Context manager that unwraps ExceptionGroups with single exceptions. Only unwraps if there's exactly one exception at each level. """ @@ -274,7 +314,9 @@ async def _unwrap_exception_groups(self): raise @asynccontextmanager - async def create_client_session(self): + async def create_client_session( + self, + ) -> AsyncGenerator[Optional[ClientSession], None]: """Create an MCP client session connected to codemcp server.""" if self.in_process: yield None @@ -285,7 +327,7 @@ async def create_client_session(self): command=sys.executable, # Current Python executable args=["-m", "codemcp"], # Module path to codemcp env=self.env, - cwd=self.temp_dir.name, # Set the working directory to our test directory + # Working directory is specified directly with kwargs in stdio_client ) async with self._unwrap_exception_groups(): @@ -303,7 +345,7 @@ async def git_run( capture_output: bool = False, text: bool = False, **kwargs: Any, - ) -> Union[subprocess.CompletedProcess, str]: + ) -> Union[subprocess.CompletedProcess[bytes], str]: """Run git command asynchronously with appropriate temp_dir and env settings. This helper method simplifies git subprocess calls in e2e tests by: @@ -320,7 +362,7 @@ async def git_run( Returns: If capture_output is False: subprocess.CompletedProcess instance - If capture_output is True and decode is True: The stdout content as string + If capture_output is True and text is True: The stdout content as string Example: # Run git add command @@ -350,22 +392,21 @@ async def git_run( stdout, stderr = await proc.communicate() # Build a CompletedProcess-like result - result = subprocess.CompletedProcess( + result = subprocess.CompletedProcess[bytes]( args=cmd, - returncode=proc.returncode, + returncode=proc.returncode or 0, # Use 0 if returncode is None stdout=stdout, stderr=stderr, ) # Check for error if requested - if check and proc.returncode != 0: - stderr.decode() if stderr else "Unknown error" + if check and proc.returncode and proc.returncode != 0: cmd_str = " ".join(cmd) raise subprocess.CalledProcessError( proc.returncode, cmd_str, output=stdout, stderr=stderr ) # Return the appropriate result type - if capture_output and text and stdout is not None: + if capture_output and text and stdout: # stdout is bytes or None return stdout.decode().strip() return result From 9c035ed4f73b9ace4d6a20e6cd4df632fa19f31b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 25 Mar 2025 19:25:28 +0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- codemcp/testing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codemcp/testing.py b/codemcp/testing.py index b63abe35..e762f759 100644 --- a/codemcp/testing.py +++ b/codemcp/testing.py @@ -407,6 +407,7 @@ async def git_run( ) # Return the appropriate result type - if capture_output and text and stdout: # stdout is bytes or None - return stdout.decode().strip() + if capture_output and text: + # Always decode to string when text=True even if stdout is empty + return stdout.decode().strip() if stdout else "" return result