Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions codemcp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def codemcp(
if path is None:
raise ValueError("path is required for ReadFile subtool")

result = await read_file(path, offset, limit, chat_id, commit_hash)
result = await read_file(**provided_params)
return result

if subtool == "WriteFile":
Expand All @@ -193,7 +193,7 @@ async def codemcp(
if chat_id is None:
raise ValueError("chat_id is required for WriteFile subtool")

result = await write_file(path, content, description, chat_id, commit_hash)
result = await write_file(**provided_params)
return result

if subtool == "EditFile":
Expand All @@ -209,19 +209,14 @@ async def codemcp(
if chat_id is None:
raise ValueError("chat_id is required for EditFile subtool")

# Accept either old_string or old_str (prefer old_string if both are provided)
old_content = old_string or old_str
# Accept either new_string or new_str (prefer new_string if both are provided)
new_content = new_string or new_str

result = await edit_file(path, old_content, new_content, None, description, chat_id, commit_hash)
result = await edit_file(**provided_params)
return result

if subtool == "LS":
if path is None:
raise ValueError("path is required for LS subtool")

result = await ls(path, chat_id, commit_hash)
result = await ls(**provided_params)
return result

if subtool == "InitProject":
Expand All @@ -231,14 +226,19 @@ async def codemcp(
raise ValueError("user_prompt is required for InitProject subtool")
if subject_line is None:
raise ValueError("subject_line is required for InitProject subtool")
if reuse_head_chat_id is None:
reuse_head_chat_id = (
False # Default value in main.py only, not in the implementation
)

return await init_project(
path, user_prompt, subject_line, reuse_head_chat_id
)
# Handle parameter naming differences with adapter pattern in the central point
if "path" in provided_params and "directory" not in provided_params:
provided_params["directory"] = provided_params.pop("path")

# Ensure reuse_head_chat_id has a default value
if (
"reuse_head_chat_id" not in provided_params
or provided_params["reuse_head_chat_id"] is None
):
provided_params["reuse_head_chat_id"] = False

return await init_project(**provided_params)

if subtool == "RunCommand":
# When is something a command as opposed to a subtool? They are
Expand All @@ -253,7 +253,11 @@ async def codemcp(
if chat_id is None:
raise ValueError("chat_id is required for RunCommand subtool")

result = await run_command(path, command, arguments, chat_id, commit_hash)
# Handle parameter naming differences with adapter pattern in the central point
if "path" in provided_params and "project_dir" not in provided_params:
provided_params["project_dir"] = provided_params.pop("path")

result = await run_command(**provided_params)
return result

if subtool == "Grep":
Expand All @@ -263,7 +267,7 @@ async def codemcp(
raise ValueError("path is required for Grep subtool")

try:
result_string = await grep(pattern, path, include, chat_id, commit_hash)
result_string = await grep(**provided_params)
return result_string
except Exception as e:
logging.error(f"Error in Grep subtool: {e}", exc_info=True)
Expand All @@ -276,7 +280,7 @@ async def codemcp(
raise ValueError("path is required for Glob subtool")

try:
result_string = await glob(pattern, path, limit, offset, chat_id, commit_hash)
result_string = await glob(**provided_params)
return result_string
except Exception as e:
logging.error(f"Error in Glob subtool: {e}", exc_info=True)
Expand All @@ -290,7 +294,7 @@ async def codemcp(
if chat_id is None:
raise ValueError("chat_id is required for RM subtool")

result = await rm(path, description, chat_id, commit_hash)
result = await rm(**provided_params)
return result

if subtool == "MV":
Expand All @@ -307,14 +311,14 @@ async def codemcp(
if chat_id is None:
raise ValueError("chat_id is required for MV subtool")

result = await mv(source_path, target_path, description, chat_id, commit_hash)
result = await mv(**provided_params)
return result

if subtool == "Think":
if thought is None:
raise ValueError("thought is required for Think subtool")

result = await think(thought, chat_id, commit_hash)
result = await think(**provided_params)
return result

if subtool == "Chmod":
Expand All @@ -325,7 +329,7 @@ async def codemcp(
if chat_id is None:
raise ValueError("chat_id is required for Chmod subtool")

result_string = await chmod(path, mode, chat_id, commit_hash)
result_string = await chmod(**provided_params)
return result_string

except Exception:
Expand Down
144 changes: 127 additions & 17 deletions codemcp/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,118 @@ def extract_chat_id_from_text(self, text: str) -> str:
assert chat_id_match is not None, "Could not find chat ID in text"
return chat_id_match.group(1)

async def _dispatch_to_subtool(self, subtool: str, kwargs: Dict[str, Any]) -> Any:
"""Dispatch to the appropriate subtool function based on the subtool name.

This is a helper method that both call_tool_assert_success and call_tool_assert_error
use to route the call to the appropriate function in the tools module.

Args:
subtool: The name of the subtool to call
kwargs: Dictionary of parameters to pass to the subtool

Returns:
The result from the subtool function

Raises:
ValueError: If the subtool is unknown
"""
# Call the function directly instead of using codemcp.main.codemcp
if subtool == "ReadFile":
from codemcp.tools.read_file import read_file

return await read_file(**kwargs)

elif subtool == "WriteFile":
from codemcp.tools.write_file import write_file

return await write_file(**kwargs)

elif subtool == "EditFile":
from codemcp.tools.edit_file import edit_file

return await edit_file(**kwargs)

elif subtool == "LS":
from codemcp.tools.ls import ls

return await ls(**kwargs)

elif subtool == "InitProject":
from codemcp.tools.init_project import init_project

# Convert 'path' parameter to 'directory' as expected by init_project
if "path" in kwargs:
kwargs = kwargs.copy() # Make a copy to avoid modifying the original
directory = kwargs.pop("path")
return await init_project(directory=directory, **kwargs)
else:
return await init_project(**kwargs)

elif subtool == "RunCommand":
from codemcp.tools.run_command import run_command

# Convert 'path' parameter to 'project_dir' as expected by run_command
if "path" in kwargs:
kwargs = kwargs.copy() # Make a copy to avoid modifying the original
project_dir = kwargs.pop("path")
return await run_command(project_dir=project_dir, **kwargs)
else:
return await run_command(**kwargs)

elif subtool == "Grep":
from codemcp.tools.grep import grep

return await grep(**kwargs)

elif subtool == "Glob":
from codemcp.tools.glob import glob

return await glob(**kwargs)

elif subtool == "RM":
from codemcp.tools.rm import rm

return await rm(**kwargs)

elif subtool == "MV":
from codemcp.tools.mv import mv

return await mv(**kwargs)

elif subtool == "Think":
from codemcp.tools.think import think

return await think(**kwargs)

elif subtool == "Chmod":
from codemcp.tools.chmod import chmod

return await chmod(**kwargs)

elif subtool == "GitLog":
from codemcp.tools.git_log import git_log

return await git_log(**kwargs)

elif subtool == "GitDiff":
from codemcp.tools.git_diff import git_diff

return await git_diff(**kwargs)

elif subtool == "GitShow":
from codemcp.tools.git_show import git_show

return await git_show(**kwargs)

elif subtool == "GitBlame":
from codemcp.tools.git_blame import git_blame

return await git_blame(**kwargs)

else:
raise ValueError(f"Unknown subtool: {subtool}")

async def call_tool_assert_error(
self,
session: Optional[ClientSession],
Expand All @@ -210,7 +322,7 @@ async def call_tool_assert_error(
"""Call a tool and assert that it fails (isError=True).

This is a helper method for the error path of tool calls, which:
1. Calls the specified tool with the given parameters using codemcp.main.codemcp directly
1. Calls the specified tool function directly based on subtool parameter
2. Asserts that the call raises an exception
3. Returns the exception string

Expand All @@ -225,21 +337,20 @@ async def call_tool_assert_error(
Raises:
AssertionError: If the tool call does not result in an error
"""
import codemcp.main

# Only codemcp tool is supported
assert tool_name == "codemcp", (
f"Only 'codemcp' tool is supported, got '{tool_name}'"
)

# Extract the parameters to pass to codemcp.main.codemcp
# Extract the parameters to pass to the direct function
subtool = tool_params.get("subtool")
kwargs = {k: v for k, v in tool_params.items() if k != "subtool"}

try:
if self.in_process:
# Call codemcp.main.codemcp directly instead of using the client session
await codemcp.main.codemcp(subtool, **kwargs)
# Use the dispatcher to call the appropriate function
await self._dispatch_to_subtool(subtool, kwargs)

# If we get here, the call succeeded - but we expected it to fail
self.fail(f"Tool call to {tool_name} succeeded, expected to fail")
else:
Expand All @@ -265,7 +376,7 @@ async def call_tool_assert_success(
"""Call a tool and assert that it succeeds (isError=False).

This is a helper method for the happy path of tool calls, which:
1. Calls the specified tool with the given parameters using codemcp.main.codemcp directly
1. Calls the specified tool function directly based on subtool parameter
2. Asserts that the call succeeds (no exception)
3. Returns the result text

Expand All @@ -280,20 +391,20 @@ async def call_tool_assert_success(
Raises:
AssertionError: If the tool call results in an error
"""
import codemcp.main

# Only codemcp tool is supported
assert tool_name == "codemcp", (
f"Only 'codemcp' tool is supported, got '{tool_name}'"
)

# Extract the parameters to pass to codemcp.main.codemcp
# Extract the parameters to pass to the direct function
subtool = tool_params.get("subtool")
kwargs = {k: v for k, v in tool_params.items() if k != "subtool"}

# Call codemcp.main.codemcp directly instead of using the client session
# Call the function directly instead of using codemcp.main.codemcp
if self.in_process:
result = await codemcp.main.codemcp(subtool, **kwargs)
# Use the dispatcher to call the appropriate function
result = await self._dispatch_to_subtool(subtool, kwargs)

# Return the normalized, extracted text result
normalized_result = self.normalize_path(result)
return self.extract_text_from_result(normalized_result)
Expand All @@ -313,12 +424,11 @@ async def get_chat_id(self, session: Optional[ClientSession]) -> str:
Returns:
str: The chat_id
"""
import codemcp.main
from codemcp.tools.init_project import init_project

# First initialize project to get chat_id using codemcp.main.codemcp directly
init_result_text = await codemcp.main.codemcp(
"InitProject",
path=self.temp_dir.name,
# First initialize project to get chat_id using init_project directly
init_result_text = await init_project(
directory=self.temp_dir.name,
user_prompt="Test initialization for get_chat_id",
subject_line="test: initialize for e2e testing",
reuse_head_chat_id=False,
Expand Down
4 changes: 2 additions & 2 deletions codemcp/tools/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ async def glob(
else:
output = f"Found {total_matches} files matching '{pattern}' in {path}"
if offset_val > 0 or total_matches > offset_val + limit_val:
output += f" (showing {offset_val+1}-{min(offset_val+limit_val, total_matches)} of {total_matches})"
output += f" (showing {offset_val + 1}-{min(offset_val + limit_val, total_matches)} of {total_matches})"
output += ":\n\n"

for match in matches:
output += f"{match}\n"

Expand Down
4 changes: 1 addition & 3 deletions codemcp/tools/write_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ async def write_file(
raise ValueError(error_message)

# Check git tracking for existing files
is_tracked, track_error = await check_git_tracking_for_existing_file(
path, chat_id
)
is_tracked, track_error = await check_git_tracking_for_existing_file(path, chat_id)
if not is_tracked:
raise ValueError(track_error)

Expand Down
Loading