Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies = [
"mcp>=1.23.3",
"pydantic>=2.12.4",
"jinja2>=3.1.6",
"docling>=2.70.0",
]
default-optional-dependency-keys = [
"mysql",
Expand All @@ -38,6 +37,9 @@ mysql = [
postgresql = [
"asyncpg>=0.31.0",
]
pdf = [
"docling>=2.70.0",
]

[build-system]
requires = ["uv_build>=0.9.6,<0.10.0"]
Expand Down
66 changes: 64 additions & 2 deletions src/databao_context_engine/build_sources/build_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
export_build_result,
reset_all_results,
)
from databao_context_engine.datasources.datasource_context import (
DatasourceContext,
read_datasource_type_from_context,
)
from databao_context_engine.datasources.datasource_discovery import discover_datasources, prepare_source
from databao_context_engine.datasources.types import DatasourceId
from databao_context_engine.pluginlib.build_plugin import DatasourceType
Expand All @@ -35,6 +39,16 @@ class BuildContextResult:
context_file_path: Path


@dataclass
class IndexSummary:
"""Summary of an indexing run over built contexts."""

total: int
indexed: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of int for each of these properties, should we return a list of DatasourceId?
We can still keep the indexed as a calculated property for quick access if we want

@dataclass
class IndexSummary:
    """Summary of an indexing run over built contexts."""

    indexed: set[DatasourceId]
    skipped: set[DatasourceId]
    failed: set[DatasourceId]

    @property
    def number_indexed() -> int:
        return len(indexed)
    
    ...

At the very least, I think we should be able to know which datasource failed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is already being logged in the exception catch.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is already being logged in the exception catch.

But logs is an information that might or might not be shown depending on who is calling: I'm guessing we will still show info logs in the console of the CLI (but since the CLI is going to live in a separate repo, we shouldn't make any strong assumptions on this). But what about the agents? I don't think they would output the logs anywhere.

And since it's not returned in Python, it's not usable by any callers of the method. The most obvious usage in Python would be to retry indexing only the datasources that failed.

skipped: int
failed: int


def build(
project_layout: ProjectLayout,
*,
Expand All @@ -47,8 +61,7 @@ def build(

1) Load available plugins
2) Discover sources
3) Create a run
4) For each source, call process_source
3) For each source, call process_source

Returns:
A list of all the contexts built.
Expand Down Expand Up @@ -113,3 +126,52 @@ def build(
)

return build_result


def run_indexing(
*, project_layout: ProjectLayout, build_service: BuildService, contexts: list[DatasourceContext]
) -> IndexSummary:
"""Index a list of built datasource contexts.

1) Load available plugins
2) Infer datasource type from context file
3) For each context, call index_built_context

Returns:
A summary of the indexing run.
"""
plugins = load_plugins()

summary = IndexSummary(total=len(contexts), indexed=0, skipped=0, failed=0)

for context in contexts:
try:
logger.info(f"Indexing datasource {context.datasource_id}")

datasource_type = read_datasource_type_from_context(context)

plugin = plugins.get(datasource_type)
if plugin is None:
logger.warning(
"No plugin for datasource type '%s' — skipping indexing for %s.",
getattr(datasource_type, "full_type", datasource_type),
context.datasource_id,
)
summary.skipped += 1
continue

build_service.index_built_context(context=context, plugin=plugin)
summary.indexed += 1
except Exception as e:
logger.debug(str(e), exc_info=True, stack_info=True)
logger.info(f"Failed to build source at ({context.datasource_id}): {str(e)}")
summary.failed += 1

logger.debug(
"Successfully indexed %d/%d datasource(s). %s",
summary.indexed,
summary.total,
f"Skipped {summary.skipped}. Failed {summary.failed}." if (summary.skipped or summary.failed) else "",
)

return summary
50 changes: 50 additions & 0 deletions src/databao_context_engine/build_sources/build_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations

import logging
from dataclasses import replace
from typing import Any

import yaml
from pydantic import BaseModel, TypeAdapter

from databao_context_engine.build_sources.plugin_execution import BuiltDatasourceContext, execute_plugin
from databao_context_engine.datasources.datasource_context import DatasourceContext
from databao_context_engine.datasources.types import PreparedDatasource
from databao_context_engine.pluginlib.build_plugin import (
BuildPlugin,
Expand Down Expand Up @@ -42,6 +48,7 @@ def process_prepared_source(
result = execute_plugin(self._project_layout, prepared_source, plugin)

chunks = plugin.divide_context_into_chunks(result.context)

if not chunks:
logger.info("No chunks for %s — skipping.", prepared_source.datasource_id.relative_path_to_config_file())
return result
Expand All @@ -54,3 +61,46 @@ def process_prepared_source(
)

return result

def index_built_context(self, *, context: DatasourceContext, plugin: BuildPlugin) -> None:
"""Index a context file using the given plugin.

1) Parses the yaml context file contents
2) Reconstructs the `BuiltDatasourceContext` object
3) Structures the inner `context` payload into the plugin's expected `context_type`
4) Calls the plugin's chunker and persists the resulting chunks and embeddings.
"""
built = self._deserialize_built_context(context=context, context_type=plugin.context_type)

chunks = plugin.divide_context_into_chunks(built.context)
if not chunks:
logger.info(
"No chunks for %s — skipping indexing.", context.datasource_id.relative_path_to_context_file().name
)
return

self._chunk_embedding_service.embed_chunks(
chunks=chunks,
result=context.context,
full_type=built.datasource_type,
datasource_id=built.datasource_id,
override=True,
)

def _deserialize_built_context(
self,
*,
context: DatasourceContext,
context_type: type[Any],
) -> BuiltDatasourceContext:
"""Parse the YAML payload and return a BuiltDatasourceContext with a typed `.context`."""
raw_context = yaml.safe_load(context.context)

built = TypeAdapter(BuiltDatasourceContext).validate_python(raw_context)

if isinstance(context_type, type) and issubclass(context_type, BaseModel):
typed_context: Any = context_type.model_validate(built.context)
else:
typed_context = TypeAdapter(context_type).validate_python(built.context)

return replace(built, context=typed_context)
42 changes: 41 additions & 1 deletion src/databao_context_engine/build_sources/build_wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from duckdb import DuckDBPyConnection

from databao_context_engine.build_sources.build_runner import BuildContextResult, build
from databao_context_engine.build_sources.build_runner import BuildContextResult, IndexSummary, build, run_indexing
from databao_context_engine.build_sources.build_service import BuildService
from databao_context_engine.datasources.datasource_context import DatasourceContext
from databao_context_engine.llm.descriptions.provider import DescriptionProvider
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
from databao_context_engine.llm.factory import (
Expand Down Expand Up @@ -64,6 +65,45 @@ def build_all_datasources(
)


def index_built_contexts(
project_layout: ProjectLayout,
contexts: list[DatasourceContext],
chunk_embedding_mode: ChunkEmbeddingMode,
) -> IndexSummary:
"""Index the contexts into the database.

- Instantiates the build service
- If the database does not exist, it creates it.

Returns:
The summary of the indexing run.
"""
logger.debug("Starting to index %d context(s) for project %s", len(contexts), project_layout.project_dir.resolve())

db_path = get_db_path(project_layout.project_dir)
if not db_path.exists():
db_path.parent.mkdir(parents=True, exist_ok=True)
migrate(db_path)

with open_duckdb_connection(db_path) as conn:
ollama_service = create_ollama_service()
embedding_provider = create_ollama_embedding_provider(ollama_service)
description_provider = (
create_ollama_description_provider(ollama_service)
if chunk_embedding_mode.should_generate_description()
else None
)

build_service = _create_build_service(
conn,
project_layout=project_layout,
embedding_provider=embedding_provider,
description_provider=description_provider,
chunk_embedding_mode=chunk_embedding_mode,
)
return run_indexing(project_layout=project_layout, build_service=build_service, contexts=contexts)


def _create_build_service(
conn: DuckDBPyConnection,
*,
Expand Down
31 changes: 31 additions & 0 deletions src/databao_context_engine/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,37 @@ def build(
click.echo(f"Build complete. Processed {len(result)} datasources.")


@dce.command()
@click.argument(
"datasources-config-files",
nargs=-1,
type=click.STRING,
)
@click.pass_context
def index(ctx: Context, datasources_config_files: tuple[str, ...]) -> None:
"""Index and create embeddings for built context files into duckdb.

If one or more datasource config files are provided, only those datasources will be indexed.
If no paths are provided, all built contexts found in the output directory will be indexed.
"""
datasource_ids = (
[DatasourceId.from_string_repr(p) for p in datasources_config_files] if datasources_config_files else None
)

summary = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).index_built_contexts(
datasource_ids=datasource_ids
)

suffix = []
if summary.skipped:
suffix.append(f"skipped {summary.skipped}")
if summary.failed:
suffix.append(f"failed {summary.failed}")

extra = f" ({', '.join(suffix)})" if suffix else ""
click.echo(f"Indexing complete. Indexed {summary.indexed}/{summary.total} datasource(s){extra}.")


@dce.command()
@click.argument(
"retrieve-text",
Expand Down
31 changes: 31 additions & 0 deletions src/databao_context_engine/databao_context_project_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from typing import Any, overload

from databao_context_engine.build_sources import BuildContextResult, build_all_datasources
from databao_context_engine.build_sources.build_runner import IndexSummary
from databao_context_engine.build_sources.build_wiring import index_built_contexts
from databao_context_engine.databao_context_engine import DatabaoContextEngine
from databao_context_engine.datasources.check_config import (
CheckDatasourceConnectionResult,
)
from databao_context_engine.datasources.check_config import (
check_datasource_connection as check_datasource_connection_internal,
)
from databao_context_engine.datasources.datasource_context import DatasourceContext
from databao_context_engine.datasources.datasource_discovery import get_datasource_list
from databao_context_engine.datasources.types import Datasource, DatasourceId
from databao_context_engine.pluginlib.build_plugin import DatasourceType
Expand Down Expand Up @@ -89,6 +92,34 @@ def build_context(
# TODO: Filter which datasources to build by datasource_ids
return build_all_datasources(project_layout=self._project_layout, chunk_embedding_mode=chunk_embedding_mode)

def index_built_contexts(
self,
datasource_ids: list[DatasourceId] | None = None,
chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY,
) -> IndexSummary:
"""Index built datasource contexts into the embeddings database.

It reads already built context files from the output directory, chunks them using the appropriate plugin,
embeds the chunks and persists both the chunks and embeddings.

Args:
datasource_ids: The list of datsource ids to index. If None, all datsources will be indexed.
chunk_embedding_mode: The mode to use for chunk embedding.

Returns:
The summary of the index operation.
"""
engine: DatabaoContextEngine = self.get_engine_for_project()
contexts: list[DatasourceContext] = engine.get_all_contexts()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improvement for an other PR: we probably should have an API in the engine to get only the datasources from a list

Right now, we only have:

  • get one datasource context
  • get all datasource context

We should add:

  • get multiple datasource contexts

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, we could add that. I considered implementing it for this PR, but I'm not sure it actually saves too much IO, but it might, specially if we have contexts that are too big.


if datasource_ids is not None:
wanted_paths = {d.datasource_path for d in datasource_ids}
contexts = [c for c in contexts if c.datasource_id.datasource_path in wanted_paths]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't you be able to simply check if c.datasource_id in datasource_ids?


return index_built_contexts(
project_layout=self._project_layout, contexts=contexts, chunk_embedding_mode=chunk_embedding_mode
)

def check_datasource_connection(
self, datasource_ids: list[DatasourceId] | None = None
) -> list[CheckDatasourceConnectionResult]:
Expand Down
32 changes: 23 additions & 9 deletions src/databao_context_engine/datasources/datasource_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable

import yaml

Expand All @@ -25,15 +26,25 @@ class DatasourceContext:
context: str


def _read_datasource_type_from_context_file(context_path: Path) -> DatasourceType:
def read_datasource_type_from_context_file(context_path: Path) -> DatasourceType:
with context_path.open("r") as context_file:
type_key = "datasource_type"
for line in context_file:
if line.startswith(f"{type_key}: "):
datasource_type = yaml.safe_load(line)[type_key]
return DatasourceType(full_type=datasource_type)
return _read_datasource_type_from_lines(context_file, source_label=str(context_path))

raise ValueError(f"Could not find type in context file {context_path}")

def read_datasource_type_from_context(context: DatasourceContext) -> DatasourceType:
return _read_datasource_type_from_lines(
context.context.splitlines(True),
source_label=str(context.datasource_id),
)


def _read_datasource_type_from_lines(lines: Iterable[str], *, source_label: str) -> DatasourceType:
type_key = "datasource_type"
for line in lines:
if line.startswith(f"{type_key}: "):
datasource_type = yaml.safe_load(line)[type_key]
return DatasourceType(full_type=datasource_type)
raise ValueError(f"Could not find type in context {source_label}")


def get_introspected_datasource_list(project_layout: ProjectLayout) -> list[Datasource]:
Expand All @@ -48,7 +59,7 @@ def get_introspected_datasource_list(project_layout: ProjectLayout) -> list[Data
result.append(
Datasource(
id=DatasourceId.from_datasource_context_file_path(relative_context_file),
type=_read_datasource_type_from_context_file(context_file),
type=read_datasource_type_from_context_file(context_file),
)
)
except ValueError as e:
Expand All @@ -73,7 +84,10 @@ def get_all_contexts(project_layout: ProjectLayout) -> list[DatasourceContext]:
result = []
for dirpath, dirnames, filenames in os.walk(project_layout.output_dir):
for context_file_name in filenames:
if Path(context_file_name).suffix not in DatasourceId.ALLOWED_YAML_SUFFIXES:
if (
Path(context_file_name).suffix not in DatasourceId.ALLOWED_YAML_SUFFIXES
or context_file_name == "all_results.yaml"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 well spotted

):
continue
context_file = Path(dirpath).joinpath(context_file_name)
relative_context_file = context_file.relative_to(project_layout.output_dir)
Expand Down
1 change: 1 addition & 0 deletions src/databao_context_engine/pluginlib/build_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class EmbeddableChunk:
class BaseBuildPlugin(Protocol):
id: str
name: str
context_type: type[Any]

def supported_types(self) -> set[str]: ...

Expand Down
Loading