From 6003254fda03e33db6023aa3a97e0dbf2ead4512 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Fri, 20 Mar 2026 10:38:24 -0700 Subject: [PATCH 1/2] feat: revisit and simplify CLI-daemon connection model - CLI-daemon connections are stateless now. - Also stop caching project settings in memory. Much easier to pick up new version. --- req-to-pr.zip | Bin 0 -> 1958 bytes src/cocoindex_code/cli.py | 87 ++++----- src/cocoindex_code/client.py | 329 +++++++++++++++++---------------- src/cocoindex_code/daemon.py | 228 +++++++++++------------ src/cocoindex_code/indexer.py | 14 +- src/cocoindex_code/project.py | 18 +- src/cocoindex_code/server.py | 67 +++---- src/cocoindex_code/settings.py | 4 - src/cocoindex_code/shared.py | 5 +- tests/test_bg_index.py | 89 --------- tests/test_client.py | 98 +--------- tests/test_daemon.py | 21 ++- tests/test_e2e_daemon.py | 39 +--- 13 files changed, 366 insertions(+), 633 deletions(-) create mode 100644 req-to-pr.zip delete mode 100644 tests/test_bg_index.py diff --git a/req-to-pr.zip b/req-to-pr.zip new file mode 100644 index 0000000000000000000000000000000000000000..2adf95cb3d634c04d216ea2f588b31daff56e3c3 GIT binary patch literal 1958 zcmV;X2U++~O9KQH000080MWZ}T*FwDbsz@-0I&`K01^NI0CHtv!cO zVy7bGe`J7xEV3R(|Qd?B+_J74C(~% zqxOCpV>;DnT_05SNGb|Yqr-v!a*3P=XSgCJaiMZd+)2abQTcji1Yj$H!87`tsm>42 z=)Oa#>6)1OGTeWNtMSly*zk1ud10&QJ#IJ;sy?#>ckwF|)?7~Gzjp@P8gKpRo z*u*8sK$olU5L@$?$eDq|!L*Z`Pz)6w$ype3mGorQE00{*jHphe>4dr$9^kNmviIn;`m=?mGL#csm>D z1=ELaf4zBlb5{;EnFid`0iXXRP~W3Q2wc}j?k|QCY7QchTA1HiePkdKY&TEN)N%ks zZ;jR3Lz6fnrkGY%WUpOb_Jpi#iE@=I`5 zq-8v1t7UXUpLGYoo3=ZfOz`~#t>Xr3l<7IynP)Iip$C}-^u{62$of9A_1Xl5-Y1FZ zOhJT;`X;gK9m|9n5S&kW4(MG8g1fqfWH9ydC;Q=cICy2NE>{bbJa)*Aj!Fu9h3WTO zdM$@HSqhU5CKY6v1HQMjZZAXdNzX2LPSOiqBDX!n{#EN$#hqS25aJXH>x#%J?h{5I zBhUqL`cvj_?$$)H7fOhBGbO z*LOQAild5ExBnZW!TPXRk&=GyWC$($Lo$Sla&GOswByrrO26$`+1sLLg^Q0h=_ zBwR93FbGjQ7a;c+I-?EeQgCxT8K~0k;luCm-XI5jFZisk_YhQBL@???nDpOqea=d? zAc~L<(4w>`Vl@3_Di~F1^ror!Gj8|HEE&Md@MXnat^Q#!&!bohb3LVjO~*UWx_0(? zq}uUx7U@5#6E8uk4~{x^s_D11|E%GH+tXm0df7-uh=HoQbuN!AQXSrEi}G;oKSnpE zVLSPBr4pqu!vLuQ)EL!V-J9ysx+5@bGx&3y7%Y_ua*S2$Vw<=+hMZH8ueVBRz@0|V zrt{}5ns+4G3^;|+b0`U?MLZ#`QGrEr6oSN5MU{GC)TQqRgjM1alVYCpXXC&J70e@F zisbc|ybMR@P;cDZa0Qar=#yv#mOTCjbm}ith*u_wHx~>;yJT%DBMS?Qko+5pf$<`< zBzasmtY0)pf!SFZfCc&#^z!BVv=c~m!^@Xjx=rhVi0F}t@#mEVOG;qtj4-zZWJki& zm7p00$!lIQ?spI+b|hO;o>@mPqk`St^dMd!KkrboR?I7utMm3_uz;KwK{>eW57@#$ ze4E{TfX<>RS@iRuWG~$KGWZxN9^cm*Ft@LFj+e=a>F`1GM#9wJ7d&V1gkbh;K;~=b z9B{ZGhSYej0>BHLS8ypTOs11mjb)!B;?`{)faeV`@7njflJ4N=G`CCsFWP|JXWp8n zSNUSHdkdwRtf^#zVrcQ8i@Z6pkJsMlhSI3S Path: return root -def require_daemon_for_project() -> tuple[DaemonClient, str]: - """Resolve project root, then connect to daemon (auto-starting if needed). +def require_daemon_for_project() -> str: + """Resolve project root, then ensure daemon is running (auto-starting if needed). - Returns ``(client, project_root_str)``. Exits on failure. + Returns ``project_root_str``. Exits on failure. """ from .client import ensure_daemon project_root = require_project_root() try: - client = ensure_daemon() + ensure_daemon() except Exception as e: _typer.echo(f"Error: Failed to connect to daemon: {e}", err=True) raise _typer.Exit(code=1) - return client, str(project_root) + return str(project_root) def resolve_default_path(project_root: Path) -> str | None: @@ -128,12 +124,14 @@ def print_search_results(response: SearchResponse) -> None: _typer.echo(r.content) -def _run_index_with_progress(client: DaemonClient, project_root: str) -> None: +def _run_index_with_progress(project_root: str) -> None: """Run indexing with streaming progress display. Exits on failure.""" from rich.console import Console as _Console from rich.live import Live as _Live from rich.spinner import Spinner as _Spinner + from . import client as _client + err_console = _Console(stderr=True) last_progress_line: str | None = None @@ -153,7 +151,7 @@ def _on_progress(progress: IndexingProgress) -> None: live.update(_Spinner("dots", last_progress_line)) try: - resp = client.index(project_root, on_progress=_on_progress, on_waiting=_on_waiting) + resp = _client.index(project_root, on_progress=_on_progress, on_waiting=_on_waiting) except RuntimeError as e: live.stop() _typer.echo(f"Indexing failed: {e}", err=True) @@ -169,7 +167,6 @@ def _on_progress(progress: IndexingProgress) -> None: def _search_with_wait_spinner( - client: DaemonClient, project_root: str, query: str, languages: list[str] | None = None, @@ -182,6 +179,8 @@ def _search_with_wait_spinner( from rich.live import Live as _Live from rich.spinner import Spinner as _Spinner + from . import client as _client + err_console = _Console(stderr=True) with _Live(_Spinner("dots", "Searching..."), console=err_console, transient=True) as live: @@ -192,7 +191,7 @@ def _on_waiting() -> None: refresh=True, ) - resp = client.search( + resp = _client.search( project_root=project_root, query=query, languages=languages, @@ -305,13 +304,12 @@ def init( @app.command() def index() -> None: """Create/update index for the codebase.""" - client, project_root = require_daemon_for_project() - print_project_header(project_root) + from . import client as _client - _run_index_with_progress(client, project_root) - - status = client.project_status(project_root) - print_index_stats(status) + project_root = require_daemon_for_project() + print_project_header(project_root) + _run_index_with_progress(project_root) + print_index_stats(_client.project_status(project_root)) @app.command() @@ -324,12 +322,11 @@ def search( refresh: bool = _typer.Option(False, "--refresh", help="Refresh index before searching"), ) -> None: """Semantic search across the codebase.""" - client, project_root = require_daemon_for_project() + project_root = require_daemon_for_project() query_str = " ".join(query) - # Refresh index with progress display before searching if refresh: - _run_index_with_progress(client, project_root) + _run_index_with_progress(project_root) # Default path filter from CWD paths: list[str] | None = None @@ -341,7 +338,6 @@ def search( paths = [default] resp = _search_with_wait_spinner( - client, project_root=project_root, query=query_str, languages=lang or None, @@ -355,10 +351,11 @@ def search( @app.command() def status() -> None: """Show project status.""" - client, project_root = require_daemon_for_project() + from . import client as _client + + project_root = require_daemon_for_project() print_project_header(project_root) - resp = client.project_status(project_root) - print_index_stats(resp) + print_index_stats(_client.project_status(project_root)) @app.command() @@ -400,12 +397,9 @@ def reset( # Remove project from daemon first so it releases file handles try: - from .client import DaemonClient + from . import client as _client - client = DaemonClient.connect() - client.handshake() - client.remove_project(str(project_root)) - client.close() + _client.remove_project(str(project_root)) except (ConnectionRefusedError, OSError, RuntimeError): pass # Daemon not running — that's fine @@ -442,13 +436,12 @@ def mcp() -> None: """Run as MCP server (stdio mode).""" import asyncio - client, project_root = require_daemon_for_project() + project_root = require_daemon_for_project() async def _run_mcp() -> None: from .server import create_mcp_server - mcp_server = create_mcp_server(client, project_root) - # Trigger initial indexing in background + mcp_server = create_mcp_server(project_root) asyncio.create_task(_bg_index(project_root)) await mcp_server.run_stdio_async() @@ -456,27 +449,14 @@ async def _run_mcp() -> None: async def _bg_index(project_root: str) -> None: - """Index in background using a dedicated daemon connection. - - A fresh DaemonClient is used so that background indexing does not share - the multiprocessing connection used by foreground MCP requests, which - would corrupt data ("Input data was truncated"). - """ + """Index in background. Each call opens its own daemon connection.""" import asyncio - from .client import ensure_daemon + from . import client as _client loop = asyncio.get_event_loop() - - def _run_index() -> None: - bg_client = ensure_daemon() - try: - bg_client.index(project_root) - finally: - bg_client.close() - try: - await loop.run_in_executor(None, _run_index) + await loop.run_in_executor(None, lambda: _client.index(project_root)) except Exception: pass @@ -487,15 +467,15 @@ def _run_index() -> None: @daemon_app.command("status") def daemon_status() -> None: """Show daemon status.""" - from .client import ensure_daemon + from . import client as _client try: - client = ensure_daemon() + _client.ensure_daemon() except Exception as e: _typer.echo(f"Error: {e}", err=True) raise _typer.Exit(code=1) - resp = client.daemon_status() + resp = _client.daemon_status() _typer.echo(f"Daemon version: {resp.version}") _typer.echo(f"Uptime: {resp.uptime_seconds:.1f}s") if resp.projects: @@ -505,7 +485,6 @@ def daemon_status() -> None: _typer.echo(f" {p.project_root} [{state}]") else: _typer.echo("No projects loaded.") - client.close() @daemon_app.command("restart") diff --git a/src/cocoindex_code/client.py b/src/cocoindex_code/client.py index 68ba10a..b14785a 100644 --- a/src/cocoindex_code/client.py +++ b/src/cocoindex_code/client.py @@ -1,4 +1,9 @@ -"""Client for communicating with the daemon.""" +"""Client for communicating with the daemon. + +Per-request connection model: each function opens a fresh connection, +performs the version handshake, sends one request, reads the response(s), +and closes. There is no persistent connection object. +""" from __future__ import annotations @@ -41,41 +46,90 @@ logger = logging.getLogger(__name__) -class DaemonClient: - """Client for communicating with the daemon.""" +# --------------------------------------------------------------------------- +# Per-request connection helpers +# --------------------------------------------------------------------------- - _conn: Connection - def __init__(self, conn: Connection) -> None: - self._conn = conn +def _connect_and_handshake() -> Connection: + """Connect to the daemon and perform the version handshake. - @classmethod - def connect(cls) -> DaemonClient: - """Connect to daemon. Raises ConnectionRefusedError if not running.""" - sock = daemon_socket_path() - if not os.path.exists(sock): - raise ConnectionRefusedError(f"Daemon socket not found: {sock}") - try: - conn = Client(sock, family=_connection_family()) - except (ConnectionRefusedError, FileNotFoundError, OSError) as e: - raise ConnectionRefusedError(f"Cannot connect to daemon: {e}") from e - return cls(conn) - - def handshake(self) -> HandshakeResponse: - """Send version handshake.""" - return self._send(HandshakeRequest(version=__version__)) # type: ignore[return-value] - - def index( - self, - project_root: str, - on_progress: Callable[[IndexingProgress], None] | None = None, - on_waiting: Callable[[], None] | None = None, - ) -> IndexResponse: - """Request indexing with streaming progress. Blocks until complete.""" - self._conn.send_bytes(encode_request(IndexRequest(project_root=project_root))) + Returns the open connection for the caller to send exactly one request. + Raises ``ConnectionRefusedError`` if the daemon is not running, or + ``RuntimeError`` on protocol/version errors. + """ + sock = daemon_socket_path() + if sys.platform != "win32" and not os.path.exists(sock): + raise ConnectionRefusedError(f"Daemon socket not found: {sock}") + try: + conn = Client(sock, family=_connection_family()) + except (ConnectionRefusedError, FileNotFoundError, OSError) as e: + raise ConnectionRefusedError(f"Cannot connect to daemon: {e}") from e + + try: + conn.send_bytes(encode_request(HandshakeRequest(version=__version__))) + data = conn.recv_bytes() + except (EOFError, OSError) as e: + conn.close() + raise ConnectionRefusedError(f"Handshake failed: {e}") from e + + resp = decode_response(data) + if isinstance(resp, ErrorResponse): + conn.close() + raise RuntimeError(f"Daemon error: {resp.message}") + if not isinstance(resp, HandshakeResponse): + conn.close() + raise RuntimeError(f"Unexpected handshake response: {type(resp).__name__}") + if not resp.ok: + conn.close() + raise _VersionMismatchError(resp) + return conn + + +class _VersionMismatchError(Exception): + """Raised when the daemon version or settings are stale.""" + + def __init__(self, resp: HandshakeResponse) -> None: + self.resp = resp + super().__init__( + f"Daemon version mismatch: {resp.daemon_version} " + f"(settings_mtime={resp.global_settings_mtime_us})" + ) + + +def _send(req: Request) -> Response: + """Open connection, handshake, send one request, read one response, close.""" + conn = _connect_and_handshake() + try: + conn.send_bytes(encode_request(req)) + data = conn.recv_bytes() + except (EOFError, OSError) as e: + raise RuntimeError(f"Connection to daemon lost: {e}") from e + finally: + conn.close() + resp = decode_response(data) + if isinstance(resp, ErrorResponse): + raise RuntimeError(f"Daemon error: {resp.message}") + return resp + + +# --------------------------------------------------------------------------- +# Public API — one function per request type +# --------------------------------------------------------------------------- + + +def index( + project_root: str, + on_progress: Callable[[IndexingProgress], None] | None = None, + on_waiting: Callable[[], None] | None = None, +) -> IndexResponse: + """Request indexing with streaming progress. Blocks until complete.""" + conn = _connect_and_handshake() + try: + conn.send_bytes(encode_request(IndexRequest(project_root=project_root))) while True: try: - data = self._conn.recv_bytes() + data = conn.recv_bytes() except EOFError: raise RuntimeError("Connection to daemon lost during indexing") resp = decode_response(data) @@ -92,24 +146,28 @@ def index( if isinstance(resp, IndexResponse): return resp raise RuntimeError(f"Unexpected response: {type(resp).__name__}") - - def search( - self, - project_root: str, - query: str, - languages: list[str] | None = None, - paths: list[str] | None = None, - limit: int = 5, - offset: int = 0, - on_waiting: Callable[[], None] | None = None, - ) -> SearchResponse: - """Search the codebase. - - If the daemon sends ``IndexWaitingNotice`` (load-time indexing in - progress), calls *on_waiting* (if provided) then continues reading - until the final ``SearchResponse``. - """ - self._conn.send_bytes( + finally: + conn.close() + + +def search( + project_root: str, + query: str, + languages: list[str] | None = None, + paths: list[str] | None = None, + limit: int = 5, + offset: int = 0, + on_waiting: Callable[[], None] | None = None, +) -> SearchResponse: + """Search the codebase. + + If the daemon sends ``IndexWaitingNotice`` (load-time indexing in + progress), calls *on_waiting* (if provided) then continues reading + until the final ``SearchResponse``. + """ + conn = _connect_and_handshake() + try: + conn.send_bytes( encode_request( SearchRequest( project_root=project_root, @@ -123,7 +181,7 @@ def search( ) while True: try: - data = self._conn.recv_bytes() + data = conn.recv_bytes() except EOFError: raise RuntimeError("Connection to daemon lost during search") resp = decode_response(data) @@ -136,38 +194,26 @@ def search( if isinstance(resp, SearchResponse): return resp raise RuntimeError(f"Unexpected response: {type(resp).__name__}") + finally: + conn.close() - def project_status(self, project_root: str) -> ProjectStatusResponse: - return self._send( # type: ignore[return-value] - ProjectStatusRequest(project_root=project_root) - ) - def daemon_status(self) -> DaemonStatusResponse: - from .protocol import DaemonStatusRequest +def project_status(project_root: str) -> ProjectStatusResponse: + return _send(ProjectStatusRequest(project_root=project_root)) # type: ignore[return-value] - return self._send(DaemonStatusRequest()) # type: ignore[return-value] - def remove_project(self, project_root: str) -> RemoveProjectResponse: - return self._send( # type: ignore[return-value] - RemoveProjectRequest(project_root=project_root) - ) +def daemon_status() -> DaemonStatusResponse: + from .protocol import DaemonStatusRequest - def stop(self) -> StopResponse: - return self._send(StopRequest()) # type: ignore[return-value] + return _send(DaemonStatusRequest()) # type: ignore[return-value] - def close(self) -> None: - try: - self._conn.close() - except Exception: - pass - def _send(self, req: Request) -> Response: - self._conn.send_bytes(encode_request(req)) - data = self._conn.recv_bytes() - resp = decode_response(data) - if isinstance(resp, ErrorResponse): - raise RuntimeError(f"Daemon error: {resp.message}") - return resp +def remove_project(project_root: str) -> RemoveProjectResponse: + return _send(RemoveProjectRequest(project_root=project_root)) # type: ignore[return-value] + + +def stop() -> StopResponse: + return _send(StopRequest()) # type: ignore[return-value] # --------------------------------------------------------------------------- @@ -178,8 +224,6 @@ def _send(self, req: Request) -> Response: def is_daemon_running() -> bool: """Check if the daemon is running.""" if sys.platform == "win32": - # os.path.exists is unreliable for Windows named pipes; - # try connecting instead. try: conn = Client(daemon_socket_path(), family=_connection_family()) conn.close() @@ -196,18 +240,14 @@ def start_daemon() -> None: daemon_dir().mkdir(parents=True, exist_ok=True) log_path = daemon_dir() / "daemon.log" - # Use the ccc entry point if available, otherwise fall back to python -m ccc_path = _find_ccc_executable() if ccc_path: cmd = [ccc_path, "run-daemon"] else: cmd = [sys.executable, "-m", "cocoindex_code.cli", "run-daemon"] - log_fd = open(log_path, "a") + log_fd = open(log_path, "w") if sys.platform == "win32": - # CREATE_NO_WINDOW prevents the daemon from showing a visible - # console window. DETACHED_PROCESS alone is not sufficient — - # it detaches from the parent console but still creates a new one. _create_no_window = 0x08000000 subprocess.Popen( cmd, @@ -230,7 +270,6 @@ def start_daemon() -> None: def _find_ccc_executable() -> str | None: """Find the ccc executable in PATH or the same directory as python.""" python_dir = Path(sys.executable).parent - # On Windows the script is ccc.exe; on Unix it's just ccc names = ["ccc.exe", "ccc"] if sys.platform == "win32" else ["ccc"] for name in names: ccc = python_dir / name @@ -242,10 +281,6 @@ def _find_ccc_executable() -> str | None: def _pid_alive(pid: int) -> bool: """Return True if *pid* is still running.""" if sys.platform == "win32": - # Avoid os.kill(pid, 0) on Windows — it has a CPython bug that corrupts - # the C-level exception state, causing subsequent C function calls - # (time.monotonic, time.sleep) to raise SystemError even after the - # OSError is caught. Use OpenProcess via ctypes instead. import ctypes kernel32 = getattr(ctypes, "windll").kernel32 @@ -255,95 +290,75 @@ def _pid_alive(pid: int) -> bool: return True return False try: - os.kill(pid, 0) # signal 0: check existence without killing + os.kill(pid, 0) return True except ProcessLookupError: return False except PermissionError: - return True # process exists but we can't signal it + return True + + +def _wait_for_daemon_exit(timeout: float) -> bool: + """Wait up to *timeout* seconds for the daemon to finish cleanup. + + Returns True when the daemon's PID file is gone (meaning it completed its + shutdown sequence). This is more reliable than checking process liveness + because the daemon process may linger as a zombie. + """ + pid_path = daemon_pid_path() + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if not pid_path.exists(): + return True + time.sleep(0.1) + return not pid_path.exists() def stop_daemon() -> None: """Stop the daemon gracefully. - Sends a StopRequest, waits for the process to exit, falls back to - SIGTERM → SIGKILL. Only removes the PID file after confirming that - the specific PID is no longer alive. + Escalation: StopRequest → SIGTERM → SIGKILL. """ pid_path = daemon_pid_path() - # Read the PID early so we can track the actual process. pid: int | None = None try: pid = int(pid_path.read_text().strip()) if pid == os.getpid(): - pid = None # safety: never kill ourselves + pid = None except (FileNotFoundError, ValueError): pass - # Step 1: try sending StopRequest via socket + # 1) Graceful StopRequest via socket try: - client = DaemonClient.connect() - client.handshake() - client.stop() - client.close() + stop() except (ConnectionRefusedError, OSError, RuntimeError): pass - # Step 2: wait for process to exit (up to 5s) - if pid is not None: - deadline = time.monotonic() + 5.0 - while time.monotonic() < deadline and _pid_alive(pid): - time.sleep(0.1) - if not _pid_alive(pid): - _cleanup_stale_files(pid_path, pid) - return + if _wait_for_daemon_exit(timeout=3.0): + return - # Step 3: if still running, try SIGTERM + # 2) SIGTERM if pid is not None and _pid_alive(pid): try: os.kill(pid, signal.SIGTERM) except (ProcessLookupError, PermissionError): pass - - deadline = time.monotonic() + 2.0 - while time.monotonic() < deadline and _pid_alive(pid): - time.sleep(0.1) - - if not _pid_alive(pid): - _cleanup_stale_files(pid_path, pid) + if _wait_for_daemon_exit(timeout=2.0): return - # Step 4: escalate to SIGKILL (Unix only; - # on Windows SIGTERM already calls TerminateProcess) + # 3) SIGKILL (Unix) — on Windows SIGTERM already calls TerminateProcess if sys.platform != "win32" and pid is not None and _pid_alive(pid): try: os.kill(pid, signal.SIGKILL) except (ProcessLookupError, PermissionError): pass - # SIGKILL is async; give the kernel a moment to reap - deadline = time.monotonic() + 1.0 - while time.monotonic() < deadline and _pid_alive(pid): - time.sleep(0.1) - - # Step 4b: on Windows, wait for the process to fully exit after TerminateProcess - # so that named pipe handles are released before starting a new daemon. - if sys.platform == "win32" and pid is not None: - deadline = time.monotonic() + 3.0 - while time.monotonic() < deadline and _pid_alive(pid): - time.sleep(0.1) - - # Step 5: clean up stale files _cleanup_stale_files(pid_path, pid) def _cleanup_stale_files(pid_path: Path, pid: int | None) -> None: - """Remove socket and PID file after the daemon has exited. - - Only removes the PID file when *pid* matches what is on disk, to - avoid accidentally deleting a newer daemon's PID file. - """ + """Remove socket and PID file after the daemon has exited.""" if sys.platform != "win32": sock = daemon_socket_path() try: @@ -358,7 +373,6 @@ def _cleanup_stale_files(pid_path: Path, pid: int | None) -> None: except (FileNotFoundError, ValueError): pass else: - # No PID known — cautiously remove if file exists try: pid_path.unlink(missing_ok=True) except Exception: @@ -371,8 +385,6 @@ def _wait_for_daemon(timeout: float = 30.0) -> None: sock_path = daemon_socket_path() while time.monotonic() < deadline: if sys.platform == "win32": - # os.path.exists is unreliable for Windows named pipes; - # try an actual connection to verify the daemon is listening. try: conn = Client(sock_path, family=_connection_family()) conn.close() @@ -387,11 +399,7 @@ def _wait_for_daemon(timeout: float = 30.0) -> None: def _needs_restart(resp: HandshakeResponse) -> bool: - """Check if the daemon needs to be restarted. - - Returns True if the version mismatches or if global_settings.yml has been - modified since the daemon loaded it. - """ + """Check if the daemon needs to be restarted.""" if not resp.ok: return True from .settings import global_settings_mtime_us @@ -402,22 +410,18 @@ def _needs_restart(resp: HandshakeResponse) -> bool: return False -def ensure_daemon() -> DaemonClient: - """Connect to daemon, starting or restarting as needed. +def ensure_daemon() -> None: + """Ensure daemon is running with correct version. Starts or restarts as needed. - 1. Try to connect to existing daemon. - 2. If connection refused: start daemon, retry connect with backoff. - 3. If connected but version mismatch or global settings changed: - stop old daemon, start new one. + After this returns, per-request functions (``index``, ``search``, etc.) + can be called directly — each opens its own connection. """ # Try connecting to existing daemon try: - client = DaemonClient.connect() - resp = client.handshake() - if not _needs_restart(resp): - return client - # Version or settings mismatch — restart - client.close() + conn = _connect_and_handshake() + conn.close() + return # daemon is up and version matches + except _VersionMismatchError: stop_daemon() except (ConnectionRefusedError, OSError): pass @@ -426,17 +430,14 @@ def ensure_daemon() -> DaemonClient: start_daemon() _wait_for_daemon() - # Connect with retries + # Verify with retries for _attempt in range(10): try: - client = DaemonClient.connect() - resp = client.handshake() - if not _needs_restart(resp): - return client - raise RuntimeError( - f"Daemon mismatch after fresh start: version={resp.daemon_version}, " - f"settings_mtime={resp.global_settings_mtime_us}" - ) + conn = _connect_and_handshake() + conn.close() + return + except _VersionMismatchError as e: + raise RuntimeError(f"Daemon mismatch after fresh start: {e}") from e except (ConnectionRefusedError, OSError): time.sleep(0.5) diff --git a/src/cocoindex_code/daemon.py b/src/cocoindex_code/daemon.py index 5974869..e524d66 100644 --- a/src/cocoindex_code/daemon.py +++ b/src/cocoindex_code/daemon.py @@ -48,7 +48,6 @@ from .query import query_codebase from .settings import ( global_settings_mtime_us, - load_project_settings, load_user_settings, user_settings_dir, ) @@ -122,8 +121,7 @@ async def get_project(self, project_root: str, *, suppress_auto_index: bool = Fa """ if project_root not in self._projects: root = Path(project_root) - project_settings = load_project_settings(root) - project = await Project.create(root, project_settings, self._embedder) + project = await Project.create(root, self._embedder) self._projects[project_root] = project self._index_locks[project_root] = asyncio.Lock() self._load_time_done[project_root] = asyncio.Event() @@ -354,66 +352,55 @@ async def handle_connection( conn: Connection, registry: ProjectRegistry, start_time: float, - shutdown_event: asyncio.Event, + on_shutdown: Callable[[], None], settings_mtime_us: int | None, ) -> None: - """Handle a single client connection.""" - loop = asyncio.get_event_loop() - handshake_done = False - - def _recv() -> bytes: - """Blocking recv that also checks for shutdown.""" - # Use poll with a timeout so we can check shutdown_event periodically - while not shutdown_event.is_set(): - if conn.poll(0.5): - return conn.recv_bytes() - raise EOFError("shutdown") + """Handle a single client connection (per-request model). + Reads exactly two messages: a ``HandshakeRequest`` followed by one + ``Request``. Sends the response(s) and closes the connection. + """ + loop = asyncio.get_event_loop() try: - while not shutdown_event.is_set(): - try: - data: bytes = await loop.run_in_executor(None, _recv) - except (EOFError, OSError): - break + # 1. Handshake + data: bytes = await loop.run_in_executor(None, conn.recv_bytes) + req = decode_request(data) - try: - req = decode_request(data) - except Exception as e: - resp: Response = ErrorResponse(message=f"Invalid request: {e}") - conn.send_bytes(encode_response(resp)) - continue - - if not handshake_done: - if not isinstance(req, HandshakeRequest): - resp = ErrorResponse(message="First message must be a handshake") - conn.send_bytes(encode_response(resp)) - break + if not isinstance(req, HandshakeRequest): + conn.send_bytes( + encode_response(ErrorResponse(message="First message must be a handshake")) + ) + return - ok = req.version == __version__ - resp = HandshakeResponse( + ok = req.version == __version__ + conn.send_bytes( + encode_response( + HandshakeResponse( ok=ok, daemon_version=__version__, global_settings_mtime_us=settings_mtime_us, ) - conn.send_bytes(encode_response(resp)) - if not ok: - break - handshake_done = True - continue - - result = await _dispatch(req, registry, start_time, shutdown_event) - if isinstance(result, AsyncIterator): - try: - async for resp in result: - conn.send_bytes(encode_response(resp)) - except Exception as exc: - logger.exception("Error during streaming response") - conn.send_bytes(encode_response(ErrorResponse(message=str(exc)))) - else: - conn.send_bytes(encode_response(result)) - - if isinstance(req, StopRequest): - break + ) + ) + if not ok: + return + + # 2. Single request + data = await loop.run_in_executor(None, conn.recv_bytes) + req = decode_request(data) + + result = await _dispatch(req, registry, start_time, on_shutdown) + if isinstance(result, AsyncIterator): + try: + async for resp in result: + conn.send_bytes(encode_response(resp)) + except Exception as exc: + logger.exception("Error during streaming response") + conn.send_bytes(encode_response(ErrorResponse(message=str(exc)))) + else: + conn.send_bytes(encode_response(result)) + except (EOFError, OSError, asyncio.CancelledError): + pass except Exception: logger.exception("Error handling connection") finally: @@ -452,7 +439,7 @@ async def _dispatch( req: Request, registry: ProjectRegistry, start_time: float, - shutdown_event: asyncio.Event, + on_shutdown: Callable[[], None], ) -> Response | AsyncIterator[IndexStreamResponse] | AsyncIterator[SearchStreamResponse]: """Dispatch a request to the appropriate handler. @@ -502,7 +489,7 @@ async def _dispatch( return RemoveProjectResponse(ok=True) if isinstance(req, StopRequest): - shutdown_event.set() + on_shutdown() return StopResponse(ok=True) return ErrorResponse(message=f"Unknown request type: {type(req).__name__}") @@ -517,7 +504,12 @@ async def _dispatch( def run_daemon() -> None: - """Main entry point for the daemon process (blocking).""" + """Main entry point for the daemon process (blocking). + + Sets up the listener, runs the asyncio event loop (``loop.run_forever``) + to serve connections, and performs cleanup when shutdown is requested via + ``StopRequest`` or a signal (SIGTERM / SIGINT). + """ daemon_dir().mkdir(parents=True, exist_ok=True) # Load user settings and record mtime for staleness detection @@ -540,44 +532,16 @@ def run_daemon() -> None: logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", - handlers=[logging.FileHandler(str(log_path)), logging.StreamHandler()], + handlers=[logging.FileHandler(str(log_path), mode="w"), logging.StreamHandler()], force=True, ) logger.info("Daemon starting (PID %d, version %s)", os.getpid(), __version__) - try: - asyncio.run(_async_daemon_main(embedder, settings_mtime_us)) - finally: - # Clean up socket first, then PID file last. - # The PID file is the authoritative "daemon is alive" indicator, so it - # must be the very last thing removed to avoid races where a client - # sees the PID gone but the socket (or process) is still lingering. - if sys.platform != "win32": - sock = daemon_socket_path() - try: - Path(sock).unlink(missing_ok=True) - except Exception: - pass - # Only remove the PID file if it still contains *our* PID. - # A new daemon may have already overwritten it during a restart race. - try: - stored = pid_path.read_text().strip() - if stored == str(os.getpid()): - pid_path.unlink(missing_ok=True) - except Exception: - pass - logger.info("Daemon stopped") - - -async def _async_daemon_main(embedder: Embedder, settings_mtime_us: int | None) -> None: - """Async main loop for the daemon.""" start_time = time.monotonic() registry = ProjectRegistry(embedder) - shutdown_event = asyncio.Event() sock_path = daemon_socket_path() - # Remove stale socket (not applicable for Windows named pipes) if sys.platform != "win32": try: Path(sock_path).unlink(missing_ok=True) @@ -587,56 +551,82 @@ async def _async_daemon_main(embedder: Embedder, settings_mtime_us: int | None) listener = Listener(sock_path, family=_connection_family()) logger.info("Listening on %s", sock_path) - loop = asyncio.get_event_loop() + loop = asyncio.new_event_loop() + tasks: set[asyncio.Task[Any]] = set() + + def _request_shutdown() -> None: + """Trigger daemon shutdown — called by StopRequest or signal handler.""" + loop.stop() + + def _spawn_handler(conn: Connection) -> None: + task = loop.create_task( + handle_connection( + conn, + registry, + start_time, + _request_shutdown, + settings_mtime_us, + ) + ) + tasks.add(task) + task.add_done_callback(tasks.discard) - # Handle signals for graceful shutdown (not supported on all platforms/contexts) + # Handle signals for graceful shutdown try: for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, shutdown_event.set) + loop.add_signal_handler(sig, _request_shutdown) except (RuntimeError, NotImplementedError): pass # Not in main thread, or not supported on this platform (e.g. Windows) - tasks: set[asyncio.Task[Any]] = set() - - async def _spawn_handler( - conn: Connection, - reg: ProjectRegistry, - st: float, - evt: asyncio.Event, - task_set: set[asyncio.Task[Any]], - ) -> None: - task = asyncio.create_task(handle_connection(conn, reg, st, evt, settings_mtime_us)) - task_set.add(task) - task.add_done_callback(task_set.discard) - - # Run accept loop in a thread so we can shut down cleanly + # Accept loop runs in a background thread; new connections are dispatched + # to the event loop via call_soon_threadsafe. The loop exits when + # listener.close() (called during shutdown) causes accept() to raise. def _accept_loop() -> None: - while not shutdown_event.is_set(): + while True: try: - try: - listener._listener._socket.settimeout(0.5) # type: ignore[attr-defined] - except AttributeError: - pass # AF_PIPE (Windows) doesn't expose ._socket conn = listener.accept() - # Schedule the handler on the event loop - asyncio.run_coroutine_threadsafe( - _spawn_handler(conn, registry, start_time, shutdown_event, tasks), - loop, - ) + loop.call_soon_threadsafe(_spawn_handler, conn) except OSError: - if shutdown_event.is_set(): - break - # Socket timeout — just retry - continue + break accept_thread = threading.Thread(target=_accept_loop, daemon=True) accept_thread.start() + # --- Serve until shutdown --- try: - await shutdown_event.wait() + loop.run_forever() finally: + # 1. Stop accepting new connections. listener.close() - accept_thread.join(timeout=2) + + # 2. Cancel handler tasks (they may be blocked in run_in_executor). + for task in tasks: + task.cancel() if tasks: - await asyncio.gather(*tasks, return_exceptions=True) + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + + # 3. Release project resources. registry.close_all() + loop.close() + + # 4. Remove socket and PID file. + if sys.platform != "win32": + try: + Path(sock_path).unlink(missing_ok=True) + except Exception: + pass + try: + stored = pid_path.read_text().strip() + if stored == str(os.getpid()): + pid_path.unlink(missing_ok=True) + except Exception: + pass + + logger.info("Daemon stopped") + + # 5. Hard-exit to avoid slow Python teardown (torch, threadpool, etc.). + # All resources are already cleaned up above. Only do this when + # running as the main entry point (not when the daemon is started + # in-process for testing). + if threading.current_thread() is threading.main_thread(): + os._exit(0) diff --git a/src/cocoindex_code/indexer.py b/src/cocoindex_code/indexer.py index 64ff3ad..b2a5bc5 100644 --- a/src/cocoindex_code/indexer.py +++ b/src/cocoindex_code/indexer.py @@ -14,12 +14,10 @@ from cocoindex.resources.id import IdGenerator from pathspec import GitIgnoreSpec -from .settings import PROJECT_SETTINGS +from .settings import load_gitignore_spec, load_project_settings from .shared import ( CODEBASE_DIR, EMBEDDER, - EXT_LANG_OVERRIDE_MAP, - GITIGNORE_SPEC, SQLITE_DB, CodeChunk, ) @@ -151,9 +149,11 @@ async def process_file( return suffix = file.file_path.path.suffix - ext_lang_override_map = coco.use_context(EXT_LANG_OVERRIDE_MAP) + project_root = coco.use_context(CODEBASE_DIR) + ps = load_project_settings(project_root) + ext_lang_map = {f".{lo.ext}": lo.lang for lo in ps.language_overrides} language = ( - ext_lang_override_map.get(suffix) + ext_lang_map.get(suffix) or detect_code_language(filename=file.file_path.path.name) or "text" ) @@ -187,9 +187,9 @@ async def process(chunk: Chunk) -> None: @coco.fn async def indexer_main() -> None: """Main indexing function - walks files and processes each.""" - ps = coco.use_context(PROJECT_SETTINGS) - gitignore_spec = coco.use_context(GITIGNORE_SPEC) project_root = coco.use_context(CODEBASE_DIR) + ps = load_project_settings(project_root) + gitignore_spec = load_gitignore_spec(project_root) table = await sqlite.mount_table_target( db=SQLITE_DB, diff --git a/src/cocoindex_code/project.py b/src/cocoindex_code/project.py index f9f60a4..f2adff7 100644 --- a/src/cocoindex_code/project.py +++ b/src/cocoindex_code/project.py @@ -11,12 +11,9 @@ from .indexer import indexer_main from .protocol import IndexingProgress -from .settings import PROJECT_SETTINGS, ProjectSettings, load_gitignore_spec from .shared import ( CODEBASE_DIR, EMBEDDER, - EXT_LANG_OVERRIDE_MAP, - GITIGNORE_SPEC, SQLITE_DB, Embedder, ) @@ -84,10 +81,14 @@ def is_initial_index_done(self) -> bool: @staticmethod async def create( project_root: Path, - project_settings: ProjectSettings, embedder: Embedder, ) -> Project: - """Create a project with explicit settings and embedder.""" + """Create a project with explicit embedder. + + Project-level settings and .gitignore are NOT cached here — the + indexer loads them fresh from disk on every run so that user edits + take effect without restarting the daemon. + """ index_dir = project_root / ".cocoindex_code" index_dir.mkdir(parents=True, exist_ok=True) @@ -95,18 +96,11 @@ async def create( target_sqlite_db_path = index_dir / "target_sqlite.db" settings = coco.Settings.from_env(cocoindex_db_path) - gitignore_spec = load_gitignore_spec(project_root) context = coco.ContextProvider() context.provide(CODEBASE_DIR, project_root) context.provide(SQLITE_DB, sqlite.connect(str(target_sqlite_db_path), load_vec=True)) context.provide(EMBEDDER, embedder) - context.provide(PROJECT_SETTINGS, project_settings) - context.provide( - EXT_LANG_OVERRIDE_MAP, - {f".{lo.ext}": lo.lang for lo in project_settings.language_overrides}, - ) - context.provide(GITIGNORE_SPEC, gitignore_spec) env = coco.Environment(settings, context_provider=context) app = coco.App( diff --git a/src/cocoindex_code/server.py b/src/cocoindex_code/server.py index c56229e..955ffe4 100644 --- a/src/cocoindex_code/server.py +++ b/src/cocoindex_code/server.py @@ -2,7 +2,7 @@ Supports two modes: 1. Daemon-backed: ``create_mcp_server(client, project_root)`` — lightweight MCP - server that delegates to the daemon via a ``DaemonClient``. + server that delegates to the daemon via per-request client functions. 2. Legacy entry point: ``main()`` — backward-compatible ``cocoindex-code`` CLI that auto-creates settings from env vars and delegates to the daemon. """ @@ -13,16 +13,10 @@ import json import os from pathlib import Path -from typing import TYPE_CHECKING from mcp.server.fastmcp import FastMCP from pydantic import BaseModel, Field -if TYPE_CHECKING: - from .client import DaemonClient - -from .protocol import IndexingProgress - _MCP_INSTRUCTIONS = ( "Code search and codebase understanding tools." "\n" @@ -62,7 +56,7 @@ class SearchResultModel(BaseModel): # === Daemon-backed MCP server factory === -def create_mcp_server(client: DaemonClient, project_root: str) -> FastMCP: +def create_mcp_server(project_root: str) -> FastMCP: """Create a lightweight MCP server that delegates to the daemon.""" mcp = FastMCP("cocoindex-code", instructions=_MCP_INSTRUCTIONS) @@ -125,13 +119,15 @@ async def search( ), ) -> SearchResultModel: """Query the codebase index via the daemon.""" + from . import client as _client + loop = asyncio.get_event_loop() try: if refresh_index: - await loop.run_in_executor(None, lambda: client.index(project_root)) + await loop.run_in_executor(None, lambda: _client.index(project_root)) resp = await loop.run_in_executor( None, - lambda: client.search( + lambda: _client.search( project_root=project_root, query=query, languages=languages, @@ -185,7 +181,6 @@ def main() -> None: """ import argparse - from .client import ensure_daemon from .settings import ( EmbeddingSettings, LanguageOverride, @@ -276,6 +271,11 @@ def main() -> None: save_user_settings(us) # --- Delegate to daemon --- + from . import client as _client + from .protocol import IndexingProgress + + _client.ensure_daemon() + if args.command == "index": import sys @@ -285,7 +285,6 @@ def main() -> None: from .cli import _format_progress - client = ensure_daemon() err_console = Console(stderr=True) last_progress_line: str | None = None @@ -304,54 +303,32 @@ def _on_progress(progress: IndexingProgress) -> None: last_progress_line = f"Indexing: {_format_progress(progress)}" live.update(Spinner("dots", last_progress_line)) - resp = client.index(str(project_root), on_progress=_on_progress, on_waiting=_on_waiting) + resp = _client.index( + str(project_root), on_progress=_on_progress, on_waiting=_on_waiting + ) if last_progress_line is not None: print(last_progress_line, file=sys.stderr) if resp.success: - status = client.project_status(str(project_root)) + st = _client.project_status(str(project_root)) print("\nIndex stats:") - print(f" Chunks: {status.total_chunks}") - print(f" Files: {status.total_files}") - if status.languages: + print(f" Chunks: {st.total_chunks}") + print(f" Files: {st.total_files}") + if st.languages: print(" Languages:") - for lang, count in sorted(status.languages.items(), key=lambda x: -x[1]): + for lang, count in sorted(st.languages.items(), key=lambda x: -x[1]): print(f" {lang}: {count} chunks") else: print(f"Indexing failed: {resp.message}") - client.close() else: # Default: run MCP server - client = ensure_daemon() - mcp_server = create_mcp_server(client, str(project_root)) + mcp_server = create_mcp_server(str(project_root)) async def _serve() -> None: + from .cli import _bg_index + asyncio.create_task(_bg_index(str(project_root))) await mcp_server.run_stdio_async() asyncio.run(_serve()) - - -async def _bg_index(project_root: str) -> None: - """Index in background using a dedicated daemon connection. - - A fresh DaemonClient is used so that background indexing does not share - the multiprocessing connection used by foreground MCP requests, which - would corrupt data ("Input data was truncated"). - """ - from .client import ensure_daemon - - loop = asyncio.get_event_loop() - - def _run_index() -> None: - bg_client = ensure_daemon() - try: - bg_client.index(project_root) - finally: - bg_client.close() - - try: - await loop.run_in_executor(None, _run_index) - except Exception: - pass diff --git a/src/cocoindex_code/settings.py b/src/cocoindex_code/settings.py index dcb155a..0216487 100644 --- a/src/cocoindex_code/settings.py +++ b/src/cocoindex_code/settings.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Any -import cocoindex as _coco import yaml as _yaml from pathspec import GitIgnoreSpec @@ -90,9 +89,6 @@ class ProjectSettings: language_overrides: list[LanguageOverride] = field(default_factory=list) -# CocoIndex context key for project settings -PROJECT_SETTINGS = _coco.ContextKey[ProjectSettings]("project_settings") - # --------------------------------------------------------------------------- # Default factories # --------------------------------------------------------------------------- diff --git a/src/cocoindex_code/shared.py b/src/cocoindex_code/shared.py index 319ff17..b024526 100644 --- a/src/cocoindex_code/shared.py +++ b/src/cocoindex_code/shared.py @@ -10,7 +10,6 @@ import cocoindex as coco from cocoindex.connectors import sqlite from numpy.typing import NDArray -from pathspec import GitIgnoreSpec if TYPE_CHECKING: from cocoindex.ops.litellm import LiteLLMEmbedder @@ -32,8 +31,6 @@ EMBEDDER = coco.ContextKey[Embedder]("embedder") SQLITE_DB = coco.ContextKey[sqlite.ManagedConnection]("index_db", tracked=False) CODEBASE_DIR = coco.ContextKey[pathlib.Path]("codebase", tracked=False) -GITIGNORE_SPEC = coco.ContextKey[GitIgnoreSpec | None]("gitignore_spec", tracked=False) -EXT_LANG_OVERRIDE_MAP = coco.ContextKey[dict[str, str]]("ext_lang_override_map") # Module-level variable — set by daemon at startup (needed for CodeChunk annotation). embedder: Embedder | None = None @@ -85,4 +82,4 @@ class CodeChunk: content: str start_line: int end_line: int - embedding: Annotated[NDArray, embedder] # type: ignore[type-arg] + embedding: Annotated[NDArray, EMBEDDER] diff --git a/tests/test_bg_index.py b/tests/test_bg_index.py deleted file mode 100644 index ef3aace..0000000 --- a/tests/test_bg_index.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Tests for MCP background indexing connection isolation.""" - -from __future__ import annotations - -from unittest.mock import MagicMock, patch - -import pytest - - -class TestBgIndexIsolation: - """Verify _bg_index uses a dedicated DaemonClient, not the shared one.""" - - @pytest.mark.asyncio - async def test_cli_bg_index_creates_own_client(self) -> None: - """cli._bg_index should call ensure_daemon() for a fresh client.""" - from cocoindex_code.cli import _bg_index - - mock_client = MagicMock() - mock_client.index = MagicMock() - mock_client.close = MagicMock() - - with patch("cocoindex_code.client.ensure_daemon", return_value=mock_client) as mock_ensure: - await _bg_index("/tmp/project") - - mock_ensure.assert_called_once() - mock_client.index.assert_called_once_with("/tmp/project") - mock_client.close.assert_called_once() - - @pytest.mark.asyncio - async def test_cli_bg_index_closes_client_on_error(self) -> None: - """cli._bg_index should close the client even if indexing fails.""" - from cocoindex_code.cli import _bg_index - - mock_client = MagicMock() - mock_client.index = MagicMock(side_effect=RuntimeError("boom")) - mock_client.close = MagicMock() - - with patch("cocoindex_code.client.ensure_daemon", return_value=mock_client): - await _bg_index("/tmp/project") # should not raise - - mock_client.close.assert_called_once() - - @pytest.mark.asyncio - async def test_server_bg_index_creates_own_client(self) -> None: - """server._bg_index should call ensure_daemon() for a fresh client.""" - from cocoindex_code.server import _bg_index - - mock_client = MagicMock() - mock_client.index = MagicMock() - mock_client.close = MagicMock() - - with patch("cocoindex_code.client.ensure_daemon", return_value=mock_client) as mock_ensure: - await _bg_index("/tmp/project") - - mock_ensure.assert_called_once() - mock_client.index.assert_called_once_with("/tmp/project") - mock_client.close.assert_called_once() - - @pytest.mark.asyncio - async def test_server_bg_index_closes_client_on_error(self) -> None: - """server._bg_index should close the client even if indexing fails.""" - from cocoindex_code.server import _bg_index - - mock_client = MagicMock() - mock_client.index = MagicMock(side_effect=RuntimeError("boom")) - mock_client.close = MagicMock() - - with patch("cocoindex_code.client.ensure_daemon", return_value=mock_client): - await _bg_index("/tmp/project") # should not raise - - mock_client.close.assert_called_once() - - @pytest.mark.asyncio - async def test_bg_index_does_not_use_shared_client(self) -> None: - """The shared MCP client must NOT be passed to _bg_index.""" - from cocoindex_code.cli import _bg_index - - shared_client = MagicMock() - bg_client = MagicMock() - bg_client.index = MagicMock() - bg_client.close = MagicMock() - - with patch("cocoindex_code.client.ensure_daemon", return_value=bg_client): - await _bg_index("/tmp/project") - - # The shared client should never have been called - shared_client.index.assert_not_called() - # The bg client should have been used instead - bg_client.index.assert_called_once() diff --git a/tests/test_client.py b/tests/test_client.py index c4eb587..d007a15 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,98 +1,13 @@ -"""Tests for DaemonClient and ensure_daemon().""" +"""Tests for client connection handling.""" from __future__ import annotations -import sys import tempfile -import threading -import time -import uuid -from collections.abc import Iterator -from multiprocessing.connection import Client from pathlib import Path import pytest -from cocoindex_code._version import __version__ -from cocoindex_code.client import DaemonClient -from cocoindex_code.daemon import _connection_family -from cocoindex_code.protocol import ( - HandshakeRequest, - StopRequest, - encode_request, -) -from cocoindex_code.settings import ( - default_user_settings, - save_user_settings, -) - - -@pytest.fixture() -def daemon_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> tuple[Path, str, Path]: - """Set up daemon environment for client tests.""" - user_dir = tmp_path / "user_home" / ".cocoindex_code" - user_dir.mkdir(parents=True) - - if sys.platform == "win32": - sock_path = rf"\\.\pipe\ccc_client_{uuid.uuid4().hex[:12]}" - else: - sock_dir = Path(tempfile.mkdtemp(prefix="ccc_client_")) - sock_path = str(sock_dir / "d.sock") - pid_path = user_dir / "daemon.pid" - - monkeypatch.setattr("cocoindex_code.settings.user_settings_dir", lambda: user_dir) - monkeypatch.setattr( - "cocoindex_code.settings.user_settings_path", - lambda: user_dir / "global_settings.yml", - ) - save_user_settings(default_user_settings()) - - # Override socket/pid paths for short AF_UNIX paths - monkeypatch.setattr("cocoindex_code.daemon.daemon_socket_path", lambda: sock_path) - monkeypatch.setattr("cocoindex_code.client.daemon_socket_path", lambda: sock_path) - monkeypatch.setattr("cocoindex_code.client.daemon_pid_path", lambda: pid_path) - - return user_dir, sock_path, pid_path - - -@pytest.fixture() -def daemon_thread(daemon_env: tuple[Path, str, Path]) -> Iterator[str]: - """Start daemon in thread, yield socket path.""" - user_dir, sock_path, pid_path = daemon_env - - from cocoindex_code.daemon import run_daemon - - thread = threading.Thread(target=run_daemon, daemon=True) - thread.start() - - # Wait for socket/pipe - import os - - deadline = time.monotonic() + 30 - while time.monotonic() < deadline: - if os.path.exists(sock_path): - break - time.sleep(0.2) - - yield sock_path - - try: - conn = Client(sock_path, family=_connection_family()) - conn.send_bytes(encode_request(HandshakeRequest(version=__version__))) - conn.recv_bytes() - conn.send_bytes(encode_request(StopRequest())) - conn.recv_bytes() - conn.close() - except Exception: - pass - thread.join(timeout=5) - - -def test_client_connect_to_running_daemon(daemon_thread: str) -> None: - client = DaemonClient.connect() - resp = client.handshake() - assert resp.ok is True - client.close() +from cocoindex_code import client def test_client_connect_refuses_when_no_daemon( @@ -103,11 +18,4 @@ def test_client_connect_refuses_when_no_daemon( monkeypatch.setattr("cocoindex_code.client.daemon_socket_path", lambda: sock_path) with pytest.raises(ConnectionRefusedError): - DaemonClient.connect() - - -def test_client_close_is_idempotent(daemon_thread: str) -> None: - client = DaemonClient.connect() - client.handshake() - client.close() - client.close() # should not raise + client._connect_and_handshake() diff --git a/tests/test_daemon.py b/tests/test_daemon.py index 477ffa7..8a95ac9 100644 --- a/tests/test_daemon.py +++ b/tests/test_daemon.py @@ -220,13 +220,15 @@ def test_daemon_remove_project(daemon_sock: str, daemon_project: str) -> None: conn.send_bytes(encode_request(RemoveProjectRequest(project_root=daemon_project))) resp = decode_response(conn.recv_bytes()) assert resp.ok is True # type: ignore[union-attr] + conn.close() - # Verify project is gone from daemon status - conn.send_bytes(encode_request(DaemonStatusRequest())) - status = decode_response(conn.recv_bytes()) + # Verify project is gone from daemon status (fresh connection) + conn2, _ = _connect_and_handshake(daemon_sock) + conn2.send_bytes(encode_request(DaemonStatusRequest())) + status = decode_response(conn2.recv_bytes()) project_roots = [p.project_root for p in status.projects] # type: ignore[union-attr] assert daemon_project not in project_roots - conn.close() + conn2.close() def test_daemon_remove_project_not_loaded(daemon_sock: str) -> None: @@ -318,9 +320,12 @@ def test_daemon_search_waits_for_load_time_indexing(daemon_sock: str) -> None: assert len(final_resp.results) > 0 assert "main.py" in final_resp.results[0].file_path - # Second search — load-time indexing is done, no waiting expected - conn.send_bytes(encode_request(SearchRequest(project_root=str(project), query="fibonacci"))) - resp2 = decode_response(conn.recv_bytes()) + conn.close() + + # Second search — load-time indexing is done, no waiting expected (fresh connection) + conn2, _ = _connect_and_handshake(daemon_sock) + conn2.send_bytes(encode_request(SearchRequest(project_root=str(project), query="fibonacci"))) + resp2 = decode_response(conn2.recv_bytes()) assert isinstance(resp2, SearchResponse) assert resp2.success is True - conn.close() + conn2.close() diff --git a/tests/test_e2e_daemon.py b/tests/test_e2e_daemon.py index 042db1e..46908a3 100644 --- a/tests/test_e2e_daemon.py +++ b/tests/test_e2e_daemon.py @@ -1,8 +1,8 @@ """End-to-end tests for the CLI → daemon subprocess flow. These tests start a real daemon subprocess via ``start_daemon()`` and interact -with it through ``DaemonClient``, mirroring how ``ccc index`` / ``ccc search`` -actually work. +with it through the per-request client functions, mirroring how ``ccc index`` / +``ccc search`` actually work. """ from __future__ import annotations @@ -15,8 +15,9 @@ import pytest +from cocoindex_code import client from cocoindex_code._version import __version__ -from cocoindex_code.client import DaemonClient, start_daemon, stop_daemon +from cocoindex_code.client import start_daemon, stop_daemon from cocoindex_code.daemon import daemon_socket_path from cocoindex_code.settings import ( default_project_settings, @@ -79,21 +80,15 @@ def e2e_daemon() -> Iterator[tuple[str, Path]]: def test_daemon_subprocess_starts(e2e_daemon: tuple[str, Path]) -> None: - """The daemon should be reachable via DaemonClient after start_daemon().""" - client = DaemonClient.connect() - resp = client.handshake() - assert resp.ok - assert resp.daemon_version == __version__ - client.close() + """The daemon should be reachable via a fresh connection after start_daemon().""" + resp = client.daemon_status() + assert resp.version == __version__ def test_index_and_search_via_client(e2e_daemon: tuple[str, Path]) -> None: """Index a project and search via the client, same as ccc index / ccc search.""" _, project_dir = e2e_daemon - client = DaemonClient.connect() - client.handshake() - resp = client.index(str(project_dir)) assert resp.success @@ -105,23 +100,3 @@ def test_index_and_search_via_client(e2e_daemon: tuple[str, Path]) -> None: assert search_resp.success assert len(search_resp.results) > 0 assert "main.py" in search_resp.results[0].file_path - - client.close() - - -def test_daemon_survives_client_disconnect(e2e_daemon: tuple[str, Path]) -> None: - """Daemon should keep running after a client disconnects.""" - _, project_dir = e2e_daemon - - c1 = DaemonClient.connect() - c1.handshake() - c1.search(str(project_dir), query="fibonacci") - c1.close() - - c2 = DaemonClient.connect() - resp = c2.handshake() - assert resp.ok - search_resp = c2.search(str(project_dir), query="fibonacci") - assert search_resp.success - assert len(search_resp.results) > 0 - c2.close() From 6bce3652406d7bced7b5434d0cacf5773b3983d6 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Fri, 20 Mar 2026 11:41:46 -0700 Subject: [PATCH 2/2] fix: type annotation --- src/cocoindex_code/shared.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cocoindex_code/shared.py b/src/cocoindex_code/shared.py index b024526..b5ce756 100644 --- a/src/cocoindex_code/shared.py +++ b/src/cocoindex_code/shared.py @@ -8,8 +8,9 @@ from typing import TYPE_CHECKING, Annotated, Union import cocoindex as coco +import numpy as np +import numpy.typing as npt from cocoindex.connectors import sqlite -from numpy.typing import NDArray if TYPE_CHECKING: from cocoindex.ops.litellm import LiteLLMEmbedder @@ -82,4 +83,4 @@ class CodeChunk: content: str start_line: int end_line: int - embedding: Annotated[NDArray, EMBEDDER] + embedding: Annotated[npt.NDArray[np.float32], EMBEDDER]