Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 73 additions & 31 deletions codemcp/testing.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,47 @@
#!/usr/bin/env python3
# pyright: reportUnknownArgumentType=false, reportUnknownMemberType=false, reportUnknownVariableType=false

import asyncio
import os
import re
import subprocess
import sys
import tempfile
import unittest
from contextlib import asynccontextmanager
from typing import Any, List, Union
from typing import (
Any,
AsyncGenerator,
Dict,
List,
Optional,
Protocol,
TypeVar,
Union,
cast,
)
from unittest import mock

from expecttest import TestCase
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

# Define types for objects used in the testing module
T = TypeVar("T")


class TextContent(Protocol):
"""Protocol for objects with a text attribute."""

text: str


class CallToolResult(Protocol):
"""Protocol for objects returned by call_tool."""

isError: bool
content: Union[str, List[TextContent], Any]


class MCPEndToEndTestCase(TestCase, unittest.IsolatedAsyncioTestCase):
"""Base class for end-to-end tests of codemcp using MCP client."""
Expand Down Expand Up @@ -84,26 +112,26 @@ async def setup_repository(self):
await self.git_run(["add", "README.md", "codemcp.toml"])
await self.git_run(["commit", "-m", "Initial commit"])

def normalize_path(self, text):
def normalize_path(self, text: Any) -> Union[str, List[TextContent], Any]:
"""Normalize temporary directory paths in output text."""
if self.temp_dir and self.temp_dir.name:
# Handle CallToolResult objects by converting to string first
if hasattr(text, "content"):
# This is a CallToolResult object, extract the content
text = text.content
text = cast(CallToolResult, text).content

# Handle lists of TextContent objects
if isinstance(text, list) and len(text) > 0 and hasattr(text[0], "text"):
# For list of TextContent objects, we'll preserve the list structure
# but normalize the path in each TextContent's text attribute
return text
return cast(List[TextContent], text)

# Replace the actual temp dir path with a fixed placeholder
if isinstance(text, str):
return text.replace(self.temp_dir.name, "/tmp/test_dir")
return text

def extract_text_from_result(self, result):
def extract_text_from_result(self, result: Any) -> str:
"""Extract text content from various result formats for assertions.

Args:
Expand All @@ -114,12 +142,12 @@ def extract_text_from_result(self, result):

"""
if isinstance(result, list) and len(result) > 0 and hasattr(result[0], "text"):
return result[0].text
return cast(TextContent, result[0]).text
if isinstance(result, str):
return result
return str(result)

def extract_chat_id_from_text(self, text):
def extract_chat_id_from_text(self, text: str) -> str:
"""Extract chat_id from init_result_text.

Args:
Expand All @@ -131,13 +159,16 @@ def extract_chat_id_from_text(self, text):
Raises:
AssertionError: If chat_id cannot be found in text
"""
import re

chat_id_match = re.search(r"chat ID: ([a-zA-Z0-9-]+)", text)
assert chat_id_match is not None, "Could not find chat ID in text"
return chat_id_match.group(1)

async def call_tool_assert_error(self, session, tool_name, tool_params):
async def call_tool_assert_error(
self,
session: Optional[ClientSession],
tool_name: str,
tool_params: Dict[str, Any],
) -> str:
"""Call a tool and assert that it fails (isError=True).

This is a helper method for the error path of tool calls, which:
Expand Down Expand Up @@ -174,17 +205,25 @@ async def call_tool_assert_error(self, session, tool_name, tool_params):
# If we get here, the call succeeded - but we expected it to fail
self.fail(f"Tool call to {tool_name} succeeded, expected to fail")
else:
assert session is not None, (
"Session cannot be None when in_process=False"
)
result = await session.call_tool("codemcp", tool_params)
self.assertTrue(result.isError, result)
error_message = self.extract_text_from_result(result.content)
return self.normalize_path(error_message)
return cast(str, self.normalize_path(error_message))
except Exception as e:
# The call failed as expected - return the error message
error_message = f"Error executing tool {tool_name}: {str(e)}"
normalized_result = self.normalize_path(error_message)
return normalized_result
return cast(str, normalized_result)

async def call_tool_assert_success(self, session, tool_name, tool_params):
async def call_tool_assert_success(
self,
session: Optional[ClientSession],
tool_name: str,
tool_params: Dict[str, Any],
) -> str:
"""Call a tool and assert that it succeeds (isError=False).

This is a helper method for the happy path of tool calls, which:
Expand Down Expand Up @@ -221,12 +260,13 @@ async def call_tool_assert_success(self, session, tool_name, tool_params):
normalized_result = self.normalize_path(result)
return self.extract_text_from_result(normalized_result)
else:
assert session is not None, "Session cannot be None when in_process=False"
result = await session.call_tool("codemcp", tool_params)
self.assertFalse(result.isError, result)
error_message = self.extract_text_from_result(result.content)
return self.normalize_path(error_message)
response_text = self.extract_text_from_result(result.content)
return cast(str, self.normalize_path(response_text))

async def get_chat_id(self, session):
async def get_chat_id(self, session: Optional[ClientSession]) -> str:
"""Initialize project and get chat_id.

Args:
Expand All @@ -247,16 +287,16 @@ async def get_chat_id(self, session):
)

# Extract chat_id from the init result
import re

chat_id_match = re.search(r"chat ID: ([a-zA-Z0-9-]+)", init_result_text)
chat_id_match = re.search(r"chat ID: ([a-zA-Z0-9-]+)", str(init_result_text))
assert chat_id_match is not None, (
"Could not find chat ID in initialization result"
)
chat_id = chat_id_match.group(1)
assert chat_id is not None

return chat_id

@asynccontextmanager
async def _unwrap_exception_groups(self):
async def _unwrap_exception_groups(self) -> AsyncGenerator[None, None]:
"""Context manager that unwraps ExceptionGroups with single exceptions.
Only unwraps if there's exactly one exception at each level.
"""
Expand All @@ -274,7 +314,9 @@ async def _unwrap_exception_groups(self):
raise

@asynccontextmanager
async def create_client_session(self):
async def create_client_session(
self,
) -> AsyncGenerator[Optional[ClientSession], None]:
"""Create an MCP client session connected to codemcp server."""
if self.in_process:
yield None
Expand All @@ -285,7 +327,7 @@ async def create_client_session(self):
command=sys.executable, # Current Python executable
args=["-m", "codemcp"], # Module path to codemcp
env=self.env,
cwd=self.temp_dir.name, # Set the working directory to our test directory
# Working directory is specified directly with kwargs in stdio_client
)

async with self._unwrap_exception_groups():
Expand All @@ -303,7 +345,7 @@ async def git_run(
capture_output: bool = False,
text: bool = False,
**kwargs: Any,
) -> Union[subprocess.CompletedProcess, str]:
) -> Union[subprocess.CompletedProcess[bytes], str]:
"""Run git command asynchronously with appropriate temp_dir and env settings.

This helper method simplifies git subprocess calls in e2e tests by:
Expand All @@ -320,7 +362,7 @@ async def git_run(

Returns:
If capture_output is False: subprocess.CompletedProcess instance
If capture_output is True and decode is True: The stdout content as string
If capture_output is True and text is True: The stdout content as string

Example:
# Run git add command
Expand Down Expand Up @@ -350,22 +392,22 @@ async def git_run(
stdout, stderr = await proc.communicate()

# Build a CompletedProcess-like result
result = subprocess.CompletedProcess(
result = subprocess.CompletedProcess[bytes](
args=cmd,
returncode=proc.returncode,
returncode=proc.returncode or 0, # Use 0 if returncode is None
stdout=stdout,
stderr=stderr,
)

# Check for error if requested
if check and proc.returncode != 0:
stderr.decode() if stderr else "Unknown error"
if check and proc.returncode and proc.returncode != 0:
cmd_str = " ".join(cmd)
raise subprocess.CalledProcessError(
proc.returncode, cmd_str, output=stdout, stderr=stderr
)

# Return the appropriate result type
if capture_output and text and stdout is not None:
return stdout.decode().strip()
if capture_output and text:
# Always decode to string when text=True even if stdout is empty
return stdout.decode().strip() if stdout else ""
return result
Loading