Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 98 additions & 29 deletions context_scribe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
from pathlib import Path
from datetime import datetime
from typing import Optional
from typing import List, Optional
import click
from rich.console import Console
from rich.live import Live
Expand All @@ -20,6 +20,7 @@
from context_scribe.evaluator import get_evaluator, EVALUATOR_REGISTRY
from context_scribe.bridge.mcp_client import MemoryBankClient


logger = logging.getLogger("context_scribe")
console: Console = Console()

Expand Down Expand Up @@ -163,6 +164,31 @@ def bootstrap_claude_config() -> None:
f.write(f"\n{MASTER_RETRIEVAL_RULE}\n")


TOOL_REGISTRY = {
"gemini-cli": (GeminiCliProvider, bootstrap_global_config),
"copilot": (CopilotProvider, bootstrap_copilot_config),
"claude": (ClaudeProvider, bootstrap_claude_config),
}


def _create_providers(tools: List[str]):
"""Create and bootstrap providers for the given tool names.

Raises ValueError for unknown tool names.
"""
providers = []
for tool in tools:
entry = TOOL_REGISTRY.get(tool)
if entry is None:
raise ValueError(
f"Unknown tool '{tool}'. Available: {', '.join(sorted(TOOL_REGISTRY))}"
)
provider_cls, bootstrap_fn = entry
bootstrap_fn()
providers.append((tool, provider_cls()))
return providers


def _detect_evaluator(preferred_tool: Optional[str] = None) -> str:
"""Auto-detect which evaluator CLI is available, prioritizing the preferred tool."""
# Map tool names to their corresponding CLI commands
Expand Down Expand Up @@ -205,22 +231,20 @@ def _status(msg: str, db, live, debug: bool):
live.update(db.generate_layout())


async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_name: str = "auto") -> bool:
if tool == "gemini-cli":
bootstrap_global_config()
provider = GeminiCliProvider()
elif tool == "copilot":
bootstrap_copilot_config()
provider = CopilotProvider()
elif tool == "claude":
bootstrap_claude_config()
provider = ClaudeProvider()
async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_name: str = "auto", tools: Optional[List[str]] = None) -> bool:
# Build provider list: --tools takes precedence over --tool
if tools is not None:
if not tools:
raise ValueError("--tools was provided but resolved to an empty list.")
tool_names = tools
else:
provider = None
if not provider: return False
tool_names = [tool]
providers = _create_providers(tool_names)
if not providers:
return False

if evaluator_name == "auto":
evaluator_name = _detect_evaluator(tool)
evaluator_name = _detect_evaluator(tool_names[0])
evaluator = get_evaluator(evaluator_name)
mcp_client = MemoryBankClient(bank_path=bank_path)

Expand All @@ -230,29 +254,54 @@ async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_n
console.print("[bold red]Fatal Error: Could not connect to the Memory Bank MCP server.[/bold red]")
raise SystemExit(1)

db = Dashboard(tool, bank_path)
display_name = ",".join(tool_names)
db = Dashboard(display_name, bank_path)
queue: asyncio.Queue = asyncio.Queue(maxsize=1000)

async def _watch_provider(tool_name: str, provider):
"""Run a provider's watch() in a thread and feed interactions into the shared queue."""
loop = asyncio.get_event_loop()
watch_iter = provider.watch()
try:
while True:
interaction = await loop.run_in_executor(None, next, watch_iter)
if interaction is not None:
await queue.put((tool_name, interaction))
except (StopIteration, asyncio.CancelledError, KeyboardInterrupt):
pass
except Exception as e:
logger.error("Watcher for %s failed: %s", tool_name, e)

async def _loop(live=None):
watcher_tasks = []
try:
loop = asyncio.get_event_loop()
watch_iter = provider.watch()
# Start a watcher task for each provider
watcher_tasks = [
asyncio.create_task(_watch_provider(name, prov))
for name, prov in providers
]
_status("🔍 Watching log stream...", db, live, debug)

while True:
if live: live.update(db.generate_layout())
interaction = await loop.run_in_executor(None, next, watch_iter)
if interaction is None:
if live:
live.update(db.generate_layout())

# Wait for next interaction from any provider
try:
tool_name, interaction = await asyncio.wait_for(queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue

_status(f"🤔 Analyzing user message ({interaction.project_name})", db, live, debug)
_status(f"🤔 [{tool_name}] Analyzing user message ({interaction.project_name})", db, live, debug)
if debug:
logging.getLogger("context_scribe").info(" content: %s", interaction.content[:120])
logger.info(" content: %s", interaction.content[:120])

_status(f"📖 Accessing Memory Bank ({interaction.project_name})...", db, live, debug)
_status(f"📖 [{tool_name}] Accessing Memory Bank ({interaction.project_name})...", db, live, debug)
existing_global = await mcp_client.read_rules("global", "global_rules.md")
existing_project = await mcp_client.read_rules(interaction.project_name, "rules.md")

_status(f"🧠 Thinking: Extracting rules for {interaction.project_name}...", db, live, debug)
_status(f"🧠 [{tool_name}] Extracting rules for {interaction.project_name}...", db, live, debug)
loop = asyncio.get_event_loop()
rule_output = await loop.run_in_executor(None, evaluator.evaluate_interaction, interaction, existing_global, existing_project)

if rule_output:
Expand All @@ -272,11 +321,11 @@ async def _loop(live=None):
seen.add(stripped)
deduped_content = "\n".join(unique_lines).strip()

_status(f"📝 Committing: {dest_path}", db, live, debug)
_status(f"📝 [{tool_name}] Committing: {dest_path}", db, live, debug)
await mcp_client.save_rule(deduped_content, dest_proj, dest_file)

db.add_history(dest_path, rule_output.description)
_status(f"✅ SUCCESS: Updated {dest_path}", db, live, debug)
_status(f"✅ [{tool_name}] Updated {dest_path}", db, live, debug)
if not debug:
console.print(f"[bold green]▶ UPDATED:[/bold green] [cyan]{dest_path}[/cyan] ({rule_output.description})")
else:
Expand All @@ -287,6 +336,8 @@ async def _loop(live=None):
except (KeyboardInterrupt, asyncio.CancelledError):
_status("🛑 Stopping...", db, live, debug)
finally:
for task in watcher_tasks:
task.cancel()
await mcp_client.close()

if debug:
Expand All @@ -297,16 +348,34 @@ async def _loop(live=None):
return True

@click.command()
@click.option('--tool', default='gemini-cli', type=click.Choice(['gemini-cli', 'copilot', 'claude']), help='The AI tool to monitor')
@click.option('--tool', default='gemini-cli', type=click.Choice(['gemini-cli', 'copilot', 'claude']), help='Single AI tool to monitor (use --tools for multiple)')
@click.option('--tools', 'tools_csv', default=None, help='Comma-separated tools to monitor concurrently (e.g. gemini-cli,claude,copilot)')
@click.option('--bank-path', default='~/.memory-bank', help='Path to your Memory Bank root')
@click.option('--evaluator', 'evaluator_name', default='auto', type=click.Choice(['auto'] + sorted(EVALUATOR_REGISTRY)), help='Evaluator LLM to use (default: auto-detect)')
@click.option('--debug', is_flag=True, default=False, help='Stream plain debug logs instead of dashboard UI')
def cli(tool, bank_path, evaluator_name, debug):
def cli(tool, tools_csv, bank_path, evaluator_name, debug):
"""Context-Scribe: Persistent Secretary Daemon"""
if debug:
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')

# Parse --tools if provided
tools = None
if tools_csv is not None:
tools = list(dict.fromkeys( # deduplicate preserving order
t.strip() for t in tools_csv.split(",") if t.strip()
))
if not tools:
raise click.ClickException("--tools requires at least one tool name.")
valid_tools = set(TOOL_REGISTRY)
invalid = [t for t in tools if t not in valid_tools]
if invalid:
raise click.ClickException(
f"Unknown tool(s): {', '.join(invalid)}. "
f"Available: {', '.join(sorted(valid_tools))}"
)

try:
asyncio.run(run_daemon(tool, bank_path, debug=debug, evaluator_name=evaluator_name))
asyncio.run(run_daemon(tool, bank_path, debug=debug, evaluator_name=evaluator_name, tools=tools))
except KeyboardInterrupt:
pass

Expand Down
55 changes: 27 additions & 28 deletions tests/test_daemons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,40 @@
from context_scribe.main import run_daemon

@pytest.mark.asyncio
@pytest.mark.parametrize("tool, provider_class, evaluator_class, bootstrap_func, evaluator_name", [
("gemini-cli", "GeminiCliProvider", "GeminiCliEvaluator", "bootstrap_global_config", "gemini"),
("copilot", "CopilotProvider", "CopilotEvaluator", "bootstrap_copilot_config", "copilot"),
("claude", "ClaudeProvider", "ClaudeEvaluator", "bootstrap_claude_config", "claude"),
@pytest.mark.parametrize("tool, bootstrap_func, evaluator_name", [
("gemini-cli", "bootstrap_global_config", "gemini"),
("copilot", "bootstrap_copilot_config", "copilot"),
("claude", "bootstrap_claude_config", "claude"),
])
async def test_run_daemon_tools(tool, provider_class, evaluator_class, bootstrap_func, evaluator_name, daemon_mocks):
async def test_run_daemon_tools(tool, bootstrap_func, evaluator_name, daemon_mocks):
"""Test the daemon run loop for all supported tools."""
with patch(f"context_scribe.main.{provider_class}", return_value=daemon_mocks.provider):
with patch(f"context_scribe.main.{evaluator_class}", return_value=daemon_mocks.evaluator):

with patch("context_scribe.main._create_providers", return_value=[(tool, daemon_mocks.provider)]):
with patch("context_scribe.main.get_evaluator", return_value=daemon_mocks.evaluator):
with patch("context_scribe.main.MemoryBankClient", return_value=daemon_mocks.mcp):
with patch(f"context_scribe.main.{bootstrap_func}"):
# Mock Live to avoid rich rendering logic completely
with patch("context_scribe.main.Live") as mock_live:
with patch("os._exit") as mock_exit:
# Make the context manager work
mock_live.return_value.__enter__.return_value = MagicMock()
# Make the context manager work
mock_live.return_value.__enter__.return_value = MagicMock()

# Start daemon and wait for it to process the mocked interaction
daemon_task = asyncio.create_task(run_daemon(tool, "~/.memory-bank", evaluator_name=evaluator_name))
# Start daemon and wait for it to process the mocked interaction
daemon_task = asyncio.create_task(run_daemon(tool, "~/.memory-bank", evaluator_name=evaluator_name))

# Wait until save_rule is called (meaning interaction processed)
for _ in range(50):
if daemon_mocks.processed_interaction:
break
await asyncio.sleep(0.1)
# Wait until save_rule is called (meaning interaction processed)
for _ in range(100):
if daemon_mocks.processed_interaction:
break
await asyncio.sleep(0.1)

daemon_task.cancel()
try:
await daemon_task
except asyncio.CancelledError:
pass
daemon_task.cancel()
try:
await daemon_task
except asyncio.CancelledError:
pass

# Verify calls
daemon_mocks.mcp.connect.assert_called_once()
daemon_mocks.mcp.read_rules.assert_called()
daemon_mocks.evaluator.evaluate_interaction.assert_called()
daemon_mocks.mcp.save_rule.assert_called_once_with("Extracted Rule", "global", "global_rules.md")
# Verify calls
daemon_mocks.mcp.connect.assert_called_once()
daemon_mocks.mcp.read_rules.assert_called()
daemon_mocks.evaluator.evaluate_interaction.assert_called()
daemon_mocks.mcp.save_rule.assert_called_once_with("Extracted Rule", "global", "global_rules.md")
Loading
Loading