Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 106 additions & 2 deletions src/mcp_codebase_index/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@

from __future__ import annotations

import argparse
import fnmatch
import json
import os
import sys
import pickle
import time
import tomllib
import traceback

from mcp.server import Server
Expand Down Expand Up @@ -66,6 +68,10 @@
_tool_call_counts: dict[str, int] = {}
_total_chars_returned: int = 0

# Tool toggling
PROTECTED_TOOLS: frozenset[str] = frozenset({"reindex", "get_usage_stats"})
_disabled_tools: set[str] = set()

# Realistic estimate of what % of codebase you'd need to read without the indexer
_TOOL_COST_MULTIPLIERS: dict[str, float] = {
"get_project_summary": 0.10,
Expand All @@ -88,6 +94,78 @@
}


_ALL_TOOL_NAMES: frozenset[str] = frozenset() # set after TOOLS is defined


def _load_disabled_tools_from_config(project_root: str) -> set[str]:
"""Read `.mcp-codebase-index.toml` and return the set of disabled tool names."""
config_path = os.path.join(project_root, ".mcp-codebase-index.toml")
if not os.path.isfile(config_path):
return set()
try:
with open(config_path, "rb") as f:
data = tomllib.load(f)
except Exception as exc:
print(
f"[mcp-codebase-index] Warning: failed to parse {config_path}: {exc}",
file=sys.stderr,
)
return set()
raw = data.get("disabled_tools")
if raw is None:
return set()
if not isinstance(raw, list) or not all(isinstance(s, str) for s in raw):
print(
"[mcp-codebase-index] Warning: disabled_tools must be a list of strings, ignoring",
file=sys.stderr,
)
return set()
return {s.strip() for s in raw if s.strip()}


def _init_disabled_tools(
cli_disabled: list[str] | None = None,
*,
project_root: str | None = None,
) -> None:
"""Union CLI + config disabled tools, strip protected, warn about unknowns."""
global _disabled_tools
merged: set[str] = set()

if cli_disabled:
merged.update(cli_disabled)

root = project_root or os.environ.get("PROJECT_ROOT", os.getcwd())
merged |= _load_disabled_tools_from_config(root)

# Warn about protected tools that a user tried to disable
protected_requested = merged & PROTECTED_TOOLS
if protected_requested:
print(
f"[mcp-codebase-index] Warning: cannot disable protected tools: "
f"{', '.join(sorted(protected_requested))}",
file=sys.stderr,
)
merged -= PROTECTED_TOOLS

# Warn about unknown tool names
unknown = merged - _ALL_TOOL_NAMES
if unknown:
print(
f"[mcp-codebase-index] Warning: unknown tools ignored: "
f"{', '.join(sorted(unknown))}",
file=sys.stderr,
)
merged &= _ALL_TOOL_NAMES

_disabled_tools = merged
if _disabled_tools:
print(
f"[mcp-codebase-index] Disabled tools: {', '.join(sorted(_disabled_tools))}",
file=sys.stderr,
)


def _format_result(value: object) -> str:
"""Format a query result as readable text."""
if isinstance(value, str):
Expand Down Expand Up @@ -648,6 +726,9 @@ def _maybe_incremental_update() -> None:
),
]

# Now that TOOLS is defined, set the real _ALL_TOOL_NAMES
_ALL_TOOL_NAMES = frozenset(t.name for t in TOOLS)


# ---------------------------------------------------------------------------
# MCP handlers
Expand All @@ -656,13 +737,22 @@ def _maybe_incremental_update() -> None:

@server.list_tools()
async def list_tools() -> list[Tool]:
if _disabled_tools:
return [t for t in TOOLS if t.name not in _disabled_tools]
return TOOLS


@server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[types.TextContent]:
global _query_fns, _total_chars_returned

# Reject disabled tools before doing any work
if name in _disabled_tools:
return [TextContent(
type="text",
text=f"Error: tool '{name}' is disabled.",
)]

# Track tool call counts (including reindex/stats themselves)
_tool_call_counts[name] = _tool_call_counts.get(name, 0) + 1

Expand Down Expand Up @@ -787,7 +877,8 @@ async def call_tool(name: str, arguments: dict) -> list[types.TextContent]:
# ---------------------------------------------------------------------------


async def main():
async def main(cli_disabled: list[str] | None = None):
_init_disabled_tools(cli_disabled)
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
Expand All @@ -800,7 +891,20 @@ def main_sync():
"""Synchronous entry point for console_scripts."""
import asyncio

asyncio.run(main())
parser = argparse.ArgumentParser(description="MCP codebase index server")
parser.add_argument(
"--disabled-tools",
type=lambda s: [t.strip() for t in s.split(",") if t.strip()],
default=None,
help="Comma-separated list of tool names to disable (e.g. search_codebase,get_call_chain)",
)
args, unknown = parser.parse_known_args()
if unknown:
print(
f"[mcp-codebase-index] Ignoring unknown arguments: {' '.join(unknown)}",
file=sys.stderr,
)
asyncio.run(main(cli_disabled=args.disabled_tools))


if __name__ == "__main__":
Expand Down
227 changes: 227 additions & 0 deletions tests/test_tool_toggle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""Tests for the tool toggling feature (--disabled-tools + TOML config)."""

import asyncio

import pytest

import mcp_codebase_index.server as srv


@pytest.fixture(autouse=True)
def _reset_toggle_state():
"""Reset the disabled-tools set and related state before each test."""
srv._disabled_tools = set()
yield
srv._disabled_tools = set()


# ---------------------------------------------------------------------------
# _load_disabled_tools_from_config
# ---------------------------------------------------------------------------


class TestLoadDisabledToolsFromConfig:
def test_no_config_file(self, tmp_path):
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()

def test_valid_config(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = ["search_codebase", "get_call_chain"]\n'
)
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == {"search_codebase", "get_call_chain"}

def test_empty_list(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text("disabled_tools = []\n")
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()

def test_invalid_type_not_list(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text('disabled_tools = "search_codebase"\n')
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()

def test_invalid_type_list_of_non_strings(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text("disabled_tools = [1, 2]\n")
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()

def test_malformed_toml(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text("not valid toml [[[")
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()

def test_missing_key(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text("[other]\nfoo = 1\n")
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()

def test_whitespace_in_values_stripped(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = [" search_codebase ", "get_call_chain "]\n'
)
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == {"search_codebase", "get_call_chain"}

def test_blank_strings_filtered(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = ["search_codebase", " ", ""]\n'
)
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == {"search_codebase"}


# ---------------------------------------------------------------------------
# _init_disabled_tools
# ---------------------------------------------------------------------------


class TestInitDisabledTools:
def test_cli_only(self, tmp_path):
srv._init_disabled_tools(["search_codebase"], project_root=str(tmp_path))
assert srv._disabled_tools == {"search_codebase"}

def test_config_only(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = ["get_call_chain"]\n'
)
srv._init_disabled_tools(None, project_root=str(tmp_path))
assert srv._disabled_tools == {"get_call_chain"}

def test_union_of_cli_and_config(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = ["get_call_chain"]\n'
)
srv._init_disabled_tools(["search_codebase"], project_root=str(tmp_path))
assert srv._disabled_tools == {"search_codebase", "get_call_chain"}

def test_protected_tools_cannot_be_disabled(self, tmp_path):
srv._init_disabled_tools(["reindex", "get_usage_stats", "search_codebase"],
project_root=str(tmp_path))
assert "reindex" not in srv._disabled_tools
assert "get_usage_stats" not in srv._disabled_tools
assert "search_codebase" in srv._disabled_tools

def test_unknown_tools_ignored(self, tmp_path):
srv._init_disabled_tools(["not_a_real_tool", "search_codebase"],
project_root=str(tmp_path))
assert "not_a_real_tool" not in srv._disabled_tools
assert "search_codebase" in srv._disabled_tools

def test_empty_cli_list(self, tmp_path):
srv._init_disabled_tools([], project_root=str(tmp_path))
assert srv._disabled_tools == set()

def test_none_cli_no_config(self, tmp_path):
srv._init_disabled_tools(None, project_root=str(tmp_path))
assert srv._disabled_tools == set()


# ---------------------------------------------------------------------------
# list_tools filtering
# ---------------------------------------------------------------------------


class TestListToolsFiltering:
def test_no_disabled_returns_all(self):
srv._disabled_tools = set()
tools = asyncio.run(srv.list_tools())
assert len(tools) == len(srv.TOOLS)

def test_disabled_tools_excluded(self):
srv._disabled_tools = {"search_codebase", "get_call_chain"}
tools = asyncio.run(srv.list_tools())
names = {t.name for t in tools}
assert "search_codebase" not in names
assert "get_call_chain" not in names
assert len(tools) == len(srv.TOOLS) - 2

def test_protected_always_present(self):
srv._disabled_tools = {"search_codebase"}
tools = asyncio.run(srv.list_tools())
names = {t.name for t in tools}
assert "reindex" in names
assert "get_usage_stats" in names


# ---------------------------------------------------------------------------
# call_tool guard
# ---------------------------------------------------------------------------


class TestCallToolGuard:
def test_disabled_tool_returns_error(self):
srv._disabled_tools = {"search_codebase"}
result = asyncio.run(srv.call_tool("search_codebase", {"pattern": "foo"}))
assert len(result) == 1
assert "disabled" in result[0].text
assert "search_codebase" in result[0].text

def test_disabled_tool_not_counted(self):
srv._tool_call_counts.clear()
srv._disabled_tools = {"search_codebase"}
asyncio.run(srv.call_tool("search_codebase", {"pattern": "foo"}))
assert "search_codebase" not in srv._tool_call_counts

def test_enabled_tool_not_blocked(self, monkeypatch):
"""An enabled tool should proceed past the guard (we mock _ensure_index)."""
srv._disabled_tools = set()
srv._tool_call_counts.clear()
# Prevent actual indexing
monkeypatch.setattr(srv, "_ensure_index", lambda: None)
monkeypatch.setattr(srv, "_maybe_incremental_update", lambda: None)
monkeypatch.setattr(srv, "_query_fns", {
"get_project_summary": lambda: "mock summary",
})
result = asyncio.run(srv.call_tool("get_project_summary", {}))
assert result[0].text == "mock summary"
assert srv._tool_call_counts.get("get_project_summary") == 1


# ---------------------------------------------------------------------------
# CLI argument parsing (main_sync integration)
# ---------------------------------------------------------------------------


class TestCliParsing:
def _parse(self, argv: list[str]) -> tuple:
"""Run the same argparse logic as main_sync with custom argv."""
import argparse

parser = argparse.ArgumentParser(description="MCP codebase index server")
parser.add_argument(
"--disabled-tools",
type=lambda s: [t.strip() for t in s.split(",") if t.strip()],
default=None,
help="Comma-separated list of tool names to disable",
)
return parser.parse_known_args(argv)

def test_disabled_tools_parsed(self):
args, unknown = self._parse(["--disabled-tools", "search_codebase,get_call_chain"])
assert args.disabled_tools == ["search_codebase", "get_call_chain"]
assert unknown == []

def test_disabled_tools_with_spaces(self):
args, _ = self._parse(["--disabled-tools", " search_codebase , get_call_chain "])
assert args.disabled_tools == ["search_codebase", "get_call_chain"]

def test_no_disabled_tools_flag(self):
args, unknown = self._parse([])
assert args.disabled_tools is None
assert unknown == []

def test_unknown_args_not_fatal(self):
"""parse_known_args must not raise SystemExit on unknown flags."""
args, unknown = self._parse(["--some-future-flag", "value"])
assert args.disabled_tools is None
assert "--some-future-flag" in unknown

def test_unknown_args_coexist_with_disabled_tools(self):
args, unknown = self._parse([
"--disabled-tools", "search_codebase",
"--unknown-flag",
])
assert args.disabled_tools == ["search_codebase"]
assert "--unknown-flag" in unknown