diff --git a/src/mcp_codebase_index/server.py b/src/mcp_codebase_index/server.py index 72126da..50c4f20 100644 --- a/src/mcp_codebase_index/server.py +++ b/src/mcp_codebase_index/server.py @@ -28,12 +28,14 @@ from __future__ import annotations +import argparse import fnmatch import json import os import sys import pickle import time +import tomllib import traceback from mcp.server import Server @@ -66,6 +68,10 @@ _tool_call_counts: dict[str, int] = {} _total_chars_returned: int = 0 +# Tool toggling +PROTECTED_TOOLS: frozenset[str] = frozenset({"reindex", "get_usage_stats"}) +_disabled_tools: set[str] = set() + # Realistic estimate of what % of codebase you'd need to read without the indexer _TOOL_COST_MULTIPLIERS: dict[str, float] = { "get_project_summary": 0.10, @@ -88,6 +94,78 @@ } +_ALL_TOOL_NAMES: frozenset[str] = frozenset() # set after TOOLS is defined + + +def _load_disabled_tools_from_config(project_root: str) -> set[str]: + """Read `.mcp-codebase-index.toml` and return the set of disabled tool names.""" + config_path = os.path.join(project_root, ".mcp-codebase-index.toml") + if not os.path.isfile(config_path): + return set() + try: + with open(config_path, "rb") as f: + data = tomllib.load(f) + except Exception as exc: + print( + f"[mcp-codebase-index] Warning: failed to parse {config_path}: {exc}", + file=sys.stderr, + ) + return set() + raw = data.get("disabled_tools") + if raw is None: + return set() + if not isinstance(raw, list) or not all(isinstance(s, str) for s in raw): + print( + "[mcp-codebase-index] Warning: disabled_tools must be a list of strings, ignoring", + file=sys.stderr, + ) + return set() + return {s.strip() for s in raw if s.strip()} + + +def _init_disabled_tools( + cli_disabled: list[str] | None = None, + *, + project_root: str | None = None, +) -> None: + """Union CLI + config disabled tools, strip protected, warn about unknowns.""" + global _disabled_tools + merged: set[str] = set() + + if cli_disabled: + merged.update(cli_disabled) + + root = project_root or os.environ.get("PROJECT_ROOT", os.getcwd()) + merged |= _load_disabled_tools_from_config(root) + + # Warn about protected tools that a user tried to disable + protected_requested = merged & PROTECTED_TOOLS + if protected_requested: + print( + f"[mcp-codebase-index] Warning: cannot disable protected tools: " + f"{', '.join(sorted(protected_requested))}", + file=sys.stderr, + ) + merged -= PROTECTED_TOOLS + + # Warn about unknown tool names + unknown = merged - _ALL_TOOL_NAMES + if unknown: + print( + f"[mcp-codebase-index] Warning: unknown tools ignored: " + f"{', '.join(sorted(unknown))}", + file=sys.stderr, + ) + merged &= _ALL_TOOL_NAMES + + _disabled_tools = merged + if _disabled_tools: + print( + f"[mcp-codebase-index] Disabled tools: {', '.join(sorted(_disabled_tools))}", + file=sys.stderr, + ) + + def _format_result(value: object) -> str: """Format a query result as readable text.""" if isinstance(value, str): @@ -648,6 +726,9 @@ def _maybe_incremental_update() -> None: ), ] +# Now that TOOLS is defined, set the real _ALL_TOOL_NAMES +_ALL_TOOL_NAMES = frozenset(t.name for t in TOOLS) + # --------------------------------------------------------------------------- # MCP handlers @@ -656,6 +737,8 @@ def _maybe_incremental_update() -> None: @server.list_tools() async def list_tools() -> list[Tool]: + if _disabled_tools: + return [t for t in TOOLS if t.name not in _disabled_tools] return TOOLS @@ -663,6 +746,13 @@ async def list_tools() -> list[Tool]: async def call_tool(name: str, arguments: dict) -> list[types.TextContent]: global _query_fns, _total_chars_returned + # Reject disabled tools before doing any work + if name in _disabled_tools: + return [TextContent( + type="text", + text=f"Error: tool '{name}' is disabled.", + )] + # Track tool call counts (including reindex/stats themselves) _tool_call_counts[name] = _tool_call_counts.get(name, 0) + 1 @@ -787,7 +877,8 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]: # --------------------------------------------------------------------------- -async def main(): +async def main(cli_disabled: list[str] | None = None): + _init_disabled_tools(cli_disabled) async with stdio_server() as (read_stream, write_stream): await server.run( read_stream, @@ -800,7 +891,20 @@ def main_sync(): """Synchronous entry point for console_scripts.""" import asyncio - asyncio.run(main()) + parser = argparse.ArgumentParser(description="MCP codebase index server") + parser.add_argument( + "--disabled-tools", + type=lambda s: [t.strip() for t in s.split(",") if t.strip()], + default=None, + help="Comma-separated list of tool names to disable (e.g. search_codebase,get_call_chain)", + ) + args, unknown = parser.parse_known_args() + if unknown: + print( + f"[mcp-codebase-index] Ignoring unknown arguments: {' '.join(unknown)}", + file=sys.stderr, + ) + asyncio.run(main(cli_disabled=args.disabled_tools)) if __name__ == "__main__": diff --git a/tests/test_tool_toggle.py b/tests/test_tool_toggle.py new file mode 100644 index 0000000..d5ae297 --- /dev/null +++ b/tests/test_tool_toggle.py @@ -0,0 +1,227 @@ +"""Tests for the tool toggling feature (--disabled-tools + TOML config).""" + +import asyncio + +import pytest + +import mcp_codebase_index.server as srv + + +@pytest.fixture(autouse=True) +def _reset_toggle_state(): + """Reset the disabled-tools set and related state before each test.""" + srv._disabled_tools = set() + yield + srv._disabled_tools = set() + + +# --------------------------------------------------------------------------- +# _load_disabled_tools_from_config +# --------------------------------------------------------------------------- + + +class TestLoadDisabledToolsFromConfig: + def test_no_config_file(self, tmp_path): + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == set() + + def test_valid_config(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text( + 'disabled_tools = ["search_codebase", "get_call_chain"]\n' + ) + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == {"search_codebase", "get_call_chain"} + + def test_empty_list(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text("disabled_tools = []\n") + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == set() + + def test_invalid_type_not_list(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text('disabled_tools = "search_codebase"\n') + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == set() + + def test_invalid_type_list_of_non_strings(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text("disabled_tools = [1, 2]\n") + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == set() + + def test_malformed_toml(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text("not valid toml [[[") + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == set() + + def test_missing_key(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text("[other]\nfoo = 1\n") + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == set() + + def test_whitespace_in_values_stripped(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text( + 'disabled_tools = [" search_codebase ", "get_call_chain "]\n' + ) + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == {"search_codebase", "get_call_chain"} + + def test_blank_strings_filtered(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text( + 'disabled_tools = ["search_codebase", " ", ""]\n' + ) + result = srv._load_disabled_tools_from_config(str(tmp_path)) + assert result == {"search_codebase"} + + +# --------------------------------------------------------------------------- +# _init_disabled_tools +# --------------------------------------------------------------------------- + + +class TestInitDisabledTools: + def test_cli_only(self, tmp_path): + srv._init_disabled_tools(["search_codebase"], project_root=str(tmp_path)) + assert srv._disabled_tools == {"search_codebase"} + + def test_config_only(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text( + 'disabled_tools = ["get_call_chain"]\n' + ) + srv._init_disabled_tools(None, project_root=str(tmp_path)) + assert srv._disabled_tools == {"get_call_chain"} + + def test_union_of_cli_and_config(self, tmp_path): + (tmp_path / ".mcp-codebase-index.toml").write_text( + 'disabled_tools = ["get_call_chain"]\n' + ) + srv._init_disabled_tools(["search_codebase"], project_root=str(tmp_path)) + assert srv._disabled_tools == {"search_codebase", "get_call_chain"} + + def test_protected_tools_cannot_be_disabled(self, tmp_path): + srv._init_disabled_tools(["reindex", "get_usage_stats", "search_codebase"], + project_root=str(tmp_path)) + assert "reindex" not in srv._disabled_tools + assert "get_usage_stats" not in srv._disabled_tools + assert "search_codebase" in srv._disabled_tools + + def test_unknown_tools_ignored(self, tmp_path): + srv._init_disabled_tools(["not_a_real_tool", "search_codebase"], + project_root=str(tmp_path)) + assert "not_a_real_tool" not in srv._disabled_tools + assert "search_codebase" in srv._disabled_tools + + def test_empty_cli_list(self, tmp_path): + srv._init_disabled_tools([], project_root=str(tmp_path)) + assert srv._disabled_tools == set() + + def test_none_cli_no_config(self, tmp_path): + srv._init_disabled_tools(None, project_root=str(tmp_path)) + assert srv._disabled_tools == set() + + +# --------------------------------------------------------------------------- +# list_tools filtering +# --------------------------------------------------------------------------- + + +class TestListToolsFiltering: + def test_no_disabled_returns_all(self): + srv._disabled_tools = set() + tools = asyncio.run(srv.list_tools()) + assert len(tools) == len(srv.TOOLS) + + def test_disabled_tools_excluded(self): + srv._disabled_tools = {"search_codebase", "get_call_chain"} + tools = asyncio.run(srv.list_tools()) + names = {t.name for t in tools} + assert "search_codebase" not in names + assert "get_call_chain" not in names + assert len(tools) == len(srv.TOOLS) - 2 + + def test_protected_always_present(self): + srv._disabled_tools = {"search_codebase"} + tools = asyncio.run(srv.list_tools()) + names = {t.name for t in tools} + assert "reindex" in names + assert "get_usage_stats" in names + + +# --------------------------------------------------------------------------- +# call_tool guard +# --------------------------------------------------------------------------- + + +class TestCallToolGuard: + def test_disabled_tool_returns_error(self): + srv._disabled_tools = {"search_codebase"} + result = asyncio.run(srv.call_tool("search_codebase", {"pattern": "foo"})) + assert len(result) == 1 + assert "disabled" in result[0].text + assert "search_codebase" in result[0].text + + def test_disabled_tool_not_counted(self): + srv._tool_call_counts.clear() + srv._disabled_tools = {"search_codebase"} + asyncio.run(srv.call_tool("search_codebase", {"pattern": "foo"})) + assert "search_codebase" not in srv._tool_call_counts + + def test_enabled_tool_not_blocked(self, monkeypatch): + """An enabled tool should proceed past the guard (we mock _ensure_index).""" + srv._disabled_tools = set() + srv._tool_call_counts.clear() + # Prevent actual indexing + monkeypatch.setattr(srv, "_ensure_index", lambda: None) + monkeypatch.setattr(srv, "_maybe_incremental_update", lambda: None) + monkeypatch.setattr(srv, "_query_fns", { + "get_project_summary": lambda: "mock summary", + }) + result = asyncio.run(srv.call_tool("get_project_summary", {})) + assert result[0].text == "mock summary" + assert srv._tool_call_counts.get("get_project_summary") == 1 + + +# --------------------------------------------------------------------------- +# CLI argument parsing (main_sync integration) +# --------------------------------------------------------------------------- + + +class TestCliParsing: + def _parse(self, argv: list[str]) -> tuple: + """Run the same argparse logic as main_sync with custom argv.""" + import argparse + + parser = argparse.ArgumentParser(description="MCP codebase index server") + parser.add_argument( + "--disabled-tools", + type=lambda s: [t.strip() for t in s.split(",") if t.strip()], + default=None, + help="Comma-separated list of tool names to disable", + ) + return parser.parse_known_args(argv) + + def test_disabled_tools_parsed(self): + args, unknown = self._parse(["--disabled-tools", "search_codebase,get_call_chain"]) + assert args.disabled_tools == ["search_codebase", "get_call_chain"] + assert unknown == [] + + def test_disabled_tools_with_spaces(self): + args, _ = self._parse(["--disabled-tools", " search_codebase , get_call_chain "]) + assert args.disabled_tools == ["search_codebase", "get_call_chain"] + + def test_no_disabled_tools_flag(self): + args, unknown = self._parse([]) + assert args.disabled_tools is None + assert unknown == [] + + def test_unknown_args_not_fatal(self): + """parse_known_args must not raise SystemExit on unknown flags.""" + args, unknown = self._parse(["--some-future-flag", "value"]) + assert args.disabled_tools is None + assert "--some-future-flag" in unknown + + def test_unknown_args_coexist_with_disabled_tools(self): + args, unknown = self._parse([ + "--disabled-tools", "search_codebase", + "--unknown-flag", + ]) + assert args.disabled_tools == ["search_codebase"] + assert "--unknown-flag" in unknown