diff --git a/pyproject.toml b/pyproject.toml index 20baecde..b772335e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ dependencies = [ "mcp>=1.23.3", "pydantic>=2.12.4", "jinja2>=3.1.6", - "docling>=2.70.0", ] default-optional-dependency-keys = [ "mysql", @@ -38,6 +37,9 @@ mysql = [ postgresql = [ "asyncpg>=0.31.0", ] +pdf = [ + "docling>=2.70.0", +] [build-system] requires = ["uv_build>=0.9.6,<0.10.0"] diff --git a/src/databao_context_engine/build_sources/build_runner.py b/src/databao_context_engine/build_sources/build_runner.py index d34c7cba..0ddbc45d 100644 --- a/src/databao_context_engine/build_sources/build_runner.py +++ b/src/databao_context_engine/build_sources/build_runner.py @@ -9,6 +9,10 @@ export_build_result, reset_all_results, ) +from databao_context_engine.datasources.datasource_context import ( + DatasourceContext, + read_datasource_type_from_context, +) from databao_context_engine.datasources.datasource_discovery import discover_datasources, prepare_source from databao_context_engine.datasources.types import DatasourceId from databao_context_engine.pluginlib.build_plugin import DatasourceType @@ -35,6 +39,16 @@ class BuildContextResult: context_file_path: Path +@dataclass +class IndexSummary: + """Summary of an indexing run over built contexts.""" + + total: int + indexed: int + skipped: int + failed: int + + def build( project_layout: ProjectLayout, *, @@ -47,8 +61,7 @@ def build( 1) Load available plugins 2) Discover sources - 3) Create a run - 4) For each source, call process_source + 3) For each source, call process_source Returns: A list of all the contexts built. @@ -113,3 +126,52 @@ def build( ) return build_result + + +def run_indexing( + *, project_layout: ProjectLayout, build_service: BuildService, contexts: list[DatasourceContext] +) -> IndexSummary: + """Index a list of built datasource contexts. + + 1) Load available plugins + 2) Infer datasource type from context file + 3) For each context, call index_built_context + + Returns: + A summary of the indexing run. + """ + plugins = load_plugins() + + summary = IndexSummary(total=len(contexts), indexed=0, skipped=0, failed=0) + + for context in contexts: + try: + logger.info(f"Indexing datasource {context.datasource_id}") + + datasource_type = read_datasource_type_from_context(context) + + plugin = plugins.get(datasource_type) + if plugin is None: + logger.warning( + "No plugin for datasource type '%s' — skipping indexing for %s.", + getattr(datasource_type, "full_type", datasource_type), + context.datasource_id, + ) + summary.skipped += 1 + continue + + build_service.index_built_context(context=context, plugin=plugin) + summary.indexed += 1 + except Exception as e: + logger.debug(str(e), exc_info=True, stack_info=True) + logger.info(f"Failed to build source at ({context.datasource_id}): {str(e)}") + summary.failed += 1 + + logger.debug( + "Successfully indexed %d/%d datasource(s). %s", + summary.indexed, + summary.total, + f"Skipped {summary.skipped}. Failed {summary.failed}." if (summary.skipped or summary.failed) else "", + ) + + return summary diff --git a/src/databao_context_engine/build_sources/build_service.py b/src/databao_context_engine/build_sources/build_service.py index be90240a..b80384f4 100644 --- a/src/databao_context_engine/build_sources/build_service.py +++ b/src/databao_context_engine/build_sources/build_service.py @@ -1,8 +1,14 @@ from __future__ import annotations import logging +from dataclasses import replace +from typing import Any + +import yaml +from pydantic import BaseModel, TypeAdapter from databao_context_engine.build_sources.plugin_execution import BuiltDatasourceContext, execute_plugin +from databao_context_engine.datasources.datasource_context import DatasourceContext from databao_context_engine.datasources.types import PreparedDatasource from databao_context_engine.pluginlib.build_plugin import ( BuildPlugin, @@ -42,6 +48,7 @@ def process_prepared_source( result = execute_plugin(self._project_layout, prepared_source, plugin) chunks = plugin.divide_context_into_chunks(result.context) + if not chunks: logger.info("No chunks for %s — skipping.", prepared_source.datasource_id.relative_path_to_config_file()) return result @@ -54,3 +61,46 @@ def process_prepared_source( ) return result + + def index_built_context(self, *, context: DatasourceContext, plugin: BuildPlugin) -> None: + """Index a context file using the given plugin. + + 1) Parses the yaml context file contents + 2) Reconstructs the `BuiltDatasourceContext` object + 3) Structures the inner `context` payload into the plugin's expected `context_type` + 4) Calls the plugin's chunker and persists the resulting chunks and embeddings. + """ + built = self._deserialize_built_context(context=context, context_type=plugin.context_type) + + chunks = plugin.divide_context_into_chunks(built.context) + if not chunks: + logger.info( + "No chunks for %s — skipping indexing.", context.datasource_id.relative_path_to_context_file().name + ) + return + + self._chunk_embedding_service.embed_chunks( + chunks=chunks, + result=context.context, + full_type=built.datasource_type, + datasource_id=built.datasource_id, + override=True, + ) + + def _deserialize_built_context( + self, + *, + context: DatasourceContext, + context_type: type[Any], + ) -> BuiltDatasourceContext: + """Parse the YAML payload and return a BuiltDatasourceContext with a typed `.context`.""" + raw_context = yaml.safe_load(context.context) + + built = TypeAdapter(BuiltDatasourceContext).validate_python(raw_context) + + if isinstance(context_type, type) and issubclass(context_type, BaseModel): + typed_context: Any = context_type.model_validate(built.context) + else: + typed_context = TypeAdapter(context_type).validate_python(built.context) + + return replace(built, context=typed_context) diff --git a/src/databao_context_engine/build_sources/build_wiring.py b/src/databao_context_engine/build_sources/build_wiring.py index 02773af2..b65887ff 100644 --- a/src/databao_context_engine/build_sources/build_wiring.py +++ b/src/databao_context_engine/build_sources/build_wiring.py @@ -2,8 +2,9 @@ from duckdb import DuckDBPyConnection -from databao_context_engine.build_sources.build_runner import BuildContextResult, build +from databao_context_engine.build_sources.build_runner import BuildContextResult, IndexSummary, build, run_indexing from databao_context_engine.build_sources.build_service import BuildService +from databao_context_engine.datasources.datasource_context import DatasourceContext from databao_context_engine.llm.descriptions.provider import DescriptionProvider from databao_context_engine.llm.embeddings.provider import EmbeddingProvider from databao_context_engine.llm.factory import ( @@ -64,6 +65,45 @@ def build_all_datasources( ) +def index_built_contexts( + project_layout: ProjectLayout, + contexts: list[DatasourceContext], + chunk_embedding_mode: ChunkEmbeddingMode, +) -> IndexSummary: + """Index the contexts into the database. + + - Instantiates the build service + - If the database does not exist, it creates it. + + Returns: + The summary of the indexing run. + """ + logger.debug("Starting to index %d context(s) for project %s", len(contexts), project_layout.project_dir.resolve()) + + db_path = get_db_path(project_layout.project_dir) + if not db_path.exists(): + db_path.parent.mkdir(parents=True, exist_ok=True) + migrate(db_path) + + with open_duckdb_connection(db_path) as conn: + ollama_service = create_ollama_service() + embedding_provider = create_ollama_embedding_provider(ollama_service) + description_provider = ( + create_ollama_description_provider(ollama_service) + if chunk_embedding_mode.should_generate_description() + else None + ) + + build_service = _create_build_service( + conn, + project_layout=project_layout, + embedding_provider=embedding_provider, + description_provider=description_provider, + chunk_embedding_mode=chunk_embedding_mode, + ) + return run_indexing(project_layout=project_layout, build_service=build_service, contexts=contexts) + + def _create_build_service( conn: DuckDBPyConnection, *, diff --git a/src/databao_context_engine/cli/commands.py b/src/databao_context_engine/cli/commands.py index d0e60b80..7563fa30 100644 --- a/src/databao_context_engine/cli/commands.py +++ b/src/databao_context_engine/cli/commands.py @@ -158,6 +158,37 @@ def build( click.echo(f"Build complete. Processed {len(result)} datasources.") +@dce.command() +@click.argument( + "datasources-config-files", + nargs=-1, + type=click.STRING, +) +@click.pass_context +def index(ctx: Context, datasources_config_files: tuple[str, ...]) -> None: + """Index and create embeddings for built context files into duckdb. + + If one or more datasource config files are provided, only those datasources will be indexed. + If no paths are provided, all built contexts found in the output directory will be indexed. + """ + datasource_ids = ( + [DatasourceId.from_string_repr(p) for p in datasources_config_files] if datasources_config_files else None + ) + + summary = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).index_built_contexts( + datasource_ids=datasource_ids + ) + + suffix = [] + if summary.skipped: + suffix.append(f"skipped {summary.skipped}") + if summary.failed: + suffix.append(f"failed {summary.failed}") + + extra = f" ({', '.join(suffix)})" if suffix else "" + click.echo(f"Indexing complete. Indexed {summary.indexed}/{summary.total} datasource(s){extra}.") + + @dce.command() @click.argument( "retrieve-text", diff --git a/src/databao_context_engine/databao_context_project_manager.py b/src/databao_context_engine/databao_context_project_manager.py index 5bb9fdc9..4755e4bf 100644 --- a/src/databao_context_engine/databao_context_project_manager.py +++ b/src/databao_context_engine/databao_context_project_manager.py @@ -3,6 +3,8 @@ from typing import Any, overload from databao_context_engine.build_sources import BuildContextResult, build_all_datasources +from databao_context_engine.build_sources.build_runner import IndexSummary +from databao_context_engine.build_sources.build_wiring import index_built_contexts from databao_context_engine.databao_context_engine import DatabaoContextEngine from databao_context_engine.datasources.check_config import ( CheckDatasourceConnectionResult, @@ -10,6 +12,7 @@ from databao_context_engine.datasources.check_config import ( check_datasource_connection as check_datasource_connection_internal, ) +from databao_context_engine.datasources.datasource_context import DatasourceContext from databao_context_engine.datasources.datasource_discovery import get_datasource_list from databao_context_engine.datasources.types import Datasource, DatasourceId from databao_context_engine.pluginlib.build_plugin import DatasourceType @@ -89,6 +92,34 @@ def build_context( # TODO: Filter which datasources to build by datasource_ids return build_all_datasources(project_layout=self._project_layout, chunk_embedding_mode=chunk_embedding_mode) + def index_built_contexts( + self, + datasource_ids: list[DatasourceId] | None = None, + chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY, + ) -> IndexSummary: + """Index built datasource contexts into the embeddings database. + + It reads already built context files from the output directory, chunks them using the appropriate plugin, + embeds the chunks and persists both the chunks and embeddings. + + Args: + datasource_ids: The list of datsource ids to index. If None, all datsources will be indexed. + chunk_embedding_mode: The mode to use for chunk embedding. + + Returns: + The summary of the index operation. + """ + engine: DatabaoContextEngine = self.get_engine_for_project() + contexts: list[DatasourceContext] = engine.get_all_contexts() + + if datasource_ids is not None: + wanted_paths = {d.datasource_path for d in datasource_ids} + contexts = [c for c in contexts if c.datasource_id.datasource_path in wanted_paths] + + return index_built_contexts( + project_layout=self._project_layout, contexts=contexts, chunk_embedding_mode=chunk_embedding_mode + ) + def check_datasource_connection( self, datasource_ids: list[DatasourceId] | None = None ) -> list[CheckDatasourceConnectionResult]: diff --git a/src/databao_context_engine/datasources/datasource_context.py b/src/databao_context_engine/datasources/datasource_context.py index a9978a83..4531195c 100644 --- a/src/databao_context_engine/datasources/datasource_context.py +++ b/src/databao_context_engine/datasources/datasource_context.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import Iterable import yaml @@ -25,15 +26,25 @@ class DatasourceContext: context: str -def _read_datasource_type_from_context_file(context_path: Path) -> DatasourceType: +def read_datasource_type_from_context_file(context_path: Path) -> DatasourceType: with context_path.open("r") as context_file: - type_key = "datasource_type" - for line in context_file: - if line.startswith(f"{type_key}: "): - datasource_type = yaml.safe_load(line)[type_key] - return DatasourceType(full_type=datasource_type) + return _read_datasource_type_from_lines(context_file, source_label=str(context_path)) - raise ValueError(f"Could not find type in context file {context_path}") + +def read_datasource_type_from_context(context: DatasourceContext) -> DatasourceType: + return _read_datasource_type_from_lines( + context.context.splitlines(True), + source_label=str(context.datasource_id), + ) + + +def _read_datasource_type_from_lines(lines: Iterable[str], *, source_label: str) -> DatasourceType: + type_key = "datasource_type" + for line in lines: + if line.startswith(f"{type_key}: "): + datasource_type = yaml.safe_load(line)[type_key] + return DatasourceType(full_type=datasource_type) + raise ValueError(f"Could not find type in context {source_label}") def get_introspected_datasource_list(project_layout: ProjectLayout) -> list[Datasource]: @@ -48,7 +59,7 @@ def get_introspected_datasource_list(project_layout: ProjectLayout) -> list[Data result.append( Datasource( id=DatasourceId.from_datasource_context_file_path(relative_context_file), - type=_read_datasource_type_from_context_file(context_file), + type=read_datasource_type_from_context_file(context_file), ) ) except ValueError as e: @@ -73,7 +84,10 @@ def get_all_contexts(project_layout: ProjectLayout) -> list[DatasourceContext]: result = [] for dirpath, dirnames, filenames in os.walk(project_layout.output_dir): for context_file_name in filenames: - if Path(context_file_name).suffix not in DatasourceId.ALLOWED_YAML_SUFFIXES: + if ( + Path(context_file_name).suffix not in DatasourceId.ALLOWED_YAML_SUFFIXES + or context_file_name == "all_results.yaml" + ): continue context_file = Path(dirpath).joinpath(context_file_name) relative_context_file = context_file.relative_to(project_layout.output_dir) diff --git a/src/databao_context_engine/pluginlib/build_plugin.py b/src/databao_context_engine/pluginlib/build_plugin.py index b1139b86..47aac7c2 100644 --- a/src/databao_context_engine/pluginlib/build_plugin.py +++ b/src/databao_context_engine/pluginlib/build_plugin.py @@ -20,6 +20,7 @@ class EmbeddableChunk: class BaseBuildPlugin(Protocol): id: str name: str + context_type: type[Any] def supported_types(self) -> set[str]: ... diff --git a/src/databao_context_engine/plugins/databases/base_db_plugin.py b/src/databao_context_engine/plugins/databases/base_db_plugin.py index 85723857..6ad215c5 100644 --- a/src/databao_context_engine/plugins/databases/base_db_plugin.py +++ b/src/databao_context_engine/plugins/databases/base_db_plugin.py @@ -11,6 +11,7 @@ from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector from databao_context_engine.plugins.databases.database_chunker import build_database_chunks +from databao_context_engine.plugins.databases.databases_types import DatabaseIntrospectionResult from databao_context_engine.plugins.databases.introspection_scope import IntrospectionScope @@ -29,6 +30,7 @@ class BaseDatabaseConfigFile(BaseModel): class BaseDatabasePlugin(BuildDatasourcePlugin[T]): name: str supported: set[str] + context_type = DatabaseIntrospectionResult def __init__(self, introspector: BaseIntrospector): self._introspector = introspector diff --git a/src/databao_context_engine/plugins/dbt/dbt_plugin.py b/src/databao_context_engine/plugins/dbt/dbt_plugin.py index 757795ca..3fc580e4 100644 --- a/src/databao_context_engine/plugins/dbt/dbt_plugin.py +++ b/src/databao_context_engine/plugins/dbt/dbt_plugin.py @@ -4,13 +4,14 @@ from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk from databao_context_engine.plugins.dbt.dbt_chunker import build_dbt_chunks from databao_context_engine.plugins.dbt.dbt_context_extractor import check_connection, extract_context -from databao_context_engine.plugins.dbt.types import DbtConfigFile +from databao_context_engine.plugins.dbt.types import DbtConfigFile, DbtContext class DbtPlugin(BuildDatasourcePlugin[DbtConfigFile]): id = "jetbrains/dbt" name = "Dbt Plugin" config_file_type = DbtConfigFile + context_type = DbtContext def supported_types(self) -> set[str]: return {"dbt"} diff --git a/src/databao_context_engine/plugins/files/pdf_plugin.py b/src/databao_context_engine/plugins/files/pdf_plugin.py index 3d408b1a..6e53bcf3 100644 --- a/src/databao_context_engine/plugins/files/pdf_plugin.py +++ b/src/databao_context_engine/plugins/files/pdf_plugin.py @@ -4,6 +4,7 @@ from docling.datamodel.base_models import DocumentStream, InputFormat from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.document_converter import DocumentConverter, PdfFormatOption +from docling_core.types import DoclingDocument from databao_context_engine import BuildFilePlugin from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk @@ -13,6 +14,7 @@ class PDFPlugin(BuildFilePlugin): id = "jetbrains/pdf" name = "PDF Plugin" + context_type = DoclingDocument def supported_types(self) -> set[str]: return {"pdf"} diff --git a/src/databao_context_engine/plugins/files/unstructured_files_plugin.py b/src/databao_context_engine/plugins/files/unstructured_files_plugin.py index 2f8e629a..6b01389a 100644 --- a/src/databao_context_engine/plugins/files/unstructured_files_plugin.py +++ b/src/databao_context_engine/plugins/files/unstructured_files_plugin.py @@ -13,6 +13,7 @@ class FileChunk(TypedDict): class InternalUnstructuredFilesPlugin(BuildFilePlugin): id = "jetbrains/unstructured_files" name = "Unstructured Files Plugin" + context_type = dict _SUPPORTED_FILES_EXTENSIONS = {"txt", "md"} diff --git a/src/databao_context_engine/plugins/plugin_loader.py b/src/databao_context_engine/plugins/plugin_loader.py index d04bfc0d..25316764 100644 --- a/src/databao_context_engine/plugins/plugin_loader.py +++ b/src/databao_context_engine/plugins/plugin_loader.py @@ -34,10 +34,18 @@ def _load_builtin_plugins(exclude_file_plugins: bool = False) -> list[BuildPlugi def _load_builtin_file_plugins() -> list[BuildFilePlugin]: - from databao_context_engine.plugins.files.pdf_plugin import PDFPlugin from databao_context_engine.plugins.files.unstructured_files_plugin import InternalUnstructuredFilesPlugin - return [InternalUnstructuredFilesPlugin(), PDFPlugin()] + plugins: list[BuildFilePlugin] = [InternalUnstructuredFilesPlugin()] + + try: + from databao_context_engine.plugins.files.pdf_plugin import PDFPlugin + + plugins.append(PDFPlugin()) + except ImportError: + pass + + return plugins def _load_builtin_datasource_plugins() -> list[BuildDatasourcePlugin]: diff --git a/src/databao_context_engine/plugins/resources/parquet_plugin.py b/src/databao_context_engine/plugins/resources/parquet_plugin.py index 11fe1b14..eb3e1afa 100644 --- a/src/databao_context_engine/plugins/resources/parquet_plugin.py +++ b/src/databao_context_engine/plugins/resources/parquet_plugin.py @@ -4,6 +4,7 @@ from databao_context_engine.plugins.resources.parquet_chunker import build_parquet_chunks from databao_context_engine.plugins.resources.parquet_introspector import ( ParquetConfigFile, + ParquetIntrospectionResult, ParquetIntrospector, parquet_type, ) @@ -13,6 +14,7 @@ class ParquetPlugin(BuildDatasourcePlugin[ParquetConfigFile]): id = "jetbrains/parquet" name = "Parquet Plugin" config_file_type = ParquetConfigFile + context_type = ParquetIntrospectionResult def __init__(self): self._introspector = ParquetIntrospector() diff --git a/src/databao_context_engine/serialization/yaml.py b/src/databao_context_engine/serialization/yaml.py index 70fc9d97..2950ed0c 100644 --- a/src/databao_context_engine/serialization/yaml.py +++ b/src/databao_context_engine/serialization/yaml.py @@ -1,4 +1,5 @@ from dataclasses import fields, is_dataclass +from enum import Enum from typing import Any, Mapping, TextIO, cast import yaml @@ -7,6 +8,8 @@ def default_representer(dumper: SafeDumper, data: object) -> Node: + if isinstance(data, Enum): + return dumper.represent_str(data.value) if isinstance(data, Mapping): return dumper.represent_dict(data) diff --git a/src/databao_context_engine/services/chunk_embedding_service.py b/src/databao_context_engine/services/chunk_embedding_service.py index 803acd11..8d84a841 100644 --- a/src/databao_context_engine/services/chunk_embedding_service.py +++ b/src/databao_context_engine/services/chunk_embedding_service.py @@ -57,7 +57,9 @@ def __init__( if self._chunk_embedding_mode.should_generate_description() and description_provider is None: raise ValueError("A DescriptionProvider must be provided when generating descriptions") - def embed_chunks(self, *, chunks: list[EmbeddableChunk], result: str, full_type: str, datasource_id: str) -> None: + def embed_chunks( + self, *, chunks: list[EmbeddableChunk], result: str, full_type: str, datasource_id: str, override: bool = False + ) -> None: """Turn plugin chunks into persisted chunks and embeddings. Flow: @@ -113,4 +115,5 @@ def embed_chunks(self, *, chunks: list[EmbeddableChunk], result: str, full_type: table_name=table_name, full_type=full_type, datasource_id=datasource_id, + override=override, ) diff --git a/src/databao_context_engine/services/persistence_service.py b/src/databao_context_engine/services/persistence_service.py index ea4c63ae..ffe0b832 100644 --- a/src/databao_context_engine/services/persistence_service.py +++ b/src/databao_context_engine/services/persistence_service.py @@ -24,10 +24,18 @@ def __init__( self._dim = dim def write_chunks_and_embeddings( - self, *, chunk_embeddings: list[ChunkEmbedding], table_name: str, full_type: str, datasource_id: str + self, + *, + chunk_embeddings: list[ChunkEmbedding], + table_name: str, + full_type: str, + datasource_id: str, + override: bool = False, ): """Atomically persist chunks and their vectors. + If override is True, delete existing chunks and embeddings for the datasource before persisting. + Raises: ValueError: If chunk_embeddings is an empty list. @@ -35,6 +43,14 @@ def write_chunks_and_embeddings( if not chunk_embeddings: raise ValueError("chunk_embeddings must be a non-empty list") + # Outside the transaction due to duckdb limitations. + # DuckDB FK checks can behave unexpectedly across multiple statements in the same transaction when deleting + # and re-inserting related rows. It also does not support on delete cascade yet. + # Given that there is a foreign key from embedding to chunk, the embedding must be deleted first. + if override: + self._embedding_repo.delete_by_datasource_id(table_name=table_name, datasource_id=datasource_id) + self._chunk_repo.delete_by_datasource_id(datasource_id=datasource_id) + with transaction(self._conn): for chunk_embedding in chunk_embeddings: chunk_dto = self.create_chunk( diff --git a/src/databao_context_engine/storage/repositories/chunk_repository.py b/src/databao_context_engine/storage/repositories/chunk_repository.py index 9051f604..cecb4098 100644 --- a/src/databao_context_engine/storage/repositories/chunk_repository.py +++ b/src/databao_context_engine/storage/repositories/chunk_repository.py @@ -108,6 +108,18 @@ def delete(self, chunk_id: int) -> int: ) return 1 if row else 0 + def delete_by_datasource_id(self, *, datasource_id: str) -> int: + deleted = self._conn.execute( + """ + DELETE FROM + chunk + WHERE + datasource_id = ? + """, + [datasource_id], + ).rowcount + return int(deleted or 0) + def list(self) -> list[ChunkDTO]: rows = self._conn.execute( """ diff --git a/src/databao_context_engine/storage/repositories/embedding_repository.py b/src/databao_context_engine/storage/repositories/embedding_repository.py index 2cd927c9..43e46a3e 100644 --- a/src/databao_context_engine/storage/repositories/embedding_repository.py +++ b/src/databao_context_engine/storage/repositories/embedding_repository.py @@ -89,6 +89,27 @@ def delete(self, *, table_name: str, chunk_id: int) -> int: ).fetchone() return 1 if row else 0 + def delete_by_datasource_id(self, *, table_name: str, datasource_id: str) -> int: + TableNamePolicy.validate_table_name(table_name=table_name) + + deleted = self._conn.execute( + f""" + DELETE FROM + {table_name} + WHERE + chunk_id IN ( + SELECT + chunk_id + FROM + chunk + WHERE + datasource_id = ? + ) + """, + [datasource_id], + ).rowcount + return int(deleted or 0) + def list(self, table_name: str) -> list[EmbeddingDTO]: TableNamePolicy.validate_table_name(table_name=table_name) rows = self._conn.execute( diff --git a/tests/build_sources/test_build_runner.py b/tests/build_sources/test_build_runner.py index 5053bbb4..af84e5d0 100644 --- a/tests/build_sources/test_build_runner.py +++ b/tests/build_sources/test_build_runner.py @@ -5,7 +5,7 @@ import pytest import yaml -from databao_context_engine import DatasourceId +from databao_context_engine import DatasourceContext, DatasourceId from databao_context_engine.build_sources import build_runner from databao_context_engine.build_sources.plugin_execution import BuiltDatasourceContext from databao_context_engine.datasources.types import PreparedFile @@ -133,3 +133,55 @@ def test_build_continues_on_service_exception( build_runner.build(project_layout=project_layout, build_service=mock_build_service) assert mock_build_service.process_prepared_source.call_count == 2 + + +def test_run_indexing_indexes_when_plugin_exists(mocker, mock_build_service, project_layout): + plugin = object() + ds_type = DatasourceType(full_type="files/md") + + mocker.patch.object(build_runner, "load_plugins", return_value={ds_type: plugin}) + mocker.patch.object(build_runner, "read_datasource_type_from_context", return_value=ds_type) + + ctx = DatasourceContext( + datasource_id=DatasourceId.from_string_repr("files/one.md"), + context="irrelevant for this test", + ) + + build_runner.run_indexing(project_layout=project_layout, build_service=mock_build_service, contexts=[ctx]) + + mock_build_service.index_built_context.assert_called_once_with(context=ctx, plugin=plugin) + + +def test_run_indexing_skips_when_plugin_missing(mocker, mock_build_service, project_layout, caplog): + ds_type = DatasourceType(full_type="files/md") + + mocker.patch.object(build_runner, "load_plugins", return_value={}) + mocker.patch.object(build_runner, "read_datasource_type_from_context", return_value=ds_type) + + ctx = DatasourceContext( + datasource_id=DatasourceId.from_string_repr("files/one.md"), + context="irrelevant for this test", + ) + + build_runner.run_indexing(project_layout=project_layout, build_service=mock_build_service, contexts=[ctx]) + + mock_build_service.index_built_context.assert_not_called() + + +def test_run_indexing_continues_on_exception(mocker, mock_build_service, project_layout): + plugin = object() + ds_type = DatasourceType(full_type="files/md") + + mocker.patch.object(build_runner, "load_plugins", return_value={ds_type: plugin}) + mocker.patch.object(build_runner, "read_datasource_type_from_context", return_value=ds_type) + + c1 = DatasourceContext(DatasourceId.from_string_repr("files/a.md"), context="a") + c2 = DatasourceContext(DatasourceId.from_string_repr("files/b.md"), context="b") + + mock_build_service.index_built_context.side_effect = [RuntimeError("boom"), None] + + build_runner.run_indexing(project_layout=project_layout, build_service=mock_build_service, contexts=[c1, c2]) + + assert mock_build_service.index_built_context.call_count == 2 + mock_build_service.index_built_context.assert_any_call(context=c1, plugin=plugin) + mock_build_service.index_built_context.assert_any_call(context=c2, plugin=plugin) diff --git a/tests/build_sources/test_build_service.py b/tests/build_sources/test_build_service.py index 3a23946a..384f7c13 100644 --- a/tests/build_sources/test_build_service.py +++ b/tests/build_sources/test_build_service.py @@ -3,8 +3,9 @@ from pathlib import Path import pytest +import yaml -from databao_context_engine import DatasourceId +from databao_context_engine import DatasourceContext, DatasourceId from databao_context_engine.build_sources.build_service import BuildService from databao_context_engine.build_sources.plugin_execution import BuiltDatasourceContext from databao_context_engine.datasources.types import PreparedDatasource, PreparedFile @@ -103,3 +104,58 @@ def test_process_prepared_source_embed_error_bubbles_after_row_creation(svc, chu with pytest.raises(RuntimeError): svc.process_prepared_source(prepared_source=prepared, plugin=plugin) + + +def test_index_built_context_happy_path_embeds(svc, chunk_embed_svc, mocker): + plugin = mocker.Mock(name="Plugin") + plugin.name = "pluggy" + plugin.context_type = dict + + built_at = datetime(2026, 2, 4, 12, 0, 0) + raw = { + "datasource_id": "files/two.md", + "datasource_type": "files/md", + "context_built_at": built_at, + "context": {"hello": "world"}, + } + yaml_text = yaml.safe_dump(raw) + + dsid = DatasourceId.from_string_repr("files/two.md") + ctx = DatasourceContext(datasource_id=dsid, context=yaml_text) + + chunks = [EmbeddableChunk("a", "A"), EmbeddableChunk("b", "B")] + plugin.divide_context_into_chunks.return_value = chunks + + svc.index_built_context(context=ctx, plugin=plugin) + + plugin.divide_context_into_chunks.assert_called_once_with({"hello": "world"}) + chunk_embed_svc.embed_chunks.assert_called_once_with( + chunks=chunks, + result=yaml_text, + full_type="files/md", + datasource_id="files/two.md", + override=True, + ) + + +def test_index_built_context_no_chunks_skips_embed(svc, chunk_embed_svc, mocker): + plugin = mocker.Mock(name="Plugin") + plugin.name = "pluggy" + plugin.context_type = dict + + raw = { + "datasource_id": "files/empty.md", + "datasource_type": "files/md", + "context_built_at": datetime(2026, 2, 4, 12, 0, 0), + "context": {"nothing": True}, + } + yaml_text = yaml.safe_dump(raw) + + dsid = DatasourceId.from_string_repr("files/empty.md") + ctx = DatasourceContext(datasource_id=dsid, context=yaml_text) + + plugin.divide_context_into_chunks.return_value = [] + + svc.index_built_context(context=ctx, plugin=plugin) + + chunk_embed_svc.embed_chunks.assert_not_called() diff --git a/tests/plugins/test_plugin_loader.py b/tests/plugins/test_plugin_loader.py index 923c7285..abff3df1 100644 --- a/tests/plugins/test_plugin_loader.py +++ b/tests/plugins/test_plugin_loader.py @@ -56,7 +56,6 @@ def test_loaded_plugins_no_extra(): assert plugin_ids == { "jetbrains/duckdb", "jetbrains/parquet", - "jetbrains/pdf", "jetbrains/sqlite", "jetbrains/unstructured_files", "jetbrains/dbt", diff --git a/tests/serialization/test_yaml.py b/tests/serialization/test_yaml.py index a8b521ff..dd5dccb6 100644 --- a/tests/serialization/test_yaml.py +++ b/tests/serialization/test_yaml.py @@ -75,7 +75,7 @@ def get_expected(my_uuid, now): my_str: hello my_nested_class: nested_var: nested - enum_value: MyEnum.KEY_2 + enum_value: VALUE_2 my_int: 12 my_uuid: {str(my_uuid)} my_date: {now.isoformat(" ")} diff --git a/tests/services/test_persistence_service.py b/tests/services/test_persistence_service.py index ebacd112..daad17cd 100644 --- a/tests/services/test_persistence_service.py +++ b/tests/services/test_persistence_service.py @@ -145,5 +145,52 @@ def __repr__(self) -> str: assert len(rows) == len(complex_items) +def test_write_chunks_and_embeddings_override_replaces_datasource_rows( + persistence, chunk_repo, embedding_repo, table_name +): + ds1_pairs = [ + ChunkEmbedding(EmbeddableChunk("A", "a"), _vec(0.0), display_text="a", generated_description="g"), + ChunkEmbedding(EmbeddableChunk("B", "b"), _vec(1.0), display_text="b", generated_description="g"), + ] + ds2_pairs = [ + ChunkEmbedding(EmbeddableChunk("X", "x"), _vec(2.0), display_text="x", generated_description="g"), + ] + + persistence.write_chunks_and_embeddings( + chunk_embeddings=ds1_pairs, table_name=table_name, full_type="files/md", datasource_id="ds1" + ) + persistence.write_chunks_and_embeddings( + chunk_embeddings=ds2_pairs, table_name=table_name, full_type="files/md", datasource_id="ds2" + ) + + saved_before = chunk_repo.list() + old_ds1_chunk_ids = {c.chunk_id for c in saved_before if c.datasource_id == "ds1"} + assert len(old_ds1_chunk_ids) == 2 + + new_ds1_pairs = [ + ChunkEmbedding(EmbeddableChunk("C", "c"), _vec(3.0), display_text="c", generated_description="g"), + ] + persistence.write_chunks_and_embeddings( + chunk_embeddings=new_ds1_pairs, + table_name=table_name, + full_type="files/md", + datasource_id="ds1", + override=True, + ) + + saved_after = chunk_repo.list() + + ds1_rows = [c for c in saved_after if c.datasource_id == "ds1"] + assert [c.embeddable_text for c in ds1_rows] == ["C"] + assert {c.chunk_id for c in ds1_rows}.isdisjoint(old_ds1_chunk_ids) + + ds2_rows = [c for c in saved_after if c.datasource_id == "ds2"] + assert [c.embeddable_text for c in ds2_rows] == ["X"] + + embedding_rows = embedding_repo.list(table_name=table_name) + assert all(row.chunk_id not in old_ds1_chunk_ids for row in embedding_rows) + assert len(embedding_rows) == len(ds1_rows) + len(ds2_rows) + + def _vec(fill: float, dim: int = 768) -> list[float]: return [fill] * dim diff --git a/tests/storage/repositories/test_chunk_repository.py b/tests/storage/repositories/test_chunk_repository.py index 1a3ba9e2..6be9bf38 100644 --- a/tests/storage/repositories/test_chunk_repository.py +++ b/tests/storage/repositories/test_chunk_repository.py @@ -45,3 +45,20 @@ def test_list(chunk_repo): all_rows = chunk_repo.list() assert [s.chunk_id for s in all_rows] == [s3.chunk_id, s2.chunk_id, s1.chunk_id] + + +def test_delete_by_datasource_id(chunk_repo): + d1_a = chunk_repo.create(full_type="type/md", datasource_id="ds1", embeddable_text="a", display_text="a") + d1_b = chunk_repo.create(full_type="type/md", datasource_id="ds1", embeddable_text="b", display_text="b") + d2_c = chunk_repo.create(full_type="type/md", datasource_id="ds2", embeddable_text="c", display_text="c") + + chunk_repo.delete_by_datasource_id(datasource_id="ds1") + + remaining = chunk_repo.list() + remaining_ids = {c.chunk_id for c in remaining} + + assert d1_a.chunk_id not in remaining_ids + assert d1_b.chunk_id not in remaining_ids + assert d2_c.chunk_id in remaining_ids + + assert {c.datasource_id for c in remaining} == {"ds2"} diff --git a/tests/storage/repositories/test_embedding_repository.py b/tests/storage/repositories/test_embedding_repository.py index 3542d080..3bc15137 100644 --- a/tests/storage/repositories/test_embedding_repository.py +++ b/tests/storage/repositories/test_embedding_repository.py @@ -71,6 +71,25 @@ def test_update_with_missing_table_raises(embedding_repo): embedding_repo.update(table_name="123", chunk_id=1, vec=_vec(0.0)) +def test_delete_by_datasource_id(embedding_repo, chunk_repo, table_name): + ds1_a = make_chunk(chunk_repo, full_type="type/f", datasource_id="ds1", embeddable_text="a", display_text="a") + ds1_b = make_chunk(chunk_repo, full_type="type/f", datasource_id="ds1", embeddable_text="b", display_text="b") + ds2_c = make_chunk(chunk_repo, full_type="type/f", datasource_id="ds2", embeddable_text="c", display_text="c") + + embedding_repo.create(table_name=table_name, chunk_id=ds1_a.chunk_id, vec=_vec(1.0)) + embedding_repo.create(table_name=table_name, chunk_id=ds1_b.chunk_id, vec=_vec(2.0)) + embedding_repo.create(table_name=table_name, chunk_id=ds2_c.chunk_id, vec=_vec(3.0)) + + embedding_repo.delete_by_datasource_id(table_name=table_name, datasource_id="ds1") + + remaining = embedding_repo.list(table_name=table_name) + remaining_ids = {e.chunk_id for e in remaining} + + assert ds1_a.chunk_id not in remaining_ids + assert ds1_b.chunk_id not in remaining_ids + assert ds2_c.chunk_id in remaining_ids + + def _vec(fill: float | None = None, *, pattern_start: float | None = None) -> list[float]: dim = 768 if fill is not None: diff --git a/tests/test_databao_context_project_manager.py b/tests/test_databao_context_project_manager.py index c2fac999..13dcd3f1 100644 --- a/tests/test_databao_context_project_manager.py +++ b/tests/test_databao_context_project_manager.py @@ -8,6 +8,7 @@ ChunkEmbeddingMode, DatabaoContextProjectManager, Datasource, + DatasourceContext, DatasourceId, DatasourceType, ) @@ -108,6 +109,64 @@ def test_databao_context_project_manager__build_with_multiple_datasource(project ) +def test_databao_context_project_manager__index_built_contexts_indexes_all_when_no_ids(project_path, mocker): + pm = DatabaoContextProjectManager(project_dir=project_path) + + c1 = DatasourceContext(DatasourceId.from_string_repr("full/a.yaml"), context="A") + c2 = DatasourceContext(DatasourceId.from_string_repr("other/b.yaml"), context="B") + + engine = mocker.Mock() + engine.get_all_contexts.return_value = [c1, c2] + mocker.patch.object(pm, "get_engine_for_project", return_value=engine) + + index_fn = mocker.patch( + "databao_context_engine.databao_context_project_manager.index_built_contexts", + autospec=True, + return_value="OK", + ) + + result = pm.index_built_contexts(datasource_ids=None) + + assert result == "OK" + index_fn.assert_called_once_with( + project_layout=pm._project_layout, + contexts=[c1, c2], + chunk_embedding_mode=ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY, + ) + + +def test_databao_context_project_manager__index_built_contexts_filters_by_datasource_path(project_path, mocker): + pm = DatabaoContextProjectManager(project_dir=project_path) + + c1 = DatasourceContext(DatasourceId.from_string_repr("full/a.yaml"), context="A") + c2 = DatasourceContext(DatasourceId.from_string_repr("other/b.yaml"), context="B") + c3 = DatasourceContext(DatasourceId.from_string_repr("full/c.yaml"), context="C") + + engine = mocker.Mock() + engine.get_all_contexts.return_value = [c1, c2, c3] + mocker.patch.object(pm, "get_engine_for_project", return_value=engine) + + index_fn = mocker.patch( + "databao_context_engine.databao_context_project_manager.index_built_contexts", + autospec=True, + return_value="OK", + ) + + wanted = [ + DatasourceId.from_string_repr("full/a.yaml"), + DatasourceId.from_string_repr("full/c.yaml"), + ] + + result = pm.index_built_contexts(datasource_ids=wanted) + + assert result == "OK" + index_fn.assert_called_once_with( + project_layout=pm._project_layout, + contexts=[c1, c3], + chunk_embedding_mode=ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY, + ) + + def assert_build_context_result( context_result: BuildContextResult, project_dir: Path, diff --git a/tests/utils/dummy_build_plugin.py b/tests/utils/dummy_build_plugin.py index 6adfe949..2659bb66 100644 --- a/tests/utils/dummy_build_plugin.py +++ b/tests/utils/dummy_build_plugin.py @@ -57,6 +57,7 @@ class DummyBuildDatasourcePlugin(BuildDatasourcePlugin[DummyConfigFileType]): id = "jetbrains/dummy_db" name = "Dummy DB Plugin" config_file_type = DummyConfigFileType + context_type = dict def supported_types(self) -> set[str]: return {"dummy_db"} @@ -101,6 +102,7 @@ def divide_context_into_chunks(self, context: Any) -> list[EmbeddableChunk]: class DummyDefaultDatasourcePlugin(DefaultBuildDatasourcePlugin): id = "jetbrains/dummy_default" name = "Dummy Plugin with a default type" + context_type = dict def supported_types(self) -> set[str]: return {"dummy_default"} @@ -115,6 +117,7 @@ def divide_context_into_chunks(self, context: Any) -> list[EmbeddableChunk]: class DummyFilePlugin(BuildFilePlugin): id = "jetbrains/dummy_file" name = "Dummy Plugin with a default type" + context_type = dict def supported_types(self) -> set[str]: return {"dummy_txt"} @@ -135,6 +138,7 @@ class AdditionalDummyPlugin(BuildDatasourcePlugin[AdditionalDummyConfigFile]): id = "additional/dummy" name = "Additional Dummy Plugin" config_file_type = AdditionalDummyConfigFile + context_type = dict def supported_types(self) -> set[str]: return {"additional_dummy_type"} @@ -149,6 +153,7 @@ def divide_context_into_chunks(self, context: Any) -> list[EmbeddableChunk]: class DummyPluginWithNoConfigType(DefaultBuildDatasourcePlugin, CustomiseConfigProperties): id = "dummy/no_config_type" name = "Dummy Plugin With No Config Type" + context_type = dict def supported_types(self) -> set[str]: return {"no_config_type"} diff --git a/uv.lock b/uv.lock index 5407a648..b7858a05 100644 --- a/uv.lock +++ b/uv.lock @@ -486,7 +486,6 @@ version = "0.1.7" source = { editable = "." } dependencies = [ { name = "click" }, - { name = "docling" }, { name = "duckdb" }, { name = "jinja2" }, { name = "mcp" }, @@ -508,6 +507,9 @@ mssql = [ mysql = [ { name = "pymysql" }, ] +pdf = [ + { name = "docling" }, +] postgresql = [ { name = "asyncpg" }, ] @@ -534,7 +536,7 @@ requires-dist = [ { name = "asyncpg", marker = "extra == 'postgresql'", specifier = ">=0.31.0" }, { name = "click", specifier = ">=8.3.0" }, { name = "clickhouse-connect", marker = "extra == 'clickhouse'", specifier = ">=0.10.0" }, - { name = "docling", specifier = ">=2.70.0" }, + { name = "docling", marker = "extra == 'pdf'", specifier = ">=2.70.0" }, { name = "duckdb", specifier = ">=1.4.3" }, { name = "jinja2", specifier = ">=3.1.6" }, { name = "mcp", specifier = ">=1.23.3" }, @@ -546,7 +548,7 @@ requires-dist = [ { name = "requests", specifier = ">=2.32.5" }, { name = "snowflake-connector-python", marker = "extra == 'snowflake'", specifier = ">=4.2.0" }, ] -provides-extras = ["mssql", "clickhouse", "athena", "snowflake", "mysql", "postgresql"] +provides-extras = ["mssql", "clickhouse", "athena", "snowflake", "mysql", "postgresql", "pdf"] [package.metadata.requires-dev] dev = [ @@ -3049,6 +3051,8 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/2f/0b295dd8d199ef71e6f176f576473d645d41357b7b8aa978cc6b042575df/torch-2.10.0-1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6abb224c2b6e9e27b592a1c0015c33a504b00a0e0938f1499f7f514e9b7bfb5c", size = 79498197, upload-time = "2026-02-06T17:37:27.627Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1b/af5fccb50c341bd69dc016769503cb0857c1423fbe9343410dfeb65240f2/torch-2.10.0-1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7350f6652dfd761f11f9ecb590bfe95b573e2961f7a242eccb3c8e78348d26fe", size = 79498248, upload-time = "2026-02-06T17:37:31.982Z" }, { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" },