Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ default-optional-dependency-keys = [
]

[project.optional-dependencies]
cli = [
"rich>=14.3.2"
]
mssql = [
"mssql-python>=1.0.0"
]
Expand Down
98 changes: 88 additions & 10 deletions src/databao_context_engine/build_sources/build_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from databao_context_engine.datasources.types import DatasourceId
from databao_context_engine.pluginlib.build_plugin import DatasourceType
from databao_context_engine.plugins.plugin_loader import load_plugins
from databao_context_engine.progress.progress import DatasourceStatus, ProgressCallback, ProgressEmitter
from databao_context_engine.project.layout import ProjectLayout

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,9 +51,7 @@ class IndexSummary:


def build(
project_layout: ProjectLayout,
*,
build_service: BuildService,
project_layout: ProjectLayout, *, build_service: BuildService, progress: ProgressCallback | None = None
) -> list[BuildContextResult]:
"""Build the context for all datasources in the project.
Expand All @@ -70,34 +69,54 @@ def build(

datasource_ids = discover_datasources(project_layout)

emitter = ProgressEmitter(progress)

if not datasource_ids:
logger.info("No sources discovered under %s", project_layout.src_dir)
emitter.task_started(total_datasources=0)
emitter.task_finished(ok=0, failed=0, skipped=0)
return []

emitter.task_started(total_datasources=len(datasource_ids))

number_of_failed_builds = 0
number_of_skipped_builds = 0

build_result = []
reset_all_results(project_layout.output_dir)
for datasource_id in datasource_ids:
for datasource_index, datasource_id in enumerate(datasource_ids, start=1):
try:
prepared_source = prepare_source(project_layout, datasource_id)

logger.info(
f'Found datasource of type "{prepared_source.datasource_type.full_type}" with name {prepared_source.datasource_id.datasource_path}'
)

emitter.datasource_started(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
)
plugin = plugins.get(prepared_source.datasource_type)
if plugin is None:
logger.warning(
"No plugin for '%s' (datasource=%s) — skipping.",
prepared_source.datasource_type.full_type,
prepared_source.datasource_id.relative_path_to_config_file(),
)
number_of_failed_builds += 1
emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.SKIPPED,
)
number_of_skipped_builds += 1
continue

result = build_service.process_prepared_source(
prepared_source=prepared_source,
plugin=plugin,
progress=progress,
)

output_dir = project_layout.output_dir
Expand All @@ -113,23 +132,46 @@ def build(
context_file_path=context_file_path,
)
)
emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.OK,
)
except Exception as e:
logger.debug(str(e), exc_info=True, stack_info=True)
logger.info(f"Failed to build source at ({datasource_id.relative_path_to_config_file()}): {str(e)}")

emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.FAILED,
error=str(e),
)
number_of_failed_builds += 1

logger.debug(
"Successfully built %d datasources. %s",
"Successfully built %d datasources. %s %s",
len(build_result),
f"Skipped {number_of_skipped_builds}." if number_of_skipped_builds > 0 else "",
f"Failed to build {number_of_failed_builds}." if number_of_failed_builds > 0 else "",
)

emitter.task_finished(
ok=len(build_result),
failed=number_of_failed_builds,
skipped=number_of_skipped_builds,
)

return build_result


def run_indexing(
*, project_layout: ProjectLayout, build_service: BuildService, contexts: list[DatasourceContext]
*,
project_layout: ProjectLayout,
build_service: BuildService,
contexts: list[DatasourceContext],
progress: ProgressCallback | None = None,
) -> IndexSummary:
"""Index a list of built datasource contexts.
Expand All @@ -144,10 +186,25 @@ def run_indexing(

summary = IndexSummary(total=len(contexts), indexed=0, skipped=0, failed=0)

for context in contexts:
emitter = ProgressEmitter(progress)

if not contexts:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this if-block is not needed. It will behave the exact same way without it.

emitter.task_started(total_datasources=0)
emitter.task_finished(ok=0, failed=0, skipped=0)
return summary

emitter.task_started(total_datasources=len(contexts))

for datasource_index, context in enumerate(contexts, start=1):
try:
logger.info(f"Indexing datasource {context.datasource_id}")

emitter.datasource_started(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
)

datasource_type = read_datasource_type_from_context(context)

plugin = plugins.get(datasource_type)
Expand All @@ -158,14 +215,34 @@ def run_indexing(
context.datasource_id,
)
summary.skipped += 1
emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.SKIPPED,
)
continue

build_service.index_built_context(context=context, plugin=plugin)
build_service.index_built_context(context=context, plugin=plugin, progress=progress)
summary.indexed += 1

emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.OK,
)
except Exception as e:
logger.debug(str(e), exc_info=True, stack_info=True)
logger.info(f"Failed to build source at ({context.datasource_id}): {str(e)}")
summary.failed += 1
emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.FAILED,
error=str(e),
)

logger.debug(
"Successfully indexed %d/%d datasource(s). %s",
Expand All @@ -174,4 +251,5 @@ def run_indexing(
f"Skipped {summary.skipped}. Failed {summary.failed}." if (summary.skipped or summary.failed) else "",
)

emitter.task_finished(ok=summary.indexed, failed=summary.failed, skipped=summary.skipped)
return summary
8 changes: 7 additions & 1 deletion src/databao_context_engine/build_sources/build_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from databao_context_engine.pluginlib.build_plugin import (
BuildPlugin,
)
from databao_context_engine.progress.progress import ProgressCallback
from databao_context_engine.project.layout import ProjectLayout
from databao_context_engine.serialization.yaml import to_yaml_string
from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingService
Expand All @@ -35,6 +36,7 @@ def process_prepared_source(
*,
prepared_source: PreparedDatasource,
plugin: BuildPlugin,
progress: ProgressCallback | None = None,
) -> BuiltDatasourceContext:
"""Process a single source to build its context.
Expand All @@ -58,11 +60,14 @@ def process_prepared_source(
result=to_yaml_string(result.context),
full_type=prepared_source.datasource_type.full_type,
datasource_id=result.datasource_id,
progress=progress,
)

return result

def index_built_context(self, *, context: DatasourceContext, plugin: BuildPlugin) -> None:
def index_built_context(
self, *, context: DatasourceContext, plugin: BuildPlugin, progress: ProgressCallback | None = None
) -> None:
"""Index a context file using the given plugin.
1) Parses the yaml context file contents
Expand All @@ -85,6 +90,7 @@ def index_built_context(self, *, context: DatasourceContext, plugin: BuildPlugin
full_type=built.datasource_type,
datasource_id=built.datasource_id,
override=True,
progress=progress,
)

def _deserialize_built_context(
Expand Down
13 changes: 11 additions & 2 deletions src/databao_context_engine/build_sources/build_wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
create_ollama_embedding_provider,
create_ollama_service,
)
from databao_context_engine.progress.progress import ProgressCallback
from databao_context_engine.project.layout import ProjectLayout
from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingMode
from databao_context_engine.services.factories import create_chunk_embedding_service
Expand All @@ -23,7 +24,10 @@


def build_all_datasources(
project_layout: ProjectLayout, chunk_embedding_mode: ChunkEmbeddingMode
project_layout: ProjectLayout,
chunk_embedding_mode: ChunkEmbeddingMode,
*,
progress: ProgressCallback | None = None,
) -> list[BuildContextResult]:
"""Build the context for all datasources in the project.
Expand Down Expand Up @@ -62,13 +66,16 @@ def build_all_datasources(
return build(
project_layout=project_layout,
build_service=build_service,
progress=progress,
)


def index_built_contexts(
project_layout: ProjectLayout,
contexts: list[DatasourceContext],
chunk_embedding_mode: ChunkEmbeddingMode,
*,
progress: ProgressCallback | None = None,
) -> IndexSummary:
"""Index the contexts into the database.
Expand Down Expand Up @@ -101,7 +108,9 @@ def index_built_contexts(
description_provider=description_provider,
chunk_embedding_mode=chunk_embedding_mode,
)
return run_indexing(project_layout=project_layout, build_service=build_service, contexts=contexts)
return run_indexing(
project_layout=project_layout, build_service=build_service, contexts=contexts, progress=progress
)


def _create_build_service(
Expand Down
18 changes: 12 additions & 6 deletions src/databao_context_engine/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from databao_context_engine.cli.info import echo_info
from databao_context_engine.config.logging import configure_logging
from databao_context_engine.mcp.mcp_runner import McpTransport, run_mcp_server
from databao_context_engine.progress.rich_progress import rich_progress


@click.group()
Expand Down Expand Up @@ -151,9 +152,12 @@ def build(
Internally, this indexes the context to be used by the MCP server and the "retrieve" command.
"""
result = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).build_context(
datasource_ids=None, chunk_embedding_mode=ChunkEmbeddingMode(chunk_embedding_mode.upper())
)
with rich_progress() as progress_cb:
result = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).build_context(
datasource_ids=None,
chunk_embedding_mode=ChunkEmbeddingMode(chunk_embedding_mode.upper()),
progress=progress_cb,
)

click.echo(f"Build complete. Processed {len(result)} datasources.")

Expand All @@ -175,9 +179,11 @@ def index(ctx: Context, datasources_config_files: tuple[str, ...]) -> None:
[DatasourceId.from_string_repr(p) for p in datasources_config_files] if datasources_config_files else None
)

summary = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).index_built_contexts(
datasource_ids=datasource_ids
)
with rich_progress() as progress_cb:
summary = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).index_built_contexts(
datasource_ids=datasource_ids,
progress=progress_cb,
)

suffix = []
if summary.skipped:
Expand Down
18 changes: 16 additions & 2 deletions src/databao_context_engine/databao_context_project_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from databao_context_engine.datasources.datasource_discovery import get_datasource_list
from databao_context_engine.datasources.types import Datasource, DatasourceId
from databao_context_engine.pluginlib.build_plugin import DatasourceType
from databao_context_engine.progress.progress import ProgressCallback
from databao_context_engine.project.layout import (
ProjectLayout,
ensure_project_dir,
Expand Down Expand Up @@ -77,6 +78,8 @@ def build_context(
self,
datasource_ids: list[DatasourceId] | None = None,
chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY,
*,
progress: ProgressCallback | None = None,
) -> list[BuildContextResult]:
"""Build the context for datasources in the project.
Expand All @@ -85,17 +88,24 @@ def build_context(
Args:
datasource_ids: The list of datasource ids to build. If None, all datasources will be built.
chunk_embedding_mode: The mode to use for chunk embedding.
progress: Optional callback that receives progress events during execution.
Returns:
The list of all built results.
"""
# TODO: Filter which datasources to build by datasource_ids
return build_all_datasources(project_layout=self._project_layout, chunk_embedding_mode=chunk_embedding_mode)
return build_all_datasources(
project_layout=self._project_layout,
chunk_embedding_mode=chunk_embedding_mode,
progress=progress,
)

def index_built_contexts(
self,
datasource_ids: list[DatasourceId] | None = None,
chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY,
*,
progress: ProgressCallback | None = None,
) -> IndexSummary:
"""Index built datasource contexts into the embeddings database.
Expand All @@ -105,6 +115,7 @@ def index_built_contexts(
Args:
datasource_ids: The list of datsource ids to index. If None, all datsources will be indexed.
chunk_embedding_mode: The mode to use for chunk embedding.
progress: Optional callback that receives progress events during execution.
Returns:
The summary of the index operation.
Expand All @@ -117,7 +128,10 @@ def index_built_contexts(
contexts = [c for c in contexts if c.datasource_id.datasource_path in wanted_paths]

return index_built_contexts(
project_layout=self._project_layout, contexts=contexts, chunk_embedding_mode=chunk_embedding_mode
project_layout=self._project_layout,
contexts=contexts,
chunk_embedding_mode=chunk_embedding_mode,
progress=progress,
)

def check_datasource_connection(
Expand Down
Empty file.
Loading