diff --git a/pyproject.toml b/pyproject.toml index 516599d..f095021 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,9 +91,11 @@ select = ["E", "F", "I", "N", "W", "UP"] python_version = "3.11" strict = true ignore_missing_imports = true +explicit_package_bases = true [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] python_functions = ["test_*"] addopts = "-v --tb=short" +asyncio_mode = "auto" diff --git a/src/cocoindex_code/chunking.py b/src/cocoindex_code/chunking.py new file mode 100644 index 0000000..7b6580a --- /dev/null +++ b/src/cocoindex_code/chunking.py @@ -0,0 +1,31 @@ +"""Public API for writing custom chunkers. + +Example usage:: + + from pathlib import Path + from cocoindex_code.chunking import Chunk, ChunkerFn, TextPosition + + def my_chunker(path: Path, content: str) -> tuple[str | None, list[Chunk]]: + pos = TextPosition(byte_offset=0, char_offset=0, line=1, column=0) + return "mylang", [Chunk(text=content, start=pos, end=pos)] +""" + +from __future__ import annotations + +import pathlib as _pathlib +from collections.abc import Callable as _Callable + +import cocoindex as _coco + +from cocoindex.resources.chunk import Chunk +from cocoindex.resources.chunk import TextPosition + +# Callable alias (not Protocol) — consistent with codebase style. +# language_override=None keeps the language detected by detect_code_language. +# path is not resolved (no syscall); call path.resolve() inside the chunker if needed. +ChunkerFn = _Callable[[_pathlib.Path, str], tuple[str | None, list[Chunk]]] + +# tracked=False: callables are not fingerprint-able; daemon restart re-indexes anyway. +CHUNKER_REGISTRY = _coco.ContextKey[dict[str, ChunkerFn]]("chunker_registry", tracked=False) + +__all__ = ["Chunk", "ChunkerFn", "CHUNKER_REGISTRY", "TextPosition"] diff --git a/src/cocoindex_code/daemon.py b/src/cocoindex_code/daemon.py index 5974869..0f7242b 100644 --- a/src/cocoindex_code/daemon.py +++ b/src/cocoindex_code/daemon.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import importlib import logging import os import signal @@ -45,8 +46,10 @@ decode_request, encode_response, ) +from .chunking import ChunkerFn as _ChunkerFn from .query import query_codebase from .settings import ( + ChunkerMapping, global_settings_mtime_us, load_project_settings, load_user_settings, @@ -56,6 +59,26 @@ logger = logging.getLogger(__name__) + +def _resolve_chunker_registry(mappings: list[ChunkerMapping]) -> dict[str, _ChunkerFn]: + """Resolve ``ChunkerMapping`` settings entries to a ``{suffix: fn}`` dict. + + Each ``mapping.module`` must be a ``"module.path:callable"`` string importable + from the current environment. + """ + registry: dict[str, _ChunkerFn] = {} + for cm in mappings: + module_path, _, attr = cm.module.partition(":") + if not attr: + raise ValueError(f"chunker module {cm.module!r} must use 'module.path:callable' format") + mod = importlib.import_module(module_path) + fn = getattr(mod, attr) + if not callable(fn): + raise ValueError(f"chunker {cm.module!r}: {attr!r} is not callable") + registry[f".{cm.ext}"] = fn + return registry + + # --------------------------------------------------------------------------- # Daemon paths # --------------------------------------------------------------------------- @@ -123,7 +146,10 @@ async def get_project(self, project_root: str, *, suppress_auto_index: bool = Fa if project_root not in self._projects: root = Path(project_root) project_settings = load_project_settings(root) - project = await Project.create(root, project_settings, self._embedder) + chunker_registry = _resolve_chunker_registry(project_settings.chunkers) + project = await Project.create( + root, project_settings, self._embedder, chunker_registry=chunker_registry + ) self._projects[project_root] = project self._index_locks[project_root] = asyncio.Lock() self._load_time_done[project_root] = asyncio.Event() diff --git a/src/cocoindex_code/indexer.py b/src/cocoindex_code/indexer.py index 64ff3ad..70ad99f 100644 --- a/src/cocoindex_code/indexer.py +++ b/src/cocoindex_code/indexer.py @@ -14,6 +14,7 @@ from cocoindex.resources.id import IdGenerator from pathspec import GitIgnoreSpec +from .chunking import CHUNKER_REGISTRY from .settings import PROJECT_SETTINGS from .shared import ( CODEBASE_DIR, @@ -158,13 +159,20 @@ async def process_file( or "text" ) - chunks = splitter.split( - content, - chunk_size=CHUNK_SIZE, - min_chunk_size=MIN_CHUNK_SIZE, - chunk_overlap=CHUNK_OVERLAP, - language=language, - ) + chunker_registry = coco.use_context(CHUNKER_REGISTRY) + chunker = chunker_registry.get(suffix) + if chunker is not None: + language_override, chunks = chunker(Path(file.file_path.path), content) + if language_override is not None: + language = language_override + else: + chunks = splitter.split( + content, + chunk_size=CHUNK_SIZE, + min_chunk_size=MIN_CHUNK_SIZE, + chunk_overlap=CHUNK_OVERLAP, + language=language, + ) id_gen = IdGenerator() diff --git a/src/cocoindex_code/project.py b/src/cocoindex_code/project.py index f9f60a4..0079ea8 100644 --- a/src/cocoindex_code/project.py +++ b/src/cocoindex_code/project.py @@ -9,6 +9,7 @@ import cocoindex as coco from cocoindex.connectors import sqlite +from .chunking import CHUNKER_REGISTRY, ChunkerFn from .indexer import indexer_main from .protocol import IndexingProgress from .settings import PROJECT_SETTINGS, ProjectSettings, load_gitignore_spec @@ -86,8 +87,21 @@ async def create( project_root: Path, project_settings: ProjectSettings, embedder: Embedder, + chunker_registry: dict[str, ChunkerFn] | None = None, ) -> Project: - """Create a project with explicit settings and embedder.""" + """Create a project with explicit settings and embedder. + + Args: + project_root: Root directory of the codebase to index. + project_settings: Include/exclude patterns and language overrides. + embedder: Embedding model instance. + chunker_registry: Optional mapping of file suffix (e.g. ``".sls"``) + to a ``ChunkerFn``. When a suffix matches, the registered + chunker is called instead of the built-in ``RecursiveSplitter``. + Defaults to an empty registry. Shallow-copied on creation. + Passed as a parameter rather than via ``env`` to keep + ``env`` internals out of the public API. + """ index_dir = project_root / ".cocoindex_code" index_dir.mkdir(parents=True, exist_ok=True) @@ -107,6 +121,7 @@ async def create( {f".{lo.ext}": lo.lang for lo in project_settings.language_overrides}, ) context.provide(GITIGNORE_SPEC, gitignore_spec) + context.provide(CHUNKER_REGISTRY, dict(chunker_registry) if chunker_registry else {}) env = coco.Environment(settings, context_provider=context) app = coco.App( diff --git a/src/cocoindex_code/settings.py b/src/cocoindex_code/settings.py index dcb155a..eb29a23 100644 --- a/src/cocoindex_code/settings.py +++ b/src/cocoindex_code/settings.py @@ -83,11 +83,18 @@ class LanguageOverride: lang: str # e.g. "php" +@dataclass +class ChunkerMapping: + ext: str # without dot, e.g. "toml" + module: str # "module.path:callable", e.g. "cocoindex_code.toml_chunker:toml_chunker" + + @dataclass class ProjectSettings: include_patterns: list[str] = field(default_factory=lambda: list(DEFAULT_INCLUDED_PATTERNS)) exclude_patterns: list[str] = field(default_factory=lambda: list(DEFAULT_EXCLUDED_PATTERNS)) language_overrides: list[LanguageOverride] = field(default_factory=list) + chunkers: list[ChunkerMapping] = field(default_factory=list) # CocoIndex context key for project settings @@ -265,6 +272,8 @@ def _project_settings_to_dict(settings: ProjectSettings) -> dict[str, Any]: d["language_overrides"] = [ {"ext": lo.ext, "lang": lo.lang} for lo in settings.language_overrides ] + if settings.chunkers: + d["chunkers"] = [{"ext": cm.ext, "module": cm.module} for cm in settings.chunkers] return d @@ -272,10 +281,12 @@ def _project_settings_from_dict(d: dict[str, Any]) -> ProjectSettings: overrides = [ LanguageOverride(ext=lo["ext"], lang=lo["lang"]) for lo in d.get("language_overrides", []) ] + chunkers = [ChunkerMapping(ext=cm["ext"], module=cm["module"]) for cm in d.get("chunkers", [])] return ProjectSettings( include_patterns=d.get("include_patterns", list(DEFAULT_INCLUDED_PATTERNS)), exclude_patterns=d.get("exclude_patterns", list(DEFAULT_EXCLUDED_PATTERNS)), language_overrides=overrides, + chunkers=chunkers, ) diff --git a/tests/example_toml_chunker.py b/tests/example_toml_chunker.py new file mode 100644 index 0000000..7d2d84a --- /dev/null +++ b/tests/example_toml_chunker.py @@ -0,0 +1,45 @@ +"""Demo chunker: splits TOML files at top-level [section] boundaries. + +Each ``[section]`` header starts a new chunk, keeping the section header +and its key-value pairs together. This produces semantically coherent units +instead of the arbitrary line-window slices from the default splitter. + +Register in ``.cocoindex_code/settings.yml``:: + + chunkers: + - ext: toml + module: example_toml_chunker:toml_chunker +""" + +from __future__ import annotations + +import re as _re +from pathlib import Path as _Path + +from cocoindex_code.chunking import Chunk, TextPosition + +_SECTION_RE = _re.compile(r"^\[(?!\[)") + + +def _pos(line: int) -> TextPosition: + return TextPosition(byte_offset=0, char_offset=0, line=line, column=0) + + +def toml_chunker(path: _Path, content: str) -> tuple[str | None, list[Chunk]]: + """Split a TOML file at top-level ``[section]`` headers.""" + lines = content.splitlines() + section_starts = [i for i, ln in enumerate(lines) if _SECTION_RE.match(ln)] + + if not section_starts: + return "toml", [Chunk(text=content, start=_pos(1), end=_pos(len(lines)))] + + boundaries = section_starts + [len(lines)] + chunks: list[Chunk] = [] + for start_idx, end_idx in zip(boundaries, boundaries[1:]): + text = "\n".join(lines[start_idx:end_idx]).strip() + if text: + chunks.append(Chunk(text=text, start=_pos(start_idx + 1), end=_pos(end_idx))) + return "toml", chunks + + +__all__ = ["toml_chunker"] diff --git a/tests/test_chunker_registry.py b/tests/test_chunker_registry.py new file mode 100644 index 0000000..84c5fdd --- /dev/null +++ b/tests/test_chunker_registry.py @@ -0,0 +1,185 @@ +"""Tests for the pluggable chunker registry. + +Uses Project.create() directly with a mock embedder so no real embedding model +is needed. Each test writes files to a temp directory, indexes them, and +queries the resulting SQLite database to verify chunk content and language. +""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path +from typing import Any + +import numpy as np +import pytest + +import cocoindex_code.shared as _shared +from cocoindex.connectors import sqlite as coco_sqlite +from cocoindex.resources.schema import VectorSchema +from cocoindex_code.chunking import CHUNKER_REGISTRY, Chunk, TextPosition +from cocoindex_code.project import Project +from cocoindex_code.settings import ProjectSettings +from example_toml_chunker import toml_chunker + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_EMBED_DIM = 4 # tiny dimension — enough to satisfy the vector table schema + + +class _StubEmbedder: + """Minimal embedder stub satisfying cocoindex memo-key and vector-schema requirements.""" + + def __coco_memo_key__(self) -> str: + return "stub-embedder" + + async def __coco_vector_schema__(self) -> VectorSchema: + return VectorSchema(dtype=np.dtype("float32"), size=_EMBED_DIM) + + async def embed(self, text: str) -> np.ndarray: + return np.zeros(_EMBED_DIM, dtype=np.float32) + + +async def _index_project( + project_root: Path, + monkeypatch: pytest.MonkeyPatch, + **create_kwargs: Any, +) -> Project: + """Create a Project and run a full index pass.""" + settings = ProjectSettings( + include_patterns=["**/*.*"], + exclude_patterns=["**/.cocoindex_code"], + ) + stub = _StubEmbedder() + # shared.embedder is read by CodeChunk.embedding at schema resolution time. + monkeypatch.setattr(_shared, "embedder", stub) + project = await Project.create( + project_root, + settings, + stub, + **create_kwargs, + ) + await project.update_index() + return project + + +def _query_chunks(project_root: Path) -> list[dict[str, Any]]: + """Read all stored chunks from the target SQLite database.""" + db_path = project_root / ".cocoindex_code" / "target_sqlite.db" + conn = coco_sqlite.connect(str(db_path), load_vec=True) + try: + with conn.readonly() as db: + db.row_factory = sqlite3.Row + rows = db.execute( + "SELECT file_path, language, content, start_line, end_line FROM code_chunks_vec" + ).fetchall() + return [dict(row) for row in rows] + finally: + conn.close() + + +def _pos(line: int) -> TextPosition: + """TextPosition with only line number set; suitable for line-granularity chunkers.""" + return TextPosition(byte_offset=0, char_offset=0, line=line, column=0) + + +# --------------------------------------------------------------------------- +# TOML fixture content +# --------------------------------------------------------------------------- + +_TOML_CONTENT = """\ +[section_one] +key = "value" +answer = 42 + +[section_two] +other = "hello" +flag = true +""" + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +async def test_default_registry_is_empty(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """CHUNKER_REGISTRY is an empty dict when no registry is passed.""" + (tmp_path / ".git").mkdir() + (tmp_path / "hello.py").write_text("x = 1\n") + + project = await _index_project(tmp_path, monkeypatch) + registry = project.env.get_context(CHUNKER_REGISTRY) + assert isinstance(registry, dict) + assert registry == {} + + +async def test_unregistered_suffix_uses_splitter( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Files with no registered chunker are processed by RecursiveSplitter.""" + (tmp_path / ".git").mkdir() + (tmp_path / "sample.py").write_text("def foo():\n return 1\n") + + await _index_project(tmp_path, monkeypatch) + chunks = _query_chunks(tmp_path) + + assert len(chunks) >= 1 + assert all(c["language"] == "python" for c in chunks) + assert any("foo" in c["content"] for c in chunks) + + +async def test_registered_chunker_is_called( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """A registered ChunkerFn splits files and may override the language.""" + (tmp_path / ".git").mkdir() + (tmp_path / "config.toml").write_text(_TOML_CONTENT) + + await _index_project(tmp_path, monkeypatch, chunker_registry={".toml": toml_chunker}) + chunks = _query_chunks(tmp_path) + + assert len(chunks) == 2 + contents = {c["content"] for c in chunks} + assert any("section_one" in c for c in contents) + assert any("section_two" in c for c in contents) + assert all(c["language"] == "toml" for c in chunks) + + +async def test_chunker_language_none_preserves_detected( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """When ChunkerFn returns language=None, detect_code_language() is used.""" + + def _passthrough_chunker(path: Path, content: str) -> tuple[str | None, list[Chunk]]: + lines = content.splitlines() + return None, [Chunk(text=content, start=_pos(1), end=_pos(len(lines)))] + + (tmp_path / ".git").mkdir() + (tmp_path / "script.py").write_text("x = 1\n") + + await _index_project(tmp_path, monkeypatch, chunker_registry={".py": _passthrough_chunker}) + chunks = _query_chunks(tmp_path) + + assert all(c["language"] == "python" for c in chunks) + + +async def test_registry_does_not_affect_other_suffixes( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Registering a chunker for .toml does not affect .py files.""" + (tmp_path / ".git").mkdir() + (tmp_path / "config.toml").write_text(_TOML_CONTENT) + (tmp_path / "code.py").write_text("def bar():\n pass\n") + + await _index_project(tmp_path, monkeypatch, chunker_registry={".toml": toml_chunker}) + chunks = _query_chunks(tmp_path) + + toml_chunks = [c for c in chunks if c["language"] == "toml"] + py_chunks = [c for c in chunks if c["language"] == "python"] + + assert len(toml_chunks) == 2 + assert len(py_chunks) >= 1 + assert any("bar" in c["content"] for c in py_chunks) diff --git a/tests/test_daemon.py b/tests/test_daemon.py index 477ffa7..4c0c0b4 100644 --- a/tests/test_daemon.py +++ b/tests/test_daemon.py @@ -74,8 +74,8 @@ def daemon_sock() -> Iterator[str]: os.environ["COCOINDEX_CODE_DIR"] = str(user_dir) # Patch create_embedder to reuse the already-loaded embedder (performance) - _orig_create_embedder = dm.create_embedder # type: ignore[attr-defined] - dm.create_embedder = lambda settings: emb # type: ignore[attr-defined] + _orig_create_embedder = dm.create_embedder + dm.create_embedder = lambda settings: emb save_user_settings(default_user_settings()) @@ -107,7 +107,7 @@ def daemon_sock() -> Iterator[str]: thread.join(timeout=5) # Restore patches and env var - dm.create_embedder = _orig_create_embedder # type: ignore[attr-defined] + dm.create_embedder = _orig_create_embedder if old_env is None: os.environ.pop("COCOINDEX_CODE_DIR", None) else: @@ -156,8 +156,8 @@ def _connect_and_handshake(sock_path: str) -> tuple[Connection, Response]: def test_daemon_starts_and_accepts_handshake(daemon_sock: str) -> None: conn, resp = _connect_and_handshake(daemon_sock) - assert resp.ok is True # type: ignore[union-attr] - assert resp.daemon_version == __version__ # type: ignore[union-attr] + assert resp.ok is True + assert resp.daemon_version == __version__ conn.close() @@ -165,7 +165,7 @@ def test_daemon_rejects_version_mismatch(daemon_sock: str) -> None: conn = Client(daemon_sock, family=_connection_family()) conn.send_bytes(encode_request(HandshakeRequest(version="0.0.0-fake"))) resp = decode_response(conn.recv_bytes()) - assert resp.ok is False # type: ignore[union-attr] + assert resp.ok is False conn.close() @@ -173,8 +173,8 @@ def test_daemon_status(daemon_sock: str) -> None: conn, _ = _connect_and_handshake(daemon_sock) conn.send_bytes(encode_request(DaemonStatusRequest())) resp = decode_response(conn.recv_bytes()) - assert resp.version == __version__ # type: ignore[union-attr] - assert resp.uptime_seconds > 0 # type: ignore[union-attr] + assert resp.version == __version__ + assert resp.uptime_seconds > 0 conn.close() @@ -182,8 +182,8 @@ def test_daemon_project_status_after_index(daemon_sock: str, daemon_project: str conn, _ = _connect_and_handshake(daemon_sock) conn.send_bytes(encode_request(ProjectStatusRequest(project_root=daemon_project))) resp = decode_response(conn.recv_bytes()) - assert resp.total_chunks > 0 # type: ignore[union-attr] - assert resp.total_files > 0 # type: ignore[union-attr] + assert resp.total_chunks > 0 + assert resp.total_files > 0 conn.close() @@ -191,9 +191,9 @@ def test_daemon_search_after_index(daemon_sock: str, daemon_project: str) -> Non conn, _ = _connect_and_handshake(daemon_sock) conn.send_bytes(encode_request(SearchRequest(project_root=daemon_project, query="fibonacci"))) resp = decode_response(conn.recv_bytes()) - assert resp.success is True # type: ignore[union-attr] - assert len(resp.results) > 0 # type: ignore[union-attr] - assert "main.py" in resp.results[0].file_path # type: ignore[union-attr] + assert resp.success is True + assert len(resp.results) > 0 + assert "main.py" in resp.results[0].file_path conn.close() @@ -219,12 +219,12 @@ def test_daemon_remove_project(daemon_sock: str, daemon_project: str) -> None: conn, _ = _connect_and_handshake(daemon_sock) conn.send_bytes(encode_request(RemoveProjectRequest(project_root=daemon_project))) resp = decode_response(conn.recv_bytes()) - assert resp.ok is True # type: ignore[union-attr] + assert resp.ok is True # Verify project is gone from daemon status conn.send_bytes(encode_request(DaemonStatusRequest())) status = decode_response(conn.recv_bytes()) - project_roots = [p.project_root for p in status.projects] # type: ignore[union-attr] + project_roots = [p.project_root for p in status.projects] assert daemon_project not in project_roots conn.close() @@ -234,7 +234,7 @@ def test_daemon_remove_project_not_loaded(daemon_sock: str) -> None: conn, _ = _connect_and_handshake(daemon_sock) conn.send_bytes(encode_request(RemoveProjectRequest(project_root="/nonexistent/path"))) resp = decode_response(conn.recv_bytes()) - assert resp.ok is True # type: ignore[union-attr] + assert resp.ok is True conn.close() diff --git a/tests/test_settings.py b/tests/test_settings.py index fd56f31..8971439 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,9 +6,13 @@ import pytest +# _resolve_chunker_registry is private to daemon.py (single call site), but its +# error paths (bad format, non-callable) are not exercised by integration tests. +from cocoindex_code.daemon import _resolve_chunker_registry from cocoindex_code.settings import ( DEFAULT_EXCLUDED_PATTERNS, DEFAULT_INCLUDED_PATTERNS, + ChunkerMapping, EmbeddingSettings, LanguageOverride, ProjectSettings, @@ -198,3 +202,25 @@ def test_project_settings_with_language_overrides(tmp_path: Path) -> None: assert len(loaded.language_overrides) == 1 assert loaded.language_overrides[0].ext == "inc" assert loaded.language_overrides[0].lang == "php" + + +def test_project_settings_with_chunkers(tmp_path: Path) -> None: + settings = ProjectSettings( + chunkers=[ChunkerMapping(ext="toml", module="example_toml_chunker:toml_chunker")], + ) + save_project_settings(tmp_path, settings) + loaded = load_project_settings(tmp_path) + assert len(loaded.chunkers) == 1 + assert loaded.chunkers[0].ext == "toml" + assert loaded.chunkers[0].module == "example_toml_chunker:toml_chunker" + + +def test_resolve_chunker_registry_missing_colon() -> None: + with pytest.raises(ValueError, match="module.path:callable"): + _resolve_chunker_registry([ChunkerMapping(ext="toml", module="no_colon_here")]) + + +def test_resolve_chunker_registry_not_callable() -> None: + # os.path is a module attribute that is a string — not callable. + with pytest.raises(ValueError, match="not callable"): + _resolve_chunker_registry([ChunkerMapping(ext="toml", module="os:sep")])