diff --git a/codemcp/main.py b/codemcp/main.py index a508ef0..8b7322e 100644 --- a/codemcp/main.py +++ b/codemcp/main.py @@ -182,7 +182,7 @@ async def codemcp( if path is None: raise ValueError("path is required for ReadFile subtool") - result = await read_file(path, offset, limit, chat_id, commit_hash) + result = await read_file(**provided_params) return result if subtool == "WriteFile": @@ -193,7 +193,7 @@ async def codemcp( if chat_id is None: raise ValueError("chat_id is required for WriteFile subtool") - result = await write_file(path, content, description, chat_id, commit_hash) + result = await write_file(**provided_params) return result if subtool == "EditFile": @@ -209,19 +209,14 @@ async def codemcp( if chat_id is None: raise ValueError("chat_id is required for EditFile subtool") - # Accept either old_string or old_str (prefer old_string if both are provided) - old_content = old_string or old_str - # Accept either new_string or new_str (prefer new_string if both are provided) - new_content = new_string or new_str - - result = await edit_file(path, old_content, new_content, None, description, chat_id, commit_hash) + result = await edit_file(**provided_params) return result if subtool == "LS": if path is None: raise ValueError("path is required for LS subtool") - result = await ls(path, chat_id, commit_hash) + result = await ls(**provided_params) return result if subtool == "InitProject": @@ -231,14 +226,19 @@ async def codemcp( raise ValueError("user_prompt is required for InitProject subtool") if subject_line is None: raise ValueError("subject_line is required for InitProject subtool") - if reuse_head_chat_id is None: - reuse_head_chat_id = ( - False # Default value in main.py only, not in the implementation - ) - return await init_project( - path, user_prompt, subject_line, reuse_head_chat_id - ) + # Handle parameter naming differences with adapter pattern in the central point + if "path" in provided_params and "directory" not in provided_params: + provided_params["directory"] = provided_params.pop("path") + + # Ensure reuse_head_chat_id has a default value + if ( + "reuse_head_chat_id" not in provided_params + or provided_params["reuse_head_chat_id"] is None + ): + provided_params["reuse_head_chat_id"] = False + + return await init_project(**provided_params) if subtool == "RunCommand": # When is something a command as opposed to a subtool? They are @@ -253,7 +253,11 @@ async def codemcp( if chat_id is None: raise ValueError("chat_id is required for RunCommand subtool") - result = await run_command(path, command, arguments, chat_id, commit_hash) + # Handle parameter naming differences with adapter pattern in the central point + if "path" in provided_params and "project_dir" not in provided_params: + provided_params["project_dir"] = provided_params.pop("path") + + result = await run_command(**provided_params) return result if subtool == "Grep": @@ -263,7 +267,7 @@ async def codemcp( raise ValueError("path is required for Grep subtool") try: - result_string = await grep(pattern, path, include, chat_id, commit_hash) + result_string = await grep(**provided_params) return result_string except Exception as e: logging.error(f"Error in Grep subtool: {e}", exc_info=True) @@ -276,7 +280,7 @@ async def codemcp( raise ValueError("path is required for Glob subtool") try: - result_string = await glob(pattern, path, limit, offset, chat_id, commit_hash) + result_string = await glob(**provided_params) return result_string except Exception as e: logging.error(f"Error in Glob subtool: {e}", exc_info=True) @@ -290,7 +294,7 @@ async def codemcp( if chat_id is None: raise ValueError("chat_id is required for RM subtool") - result = await rm(path, description, chat_id, commit_hash) + result = await rm(**provided_params) return result if subtool == "MV": @@ -307,14 +311,14 @@ async def codemcp( if chat_id is None: raise ValueError("chat_id is required for MV subtool") - result = await mv(source_path, target_path, description, chat_id, commit_hash) + result = await mv(**provided_params) return result if subtool == "Think": if thought is None: raise ValueError("thought is required for Think subtool") - result = await think(thought, chat_id, commit_hash) + result = await think(**provided_params) return result if subtool == "Chmod": @@ -325,7 +329,7 @@ async def codemcp( if chat_id is None: raise ValueError("chat_id is required for Chmod subtool") - result_string = await chmod(path, mode, chat_id, commit_hash) + result_string = await chmod(**provided_params) return result_string except Exception: diff --git a/codemcp/testing.py b/codemcp/testing.py index cc8a028..8c3726d 100644 --- a/codemcp/testing.py +++ b/codemcp/testing.py @@ -201,6 +201,118 @@ def extract_chat_id_from_text(self, text: str) -> str: assert chat_id_match is not None, "Could not find chat ID in text" return chat_id_match.group(1) + async def _dispatch_to_subtool(self, subtool: str, kwargs: Dict[str, Any]) -> Any: + """Dispatch to the appropriate subtool function based on the subtool name. + + This is a helper method that both call_tool_assert_success and call_tool_assert_error + use to route the call to the appropriate function in the tools module. + + Args: + subtool: The name of the subtool to call + kwargs: Dictionary of parameters to pass to the subtool + + Returns: + The result from the subtool function + + Raises: + ValueError: If the subtool is unknown + """ + # Call the function directly instead of using codemcp.main.codemcp + if subtool == "ReadFile": + from codemcp.tools.read_file import read_file + + return await read_file(**kwargs) + + elif subtool == "WriteFile": + from codemcp.tools.write_file import write_file + + return await write_file(**kwargs) + + elif subtool == "EditFile": + from codemcp.tools.edit_file import edit_file + + return await edit_file(**kwargs) + + elif subtool == "LS": + from codemcp.tools.ls import ls + + return await ls(**kwargs) + + elif subtool == "InitProject": + from codemcp.tools.init_project import init_project + + # Convert 'path' parameter to 'directory' as expected by init_project + if "path" in kwargs: + kwargs = kwargs.copy() # Make a copy to avoid modifying the original + directory = kwargs.pop("path") + return await init_project(directory=directory, **kwargs) + else: + return await init_project(**kwargs) + + elif subtool == "RunCommand": + from codemcp.tools.run_command import run_command + + # Convert 'path' parameter to 'project_dir' as expected by run_command + if "path" in kwargs: + kwargs = kwargs.copy() # Make a copy to avoid modifying the original + project_dir = kwargs.pop("path") + return await run_command(project_dir=project_dir, **kwargs) + else: + return await run_command(**kwargs) + + elif subtool == "Grep": + from codemcp.tools.grep import grep + + return await grep(**kwargs) + + elif subtool == "Glob": + from codemcp.tools.glob import glob + + return await glob(**kwargs) + + elif subtool == "RM": + from codemcp.tools.rm import rm + + return await rm(**kwargs) + + elif subtool == "MV": + from codemcp.tools.mv import mv + + return await mv(**kwargs) + + elif subtool == "Think": + from codemcp.tools.think import think + + return await think(**kwargs) + + elif subtool == "Chmod": + from codemcp.tools.chmod import chmod + + return await chmod(**kwargs) + + elif subtool == "GitLog": + from codemcp.tools.git_log import git_log + + return await git_log(**kwargs) + + elif subtool == "GitDiff": + from codemcp.tools.git_diff import git_diff + + return await git_diff(**kwargs) + + elif subtool == "GitShow": + from codemcp.tools.git_show import git_show + + return await git_show(**kwargs) + + elif subtool == "GitBlame": + from codemcp.tools.git_blame import git_blame + + return await git_blame(**kwargs) + + else: + raise ValueError(f"Unknown subtool: {subtool}") + async def call_tool_assert_error( self, session: Optional[ClientSession], @@ -210,7 +322,7 @@ async def call_tool_assert_error( """Call a tool and assert that it fails (isError=True). This is a helper method for the error path of tool calls, which: - 1. Calls the specified tool with the given parameters using codemcp.main.codemcp directly + 1. Calls the specified tool function directly based on subtool parameter 2. Asserts that the call raises an exception 3. Returns the exception string @@ -225,21 +337,20 @@ async def call_tool_assert_error( Raises: AssertionError: If the tool call does not result in an error """ - import codemcp.main - # Only codemcp tool is supported assert tool_name == "codemcp", ( f"Only 'codemcp' tool is supported, got '{tool_name}'" ) - # Extract the parameters to pass to codemcp.main.codemcp + # Extract the parameters to pass to the direct function subtool = tool_params.get("subtool") kwargs = {k: v for k, v in tool_params.items() if k != "subtool"} try: if self.in_process: - # Call codemcp.main.codemcp directly instead of using the client session - await codemcp.main.codemcp(subtool, **kwargs) + # Use the dispatcher to call the appropriate function + await self._dispatch_to_subtool(subtool, kwargs) + # If we get here, the call succeeded - but we expected it to fail self.fail(f"Tool call to {tool_name} succeeded, expected to fail") else: @@ -265,7 +376,7 @@ async def call_tool_assert_success( """Call a tool and assert that it succeeds (isError=False). This is a helper method for the happy path of tool calls, which: - 1. Calls the specified tool with the given parameters using codemcp.main.codemcp directly + 1. Calls the specified tool function directly based on subtool parameter 2. Asserts that the call succeeds (no exception) 3. Returns the result text @@ -280,20 +391,20 @@ async def call_tool_assert_success( Raises: AssertionError: If the tool call results in an error """ - import codemcp.main - # Only codemcp tool is supported assert tool_name == "codemcp", ( f"Only 'codemcp' tool is supported, got '{tool_name}'" ) - # Extract the parameters to pass to codemcp.main.codemcp + # Extract the parameters to pass to the direct function subtool = tool_params.get("subtool") kwargs = {k: v for k, v in tool_params.items() if k != "subtool"} - # Call codemcp.main.codemcp directly instead of using the client session + # Call the function directly instead of using codemcp.main.codemcp if self.in_process: - result = await codemcp.main.codemcp(subtool, **kwargs) + # Use the dispatcher to call the appropriate function + result = await self._dispatch_to_subtool(subtool, kwargs) + # Return the normalized, extracted text result normalized_result = self.normalize_path(result) return self.extract_text_from_result(normalized_result) @@ -313,12 +424,11 @@ async def get_chat_id(self, session: Optional[ClientSession]) -> str: Returns: str: The chat_id """ - import codemcp.main + from codemcp.tools.init_project import init_project - # First initialize project to get chat_id using codemcp.main.codemcp directly - init_result_text = await codemcp.main.codemcp( - "InitProject", - path=self.temp_dir.name, + # First initialize project to get chat_id using init_project directly + init_result_text = await init_project( + directory=self.temp_dir.name, user_prompt="Test initialization for get_chat_id", subject_line="test: initialize for e2e testing", reuse_head_chat_id=False, diff --git a/codemcp/tools/glob.py b/codemcp/tools/glob.py index c675797..7f97cd8 100644 --- a/codemcp/tools/glob.py +++ b/codemcp/tools/glob.py @@ -117,9 +117,9 @@ async def glob( else: output = f"Found {total_matches} files matching '{pattern}' in {path}" if offset_val > 0 or total_matches > offset_val + limit_val: - output += f" (showing {offset_val+1}-{min(offset_val+limit_val, total_matches)} of {total_matches})" + output += f" (showing {offset_val + 1}-{min(offset_val + limit_val, total_matches)} of {total_matches})" output += ":\n\n" - + for match in matches: output += f"{match}\n" diff --git a/codemcp/tools/write_file.py b/codemcp/tools/write_file.py index cda9e66..5c12b4d 100644 --- a/codemcp/tools/write_file.py +++ b/codemcp/tools/write_file.py @@ -71,9 +71,7 @@ async def write_file( raise ValueError(error_message) # Check git tracking for existing files - is_tracked, track_error = await check_git_tracking_for_existing_file( - path, chat_id - ) + is_tracked, track_error = await check_git_tracking_for_existing_file(path, chat_id) if not is_tracked: raise ValueError(track_error)