diff --git a/codemcp/tools/user_prompt.py b/codemcp/tools/user_prompt.py index 0f19a951..5b6f5188 100644 --- a/codemcp/tools/user_prompt.py +++ b/codemcp/tools/user_prompt.py @@ -2,29 +2,113 @@ import logging import os +import re +from pathlib import Path from ..git_query import find_git_root from ..rules import get_applicable_rules_content __all__ = [ "user_prompt", + "is_slash_command", + "resolve_slash_command", + "get_command_content", ] +def is_slash_command(text: str) -> bool: + """Check if the user's text starts with a slash command. + + Args: + text: The user's text to check + + Returns: + True if the text starts with a slash, False otherwise + """ + return bool(text and text.strip().startswith("/")) + + +def resolve_slash_command(command: str) -> tuple[bool, str, str | None]: + """Resolve a slash command to a file path. + + Args: + command: The slash command (including the slash) + + Returns: + A tuple of (success, command_name, file_path) + If success is False, file_path will be None + """ + # Strip the leading slash and any whitespace + command = command.strip()[1:].strip() + + # Check for the command format: user:command-name + match = re.match(r"^user:([a-zA-Z0-9_-]+)$", command) + if not match: + return False, command, None + + command_name = match.group(1) + + # Get the commands directory path + commands_dir = Path.home() / ".claude" / "commands" + + # Create the commands directory if it doesn't exist + os.makedirs(commands_dir, exist_ok=True) + + # Check if the command file exists + command_file = commands_dir / f"{command_name}.md" + if not command_file.exists(): + return False, command_name, None + + return True, command_name, str(command_file) + + +async def get_command_content(file_path: str) -> str: + """Get the content of a command file. + + Args: + file_path: The path to the command file + + Returns: + The content of the command file + """ + try: + # Import here to avoid circular imports + from ..file_utils import async_open_text + + # Read the file content + content = await async_open_text(file_path) + return content + except Exception as e: + logging.error(f"Error reading command file {file_path}: {e}") + return f"Error reading command file: {e}" + + async def user_prompt(user_text: str, chat_id: str | None = None) -> str: """Store the user's verbatim prompt text for later use. This function processes the user's prompt and applies any relevant cursor rules. + If the user's prompt starts with a slash, it tries to resolve it as a command. Args: user_text: The user's original prompt verbatim chat_id: The unique ID of the current chat session Returns: - A message with any applicable cursor rules + A message with any applicable cursor rules or command content """ logging.info(f"Received user prompt for chat ID {chat_id}: {user_text}") + # Check if this is a slash command + if is_slash_command(user_text): + success, command_name, file_path = resolve_slash_command(user_text) + if success and file_path: + command_content = await get_command_content(file_path) + logging.info(f"Resolved slash command {command_name} to file {file_path}") + return command_content + else: + logging.info(f"Failed to resolve slash command {user_text}") + return f"Unknown slash command: {command_name}" + # Get the current working directory to find repo root cwd = os.getcwd() repo_root = find_git_root(cwd) diff --git a/e2e/test_slash_commands.py b/e2e/test_slash_commands.py new file mode 100644 index 00000000..6a991b52 --- /dev/null +++ b/e2e/test_slash_commands.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +import asyncio +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from codemcp.tools.user_prompt import user_prompt + + +@pytest.fixture +def mock_commands_dir() -> Path: + """Create a temporary directory with test command files.""" + # Create a temporary directory + temp_dir = tempfile.mkdtemp() + + # Create the commands directory + commands_dir = Path(temp_dir) / ".claude" / "commands" + os.makedirs(commands_dir, exist_ok=True) + + # Create test command files + test_cmd = commands_dir / "test-command.md" + with open(test_cmd, "w") as f: + f.write("# Test Command\nThis is a test command content.") + + help_cmd = commands_dir / "help.md" + with open(help_cmd, "w") as f: + f.write( + "# Available Commands\n- `/user:test-command`: A test command\n- `/user:help`: This help message" + ) + + return commands_dir + + +def test_slash_command_e2e(mock_commands_dir: Path) -> None: + """Test slash commands in an end-to-end scenario.""" + # Save original home + original_home = Path.home + + try: + # Mock Path.home to use our temp directory + Path.home = MagicMock(return_value=mock_commands_dir.parent.parent) + + # Mock async_open_text to read the actual files + with patch("codemcp.file_utils.async_open_text") as mock_open: + # Set up different return values based on which file is being read + def side_effect(file_path, **kwargs): + if "test-command.md" in file_path: + return "# Test Command\nThis is a test command content." + elif "help.md" in file_path: + return "# Available Commands\n- `/user:test-command`: A test command\n- `/user:help`: This help message" + else: + raise FileNotFoundError(f"File not found: {file_path}") + + mock_open.side_effect = side_effect + + # Test a valid slash command + result = asyncio.run(user_prompt("/user:test-command", "test-chat-id")) + assert "# Test Command" in result + assert "This is a test command content." in result + + # Test the help command + result = asyncio.run(user_prompt("/user:help", "test-chat-id")) + assert "# Available Commands" in result + assert "`/user:test-command`" in result + assert "`/user:help`" in result + + # Test an invalid slash command + result = asyncio.run(user_prompt("/user:invalid-command", "test-chat-id")) + assert "Unknown slash command: invalid-command" in result + + # Test a non-slash command + with patch("codemcp.tools.user_prompt.find_git_root", return_value=None): + result = asyncio.run(user_prompt("normal message", "test-chat-id")) + assert "User prompt received" in result + finally: + # Restore original home + Path.home = original_home + + # Clean up the temporary directory + import shutil + + shutil.rmtree(mock_commands_dir.parent.parent) diff --git a/pyproject.toml b/pyproject.toml index d6b1ea63..0e273c79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "anyio>=3.7.0", "pyyaml>=6.0.0", "pytest-xdist>=3.6.1", + "pytest-asyncio>=0.21.0", "editorconfig>=0.17.0", "click>=8.1.8", ] diff --git a/tests/test_slash_commands.py b/tests/test_slash_commands.py new file mode 100644 index 00000000..2654ecc8 --- /dev/null +++ b/tests/test_slash_commands.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +import asyncio +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +from codemcp.tools.user_prompt import is_slash_command, resolve_slash_command + + +def test_is_slash_command(): + """Test the is_slash_command function.""" + # Test valid slash commands + assert is_slash_command("/command") is True + assert is_slash_command(" /command ") is True + assert is_slash_command("/user:command-name") is True + + # Test invalid cases + assert is_slash_command("command") is False + assert is_slash_command("") is False + assert is_slash_command(None) is False + assert is_slash_command(" command ") is False + + +def test_resolve_slash_command(): + """Test the resolve_slash_command function.""" + # Valid command format but non-existent file + with patch("os.makedirs"), patch("pathlib.Path.exists", return_value=False): + success, command_name, file_path = resolve_slash_command("/user:test-command") + assert success is False + assert command_name == "test-command" + assert file_path is None + + # Valid command format with existing file + with ( + patch("os.makedirs"), + patch("pathlib.Path.exists", return_value=True), + patch( + "pathlib.Path.__truediv__", + return_value=Path("/home/user/.claude/commands/test-command.md"), + ), + ): + success, command_name, file_path = resolve_slash_command("/user:test-command") + assert success is True + assert command_name == "test-command" + assert file_path == "/home/user/.claude/commands/test-command.md" + + # Invalid command format (missing user: prefix) + success, command_name, file_path = resolve_slash_command("/test-command") + assert success is False + assert command_name == "test-command" + assert file_path is None + + # Invalid command format (invalid characters) + success, command_name, file_path = resolve_slash_command("/user:test@command") + assert success is False + assert command_name == "user:test@command" + assert file_path is None + + +def test_get_command_content(): + """Test the get_command_content function.""" + from codemcp.tools.user_prompt import get_command_content + + # Create a temporary file for testing + with tempfile.NamedTemporaryFile(mode="w+", suffix=".md") as temp_file: + temp_file.write("# Test Command\nThis is a test command content.") + temp_file.flush() + + # Mock file_utils.async_open_text to return our test content + with patch("codemcp.file_utils.async_open_text") as mock_open: + # Set up the mock to return our content + mock_open.return_value = "# Test Command\nThis is a test command content." + + # Run the coroutine in the event loop + result = asyncio.run(get_command_content(temp_file.name)) + + # Verify the result + assert "# Test Command" in result + assert "This is a test command content." in result + + # Test error handling + with patch( + "codemcp.file_utils.async_open_text", side_effect=Exception("Test error") + ): + # Run the coroutine in the event loop + result = asyncio.run(get_command_content("non-existent-file")) + + # Verify error handling + assert "Error reading command file" in result + assert "Test error" in result + + +def test_user_prompt_with_slash_command(): + """Test the user_prompt function with slash commands.""" + from codemcp.tools.user_prompt import user_prompt + + # Create a temporary directory and markdown file for testing + with tempfile.TemporaryDirectory() as temp_dir: + # Mock Path.home() to return our temporary directory + original_home = Path.home + Path.home = MagicMock(return_value=Path(temp_dir)) + + try: + # Create the .claude/commands directory + commands_dir = Path(temp_dir) / ".claude" / "commands" + os.makedirs(commands_dir, exist_ok=True) + + # Create a test command file + command_file = commands_dir / "test-command.md" + with open(command_file, "w") as f: + f.write("# Test Command\nThis is a test command content.") + + # Mock file_utils.async_open_text to return our test content + with patch("codemcp.file_utils.async_open_text") as mock_open: + mock_open.return_value = ( + "# Test Command\nThis is a test command content." + ) + + # Test with a valid slash command + result = asyncio.run(user_prompt("/user:test-command", "test-chat-id")) + assert "# Test Command" in result + assert "This is a test command content." in result + + # Test with an invalid slash command + result = asyncio.run(user_prompt("/user:non-existent", "test-chat-id")) + assert "Unknown slash command: non-existent" in result + + # Test with a non-slash command + with patch( + "codemcp.tools.user_prompt.find_git_root", return_value=None + ): + result = asyncio.run(user_prompt("regular command", "test-chat-id")) + assert "User prompt received" in result + finally: + # Restore original Path.home + Path.home = original_home