From 6723ba8052aeea79277cf6d347ca124eed622acd Mon Sep 17 00:00:00 2001 From: MateusCordeiro Date: Thu, 5 Feb 2026 23:03:23 +0100 Subject: [PATCH 1/6] progress bar v1 --- pyproject.toml | 1 + .../build_sources/build_runner.py | 54 ++++- .../build_sources/build_service.py | 24 ++ .../build_sources/build_wiring.py | 7 +- .../build_sources/export_results.py | 2 +- src/databao_context_engine/cli/commands.py | 8 +- .../databao_context_project_manager.py | 9 +- .../progress/__init__.py | 0 .../progress/progress.py | 216 +++++++++++++++++ .../progress/rich_progress.py | 220 ++++++++++++++++++ .../services/chunk_embedding_service.py | 27 ++- .../services/persistence_service.py | 25 +- uv.lock | 36 +++ 13 files changed, 611 insertions(+), 18 deletions(-) create mode 100644 src/databao_context_engine/progress/__init__.py create mode 100644 src/databao_context_engine/progress/progress.py create mode 100644 src/databao_context_engine/progress/rich_progress.py diff --git a/pyproject.toml b/pyproject.toml index e66cacdb..029f9d4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "mcp>=1.23.3", "pydantic>=2.12.4", "jinja2>=3.1.6", + "rich>=14.3.2" ] default-optional-dependency-keys = [ "mysql", diff --git a/src/databao_context_engine/build_sources/build_runner.py b/src/databao_context_engine/build_sources/build_runner.py index 54ede56f..42193424 100644 --- a/src/databao_context_engine/build_sources/build_runner.py +++ b/src/databao_context_engine/build_sources/build_runner.py @@ -13,6 +13,7 @@ from databao_context_engine.datasources.types import DatasourceId from databao_context_engine.pluginlib.build_plugin import DatasourceType from databao_context_engine.plugins.plugin_loader import load_plugins +from databao_context_engine.progress.progress import DatasourceStatus, ProgressCallback, ProgressEmitter from databao_context_engine.project.layout import ProjectLayout logger = logging.getLogger(__name__) @@ -36,9 +37,7 @@ class BuildContextResult: def build( - project_layout: ProjectLayout, - *, - build_service: BuildService, + project_layout: ProjectLayout, *, build_service: BuildService, progress: ProgressCallback | None = None ) -> list[BuildContextResult]: """Build the context for all datasources in the project. @@ -57,21 +56,35 @@ def build( datasources = discover_datasources(project_layout) + emitter = ProgressEmitter(progress) + if not datasources: logger.info("No sources discovered under %s", project_layout.src_dir) + emitter.build_started(total_datasources=0) + emitter.build_finished(ok=0, failed=0, skipped=0) return [] + emitter.build_started(total_datasources=len(datasources)) + number_of_failed_builds = 0 build_result = [] reset_all_results(project_layout.output_dir) - for discovered_datasource in datasources: + for datasource_index, discovered_datasource in enumerate(datasources, start=1): + datasource_id = str(discovered_datasource.datasource_id) try: prepared_source = prepare_source(discovered_datasource) - logger.info( - f'Found datasource of type "{prepared_source.datasource_type.full_type}" with name {prepared_source.path.stem}' - ) + # logger.info( + # f'Found datasource of type "{prepared_source.datasource_type.full_type}" with name {prepared_source.path.stem}' + # ) + emitter.datasource_started( + datasource_id=datasource_id, + datasource_type=prepared_source.datasource_type.full_type, + datasource_path=str(prepared_source.path), + index=datasource_index, + total=len(datasources), + ) plugin = plugins.get(prepared_source.datasource_type) if plugin is None: logger.warning( @@ -79,12 +92,19 @@ def build( prepared_source.datasource_type.full_type, prepared_source.path, ) + emitter.datasource_finished( + datasource_id=datasource_id, + index=datasource_index, + total=len(datasources), + status=DatasourceStatus.SKIPPED, + ) number_of_failed_builds += 1 continue result = build_service.process_prepared_source( prepared_source=prepared_source, plugin=plugin, + progress=progress, ) output_dir = project_layout.output_dir @@ -100,10 +120,22 @@ def build( context_file_path=context_file_path, ) ) + emitter.datasource_finished( + datasource_id=datasource_id, + index=datasource_index, + total=len(datasources), + status=DatasourceStatus.OK, + ) except Exception as e: logger.debug(str(e), exc_info=True, stack_info=True) logger.info(f"Failed to build source at ({discovered_datasource.path}): {str(e)}") - + emitter.datasource_finished( + datasource_id=datasource_id, + index=datasource_index, + total=len(datasources), + status=DatasourceStatus.FAILED, + error=str(e), + ) number_of_failed_builds += 1 logger.debug( @@ -112,4 +144,10 @@ def build( f"Failed to build {number_of_failed_builds}." if number_of_failed_builds > 0 else "", ) + emitter.build_finished( + ok=len(build_result), + failed=number_of_failed_builds, + skipped=0, + ) + return build_result diff --git a/src/databao_context_engine/build_sources/build_service.py b/src/databao_context_engine/build_sources/build_service.py index ab18c36f..4e0f181b 100644 --- a/src/databao_context_engine/build_sources/build_service.py +++ b/src/databao_context_engine/build_sources/build_service.py @@ -7,6 +7,7 @@ from databao_context_engine.pluginlib.build_plugin import ( BuildPlugin, ) +from databao_context_engine.progress.progress import ProgressCallback, ProgressEmitter, DatasourcePhase from databao_context_engine.serialization.yaml import to_yaml_string from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingService @@ -26,6 +27,7 @@ def process_prepared_source( *, prepared_source: PreparedDatasource, plugin: BuildPlugin, + progress: ProgressCallback | None = None, ) -> BuiltDatasourceContext: """Process a single source to build its context. @@ -36,18 +38,40 @@ def process_prepared_source( Returns: The built context. """ + emitter = ProgressEmitter(progress) + + emitter.datasource_phase( + datasource_id=prepared_source.datasource_id if hasattr(prepared_source, 'datasource_id') else '', + phase=DatasourcePhase.EXECUTE_PLUGIN, + message="Executing plugin", + ) + result = execute(prepared_source, plugin) + emitter.datasource_phase( + datasource_id=result.datasource_id, + phase=DatasourcePhase.DIVIDE_CHUNKS, + message="Dividing context into chunks", + ) + chunks = plugin.divide_context_into_chunks(result.context) + emitter.chunks_discovered(datasource_id=result.datasource_id, total_chunks=len(chunks)) if not chunks: logger.info("No chunks for %s — skipping.", prepared_source.path.name) return result + emitter.datasource_phase( + datasource_id=result.datasource_id, + phase=DatasourcePhase.EMBED, + message="Embedding chunks", + ) + self._chunk_embedding_service.embed_chunks( chunks=chunks, result=to_yaml_string(result.context), full_type=prepared_source.datasource_type.full_type, datasource_id=result.datasource_id, + progress=progress ) return result diff --git a/src/databao_context_engine/build_sources/build_wiring.py b/src/databao_context_engine/build_sources/build_wiring.py index a3c26d84..d40c3d79 100644 --- a/src/databao_context_engine/build_sources/build_wiring.py +++ b/src/databao_context_engine/build_sources/build_wiring.py @@ -11,6 +11,7 @@ create_ollama_embedding_provider, create_ollama_service, ) +from databao_context_engine.progress.progress import ProgressCallback from databao_context_engine.project.layout import ProjectLayout from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingMode from databao_context_engine.services.factories import create_chunk_embedding_service @@ -22,7 +23,10 @@ def build_all_datasources( - project_layout: ProjectLayout, chunk_embedding_mode: ChunkEmbeddingMode + project_layout: ProjectLayout, + chunk_embedding_mode: ChunkEmbeddingMode, + *, + progress: ProgressCallback | None = None, ) -> list[BuildContextResult]: """Build the context for all datasources in the project. @@ -60,6 +64,7 @@ def build_all_datasources( return build( project_layout=project_layout, build_service=build_service, + progress=progress, ) diff --git a/src/databao_context_engine/build_sources/export_results.py b/src/databao_context_engine/build_sources/export_results.py index 3fd1814b..b00de834 100644 --- a/src/databao_context_engine/build_sources/export_results.py +++ b/src/databao_context_engine/build_sources/export_results.py @@ -20,7 +20,7 @@ def export_build_result(output_dir: Path, result: BuiltDatasourceContext) -> Pat with export_file_path.open("w") as export_file: write_yaml_to_stream(data=result, file_stream=export_file) - logger.info(f"Exported result to {export_file_path.resolve()}") + # logger.info(f"Exported result to {export_file_path.resolve()}") return export_file_path diff --git a/src/databao_context_engine/cli/commands.py b/src/databao_context_engine/cli/commands.py index d0e60b80..956c6636 100644 --- a/src/databao_context_engine/cli/commands.py +++ b/src/databao_context_engine/cli/commands.py @@ -20,6 +20,7 @@ from databao_context_engine.cli.info import echo_info from databao_context_engine.config.logging import configure_logging from databao_context_engine.mcp.mcp_runner import McpTransport, run_mcp_server +from databao_context_engine.progress.rich_progress import rich_progress @click.group() @@ -151,9 +152,10 @@ def build( Internally, this indexes the context to be used by the MCP server and the "retrieve" command. """ - result = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).build_context( - datasource_ids=None, chunk_embedding_mode=ChunkEmbeddingMode(chunk_embedding_mode.upper()) - ) + with rich_progress() as progress_cb: + result = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).build_context( + datasource_ids=None, chunk_embedding_mode=ChunkEmbeddingMode(chunk_embedding_mode.upper()), progress=progress_cb + ) click.echo(f"Build complete. Processed {len(result)} datasources.") diff --git a/src/databao_context_engine/databao_context_project_manager.py b/src/databao_context_engine/databao_context_project_manager.py index 5bb9fdc9..4271d5bf 100644 --- a/src/databao_context_engine/databao_context_project_manager.py +++ b/src/databao_context_engine/databao_context_project_manager.py @@ -13,6 +13,7 @@ from databao_context_engine.datasources.datasource_discovery import get_datasource_list from databao_context_engine.datasources.types import Datasource, DatasourceId from databao_context_engine.pluginlib.build_plugin import DatasourceType +from databao_context_engine.progress.progress import ProgressCallback from databao_context_engine.project.layout import ( ProjectLayout, ensure_project_dir, @@ -74,6 +75,8 @@ def build_context( self, datasource_ids: list[DatasourceId] | None = None, chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY, + *, + progress: ProgressCallback | None = None, ) -> list[BuildContextResult]: """Build the context for datasources in the project. @@ -87,7 +90,11 @@ def build_context( The list of all built results. """ # TODO: Filter which datasources to build by datasource_ids - return build_all_datasources(project_layout=self._project_layout, chunk_embedding_mode=chunk_embedding_mode) + return build_all_datasources( + project_layout=self._project_layout, + chunk_embedding_mode=chunk_embedding_mode, + progress=progress, + ) def check_datasource_connection( self, datasource_ids: list[DatasourceId] | None = None diff --git a/src/databao_context_engine/progress/__init__.py b/src/databao_context_engine/progress/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/databao_context_engine/progress/progress.py b/src/databao_context_engine/progress/progress.py new file mode 100644 index 00000000..3f462e42 --- /dev/null +++ b/src/databao_context_engine/progress/progress.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Callable + + +class ProgressKind(str, Enum): + # Build-level + BUILD_STARTED = "build_started" + BUILD_FINISHED = "build_finished" + + # Datasource-level lifecycle + DATASOURCE_STARTED = "datasource_started" + DATASOURCE_FINISHED = "datasource_finished" + + # Datasource-level detail / phases + DATASOURCE_PHASE = "datasource_phase" + CHUNKS_DISCOVERED = "chunks_discovered" + + # Chunk embedding + EMBEDDING_STARTED = "embedding_started" + EMBEDDING_PROGRESS = "embedding_progress" + EMBEDDING_FINISHED = "embedding_finished" + + # Persistence + PERSIST_STARTED = "persist_started" + PERSIST_PROGRESS = "persist_progress" + PERSIST_FINISHED = "persist_finished" + + +class DatasourceStatus(str, Enum): + OK = "ok" + SKIPPED = "skipped" + FAILED = "failed" + + +class DatasourcePhase(str, Enum): + EXECUTE_PLUGIN = "execute_plugin" + DIVIDE_CHUNKS = "divide_chunks" + EMBED = "embed" + PERSIST = "persist" + EXPORT = "export" + + +@dataclass(frozen=True, slots=True) +class ProgressEvent: + """A structured progress update emitted by the build pipeline. + + Conventions: + - `total` may be None when unknown. + - For *_PROGRESS events, `done` should be monotonic increasing up to `total` (if total known). + - `message` is optional human-readable text; callers can ignore it and render their own. + """ + + kind: ProgressKind + + # Identifiers (often present, but not required for build-level events) + datasource_id: str | None = None + datasource_type: str | None = None # e.g. full_type + datasource_path: str | None = None # if helpful for debugging/logging + + # Build progress (datasources) + datasource_index: int | None = None # 1-based index + datasource_total: int | None = None # total datasources + + # Sub-progress (chunks / persistence) + done: int | None = None + total: int | None = None + + # Extra structured fields + phase: DatasourcePhase | None = None + status: DatasourceStatus | None = None + error: str | None = None + + # Human text + message: str = "" + + +ProgressCallback = Callable[[ProgressEvent], None] + + +class ProgressEmitter: + """Small helper so you don't sprinkle `if progress: progress(...)` everywhere. + + You can also extend this later with throttling helpers. + """ + + def __init__(self, cb: ProgressCallback | None): + self._cb = cb + + def emit(self, event: ProgressEvent) -> None: + if self._cb is not None: + self._cb(event) + + # Convenience builders (optional, but nice for consistency) + def build_started(self, *, total_datasources: int | None) -> None: + self.emit(ProgressEvent(kind=ProgressKind.BUILD_STARTED, datasource_total=total_datasources)) + + def build_finished(self, *, ok: int, failed: int, skipped: int) -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.BUILD_FINISHED, + message=f"Finished (ok={ok}, failed={failed}, skipped={skipped})", + ) + ) + + def datasource_started( + self, + *, + datasource_id: str, + datasource_type: str | None, + datasource_path: str | None, + index: int, + total: int, + ) -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.DATASOURCE_STARTED, + datasource_id=datasource_id, + datasource_type=datasource_type, + datasource_path=datasource_path, + datasource_index=index, + datasource_total=total, + message=f"Starting {datasource_id}", + ) + ) + + def datasource_phase(self, *, datasource_id: str, phase: DatasourcePhase, message: str = "") -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.DATASOURCE_PHASE, + datasource_id=datasource_id, + phase=phase, + message=message or phase.value, + ) + ) + + def chunks_discovered(self, *, datasource_id: str, total_chunks: int) -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.CHUNKS_DISCOVERED, + datasource_id=datasource_id, + total=total_chunks, + message=f"Discovered {total_chunks} chunks", + ) + ) + + def embedding_started(self, *, datasource_id: str, total_chunks: int) -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.EMBEDDING_STARTED, + datasource_id=datasource_id, + total=total_chunks, + message="Embedding started", + ) + ) + + def embedding_progress(self, *, datasource_id: str, done: int, total: int | None, message: str = "") -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.EMBEDDING_PROGRESS, + datasource_id=datasource_id, + done=done, + total=total, + message=message, + ) + ) + + def embedding_finished(self, *, datasource_id: str) -> None: + self.emit(ProgressEvent(kind=ProgressKind.EMBEDDING_FINISHED, datasource_id=datasource_id, message="Embedding finished")) + + def persist_started(self, *, datasource_id: str, total_items: int) -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.PERSIST_STARTED, + datasource_id=datasource_id, + total=total_items, + message="Persist started", + ) + ) + + def persist_progress(self, *, datasource_id: str, done: int, total: int | None, message: str = "") -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.PERSIST_PROGRESS, + datasource_id=datasource_id, + done=done, + total=total, + message=message, + ) + ) + + def persist_finished(self, *, datasource_id: str) -> None: + self.emit(ProgressEvent(kind=ProgressKind.PERSIST_FINISHED, datasource_id=datasource_id, message="Persist finished")) + + def datasource_finished( + self, + *, + datasource_id: str, + index: int, + total: int, + status: DatasourceStatus, + error: str | None = None, + ) -> None: + self.emit( + ProgressEvent( + kind=ProgressKind.DATASOURCE_FINISHED, + datasource_id=datasource_id, + datasource_index=index, + datasource_total=total, + status=status, + error=error, + message=(f"Finished {datasource_id}" if status == DatasourceStatus.OK else f"{status.value}: {datasource_id}"), + ) + ) \ No newline at end of file diff --git a/src/databao_context_engine/progress/rich_progress.py b/src/databao_context_engine/progress/rich_progress.py new file mode 100644 index 00000000..aca011a7 --- /dev/null +++ b/src/databao_context_engine/progress/rich_progress.py @@ -0,0 +1,220 @@ +# databao_context_engine/cli/rich_progress.py +from __future__ import annotations + +from contextlib import contextmanager +from typing import Iterator + +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskID, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, +) + +from databao_context_engine.progress.progress import ( + DatasourcePhase, + ProgressCallback, + ProgressEvent, + ProgressKind, +) + + +@contextmanager +def rich_progress() -> Iterator[ProgressCallback]: + """Context manager that provides a ProgressCallback rendering a live UI via Rich. + + Usage: + with rich_progress() as progress_cb: + manager.build_context(..., progress=progress_cb) + """ # noqa: DOC402 + console = Console(stderr=True) + progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + transient=False, # keep the progress display after finishing (change to True if you prefer) + console=console, + ) + + # Task IDs (created lazily when we receive events) + tasks: dict[str, TaskID] = {} + + # Some state so we can update descriptions nicely + state = { + "current_datasource_id": None, + "current_phase": None, + "datasource_index": None, + "datasource_total": None, + "embed_total": None, + "persist_total": None, + } + + def _ensure_overall_task(total: int | None) -> TaskID: + if "overall" not in tasks: + tasks["overall"] = progress.add_task("Datasources", total=total) + else: + # allow setting/updating totals later + if total is not None: + progress.update(tasks["overall"], total=total) + return tasks["overall"] + + def _ensure_embed_task(total: int | None) -> TaskID: + ds = state["current_datasource_id"] or "datasource" + phase = state["current_phase"] or DatasourcePhase.EMBED + desc = f"{ds} • {phase.value}" + if "embed" not in tasks: + tasks["embed"] = progress.add_task(desc, total=total) + else: + progress.update(tasks["embed"], description=desc) + if total is not None: + progress.update(tasks["embed"], total=total) + return tasks["embed"] + + def _ensure_persist_task(total: int | None) -> TaskID: + ds = state["current_datasource_id"] or "datasource" + phase = state["current_phase"] or DatasourcePhase.PERSIST + desc = f"{ds} • {phase.value}" + if "persist" not in tasks: + tasks["persist"] = progress.add_task(desc, total=total) + else: + progress.update(tasks["persist"], description=desc) + if total is not None: + progress.update(tasks["persist"], total=total) + return tasks["persist"] + + def _update_overall_description() -> None: + if "overall" not in tasks: + return + idx = state["datasource_index"] + tot = state["datasource_total"] + ds = state["current_datasource_id"] + ph = state["current_phase"].value if state["current_phase"] else None + + if idx is not None and tot is not None and ds: + suffix = f" • {ds}" + if ph: + suffix += f" • {ph}" + progress.update(tasks["overall"], description=f"Datasources {idx}/{tot}{suffix}") + + def on_event(ev: ProgressEvent) -> None: + # ---- Build-level ---- + if ev.kind == ProgressKind.BUILD_STARTED: + _ensure_overall_task(ev.datasource_total) + return + + if ev.kind == ProgressKind.BUILD_FINISHED: + # Optional: mark overall as completed if totals are known. + # (We’re already advancing it on datasource finished.) + if ev.message: + progress.console.print(f"{ev.message}") + return + + # ---- Datasource lifecycle ---- + if ev.kind == ProgressKind.DATASOURCE_STARTED: + state["current_datasource_id"] = ev.datasource_id + state["datasource_index"] = ev.datasource_index + state["datasource_total"] = ev.datasource_total + state["embed_total"] = None + state["persist_total"] = None + _ensure_overall_task(ev.datasource_total) + _update_overall_description() + + # Reset per-datasource tasks so each datasource starts fresh without flicker. + # Keep the tasks and just reset progress/description. + ds = state["current_datasource_id"] or "datasource" + + if "embed" in tasks: + progress.update( + tasks["embed"], completed=0, total=None, description=f"{ds} • {DatasourcePhase.EMBED.value}" + ) + + if "persist" in tasks: + progress.update( + tasks["persist"], completed=0, total=None, description=f"{ds} • {DatasourcePhase.PERSIST.value}" + ) + return + + if ev.kind == ProgressKind.DATASOURCE_PHASE: + # Update current phase and refresh the overall description line + state["current_phase"] = ev.phase + _update_overall_description() + return + + if ev.kind == ProgressKind.DATASOURCE_FINISHED: + idx = ev.datasource_index or 0 + tot = ev.datasource_total or 0 + ds = ev.datasource_id or "(unknown datasource)" + status = ev.status.value if ev.status else "done" + progress.console.print(f"{status.upper()} {idx}/{tot}: {ds}") + + + # Advance overall by 1 (ok/failed/skipped are all "completed" from a progress perspective) + _ensure_overall_task(ev.datasource_total) + progress.advance(tasks["overall"], 1) + + # Clear phase display for next datasource + state["current_phase"] = None + _update_overall_description() + return + + # ---- Chunk discovery is informative; embed total will arrive in EMBEDDING_STARTED anyway ---- + if ev.kind == ProgressKind.CHUNKS_DISCOVERED: + # You could surface this as a message or update descriptions; optional. + return + + # ---- Embedding ---- + if ev.kind == ProgressKind.EMBEDDING_STARTED: + state["current_phase"] = DatasourcePhase.EMBED + _update_overall_description() + _ensure_embed_task(ev.total) + state["embed_total"] = ev.total + return + + if ev.kind == ProgressKind.EMBEDDING_PROGRESS: + task_id = _ensure_embed_task(ev.total) + # Use completed to avoid drift if events are throttled + if ev.done is not None: + progress.update(task_id, completed=ev.done) + if ev.total is not None: + progress.update(task_id, total=ev.total) + state["embed_total"] = ev.total + return + + if ev.kind == ProgressKind.EMBEDDING_FINISHED: + # If we have an embed task, mark it complete (if total known) + if "embed" in tasks and state["embed_total"] is not None: + progress.update(tasks["embed"], completed=state["embed_total"]) + return + + # ---- Persistence ---- + if ev.kind == ProgressKind.PERSIST_STARTED: + state["current_phase"] = DatasourcePhase.PERSIST + _update_overall_description() + _ensure_persist_task(ev.total) + state["persist_total"] = ev.total + return + + if ev.kind == ProgressKind.PERSIST_PROGRESS: + task_id = _ensure_persist_task(ev.total) + if ev.done is not None: + progress.update(task_id, completed=ev.done) + if ev.total is not None: + progress.update(task_id, total=ev.total) + state["persist_total"] = ev.total + return + + if ev.kind == ProgressKind.PERSIST_FINISHED: + if "persist" in tasks and state["persist_total"] is not None: + progress.update(tasks["persist"], completed=state["persist_total"]) + return + + # Ignore unknown events by default + + with progress: + yield on_event diff --git a/src/databao_context_engine/services/chunk_embedding_service.py b/src/databao_context_engine/services/chunk_embedding_service.py index 803acd11..d15ca2c2 100644 --- a/src/databao_context_engine/services/chunk_embedding_service.py +++ b/src/databao_context_engine/services/chunk_embedding_service.py @@ -5,6 +5,7 @@ from databao_context_engine.llm.descriptions.provider import DescriptionProvider from databao_context_engine.llm.embeddings.provider import EmbeddingProvider from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk +from databao_context_engine.progress.progress import ProgressCallback, ProgressEmitter from databao_context_engine.serialization.yaml import to_yaml_string from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver from databao_context_engine.services.models import ChunkEmbedding @@ -57,7 +58,15 @@ def __init__( if self._chunk_embedding_mode.should_generate_description() and description_provider is None: raise ValueError("A DescriptionProvider must be provided when generating descriptions") - def embed_chunks(self, *, chunks: list[EmbeddableChunk], result: str, full_type: str, datasource_id: str) -> None: + def embed_chunks( + self, + *, + chunks: list[EmbeddableChunk], + result: str, + full_type: str, + datasource_id: str, + progress: ProgressCallback | None = None, + ) -> None: """Turn plugin chunks into persisted chunks and embeddings. Flow: @@ -68,12 +77,16 @@ def embed_chunks(self, *, chunks: list[EmbeddableChunk], result: str, full_type: if not chunks: return + emitter = ProgressEmitter(progress) + emitter.embedding_started(datasource_id=datasource_id, total_chunks=len(chunks)) + logger.debug( f"Embedding {len(chunks)} chunks for datasource {datasource_id}, with chunk_embedding_mode={self._chunk_embedding_mode}" ) enriched_embeddings: list[ChunkEmbedding] = [] - for chunk in chunks: + emit_every = 10 # avoid overly chatty UIs + for i, chunk in enumerate(chunks, start=1): chunk_display_text = chunk.content if isinstance(chunk.content, str) else to_yaml_string(chunk.content) generated_description = "" @@ -101,6 +114,15 @@ def embed_chunks(self, *, chunks: list[EmbeddableChunk], result: str, full_type: generated_description=generated_description, ) ) + if i % emit_every == 0 or i == len(chunks): + emitter.embedding_progress( + datasource_id=datasource_id, + done=i, + total=len(chunks), + message=f"Embedded {i}/{len(chunks)} chunks", + ) + + emitter.embedding_finished(datasource_id=datasource_id) table_name = self._shard_resolver.resolve_or_create( embedder=self._embedding_provider.embedder, @@ -113,4 +135,5 @@ def embed_chunks(self, *, chunks: list[EmbeddableChunk], result: str, full_type: table_name=table_name, full_type=full_type, datasource_id=datasource_id, + progress=progress, ) diff --git a/src/databao_context_engine/services/persistence_service.py b/src/databao_context_engine/services/persistence_service.py index ea4c63ae..102cd305 100644 --- a/src/databao_context_engine/services/persistence_service.py +++ b/src/databao_context_engine/services/persistence_service.py @@ -2,6 +2,7 @@ import duckdb +from databao_context_engine.progress.progress import ProgressCallback, ProgressEmitter from databao_context_engine.services.models import ChunkEmbedding from databao_context_engine.storage.models import ChunkDTO from databao_context_engine.storage.repositories.chunk_repository import ChunkRepository @@ -24,7 +25,13 @@ def __init__( self._dim = dim def write_chunks_and_embeddings( - self, *, chunk_embeddings: list[ChunkEmbedding], table_name: str, full_type: str, datasource_id: str + self, + *, + chunk_embeddings: list[ChunkEmbedding], + table_name: str, + full_type: str, + datasource_id: str, + progress: ProgressCallback | None = None, ): """Atomically persist chunks and their vectors. @@ -35,8 +42,13 @@ def write_chunks_and_embeddings( if not chunk_embeddings: raise ValueError("chunk_embeddings must be a non-empty list") + emitter = ProgressEmitter(progress) + total_items = len(chunk_embeddings) + emitter.persist_started(datasource_id=datasource_id, total_items=total_items) + with transaction(self._conn): - for chunk_embedding in chunk_embeddings: + emit_every = 25 + for i, chunk_embedding in enumerate(chunk_embeddings, start=1): chunk_dto = self.create_chunk( full_type=full_type, datasource_id=datasource_id, @@ -45,6 +57,15 @@ def write_chunks_and_embeddings( ) self.create_embedding(table_name=table_name, chunk_id=chunk_dto.chunk_id, vec=chunk_embedding.vec) + if i % emit_every == 0 or i == total_items: + emitter.persist_progress( + datasource_id=datasource_id, + done=i, + total=total_items, + message=f"Persisted {i}/{total_items} chunks", + ) + emitter.persist_finished(datasource_id=datasource_id) + def create_chunk(self, *, full_type: str, datasource_id: str, embeddable_text: str, display_text: str) -> ChunkDTO: return self._chunk_repo.create( full_type=full_type, diff --git a/uv.lock b/uv.lock index 443fa802..52fbc68d 100644 --- a/uv.lock +++ b/uv.lock @@ -420,6 +420,7 @@ dependencies = [ { name = "pydantic" }, { name = "pyyaml" }, { name = "requests" }, + { name = "rich" }, ] [package.optional-dependencies] @@ -470,6 +471,7 @@ requires-dist = [ { name = "pymysql", marker = "extra == 'mysql'", specifier = ">=1.1.2" }, { name = "pyyaml", specifier = ">=6.0.3" }, { name = "requests", specifier = ">=2.32.5" }, + { name = "rich", specifier = ">=14.3.2" }, { name = "snowflake-connector-python", marker = "extra == 'snowflake'", specifier = ">=4.2.0" }, ] provides-extras = ["mssql", "clickhouse", "athena", "snowflake", "mysql", "postgresql"] @@ -771,6 +773,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ca/28/2635a8141c9a4f4bc23f5135a92bbcf48d928d8ca094088c962df1879d64/lz4-4.4.5-cp314-cp314-win_arm64.whl", hash = "sha256:d994b87abaa7a88ceb7a37c90f547b8284ff9da694e6afcfaa8568d739faf3f7", size = 93812, upload-time = "2025-11-03T13:02:26.133Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -859,6 +873,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "msal" version = "1.34.0" @@ -1359,6 +1382,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "rich" +version = "14.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/99/a4cab2acbb884f80e558b0771e97e21e939c5dfb460f488d19df485e8298/rich-14.3.2.tar.gz", hash = "sha256:e712f11c1a562a11843306f5ed999475f09ac31ffb64281f73ab29ffdda8b3b8", size = 230143, upload-time = "2026-02-01T16:20:47.908Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/45/615f5babd880b4bd7d405cc0dc348234c5ffb6ed1ea33e152ede08b2072d/rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69", size = 309963, upload-time = "2026-02-01T16:20:46.078Z" }, +] + [[package]] name = "rpds-py" version = "0.30.0" From 02842b55f7045e6e87e353f33df51a3d2e90193a Mon Sep 17 00:00:00 2001 From: MateusCordeiro Date: Tue, 10 Feb 2026 10:50:16 +0000 Subject: [PATCH 2/6] v2 --- .../build_sources/build_runner.py | 17 +- .../build_sources/build_service.py | 25 +- .../databao_context_project_manager.py | 1 + .../progress/progress.py | 156 ++------- .../progress/rich_progress.py | 296 ++++++++---------- .../services/chunk_embedding_service.py | 20 +- .../services/persistence_service.py | 20 +- 7 files changed, 190 insertions(+), 345 deletions(-) diff --git a/src/databao_context_engine/build_sources/build_runner.py b/src/databao_context_engine/build_sources/build_runner.py index 42193424..20cf5342 100644 --- a/src/databao_context_engine/build_sources/build_runner.py +++ b/src/databao_context_engine/build_sources/build_runner.py @@ -67,6 +67,8 @@ def build( emitter.build_started(total_datasources=len(datasources)) number_of_failed_builds = 0 + number_of_skipped_builds = 0 + build_result = [] reset_all_results(project_layout.output_dir) for datasource_index, discovered_datasource in enumerate(datasources, start=1): @@ -74,14 +76,12 @@ def build( try: prepared_source = prepare_source(discovered_datasource) - # logger.info( - # f'Found datasource of type "{prepared_source.datasource_type.full_type}" with name {prepared_source.path.stem}' - # ) + logger.info( + f'Found datasource of type "{prepared_source.datasource_type.full_type}" with name {prepared_source.path.stem}' + ) emitter.datasource_started( datasource_id=datasource_id, - datasource_type=prepared_source.datasource_type.full_type, - datasource_path=str(prepared_source.path), index=datasource_index, total=len(datasources), ) @@ -98,7 +98,7 @@ def build( total=len(datasources), status=DatasourceStatus.SKIPPED, ) - number_of_failed_builds += 1 + number_of_skipped_builds += 1 continue result = build_service.process_prepared_source( @@ -139,15 +139,16 @@ def build( number_of_failed_builds += 1 logger.debug( - "Successfully built %d datasources. %s", + "Successfully built %d datasources. %s %s", len(build_result), + f"Skipped {number_of_skipped_builds}." if number_of_skipped_builds > 0 else "", f"Failed to build {number_of_failed_builds}." if number_of_failed_builds > 0 else "", ) emitter.build_finished( ok=len(build_result), failed=number_of_failed_builds, - skipped=0, + skipped=number_of_skipped_builds, ) return build_result diff --git a/src/databao_context_engine/build_sources/build_service.py b/src/databao_context_engine/build_sources/build_service.py index 4e0f181b..10d0b669 100644 --- a/src/databao_context_engine/build_sources/build_service.py +++ b/src/databao_context_engine/build_sources/build_service.py @@ -7,7 +7,7 @@ from databao_context_engine.pluginlib.build_plugin import ( BuildPlugin, ) -from databao_context_engine.progress.progress import ProgressCallback, ProgressEmitter, DatasourcePhase +from databao_context_engine.progress.progress import ProgressCallback from databao_context_engine.serialization.yaml import to_yaml_string from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingService @@ -38,40 +38,19 @@ def process_prepared_source( Returns: The built context. """ - emitter = ProgressEmitter(progress) - - emitter.datasource_phase( - datasource_id=prepared_source.datasource_id if hasattr(prepared_source, 'datasource_id') else '', - phase=DatasourcePhase.EXECUTE_PLUGIN, - message="Executing plugin", - ) - result = execute(prepared_source, plugin) - emitter.datasource_phase( - datasource_id=result.datasource_id, - phase=DatasourcePhase.DIVIDE_CHUNKS, - message="Dividing context into chunks", - ) - chunks = plugin.divide_context_into_chunks(result.context) - emitter.chunks_discovered(datasource_id=result.datasource_id, total_chunks=len(chunks)) if not chunks: logger.info("No chunks for %s — skipping.", prepared_source.path.name) return result - emitter.datasource_phase( - datasource_id=result.datasource_id, - phase=DatasourcePhase.EMBED, - message="Embedding chunks", - ) - self._chunk_embedding_service.embed_chunks( chunks=chunks, result=to_yaml_string(result.context), full_type=prepared_source.datasource_type.full_type, datasource_id=result.datasource_id, - progress=progress + progress=progress, ) return result diff --git a/src/databao_context_engine/databao_context_project_manager.py b/src/databao_context_engine/databao_context_project_manager.py index 4271d5bf..4cd653d7 100644 --- a/src/databao_context_engine/databao_context_project_manager.py +++ b/src/databao_context_engine/databao_context_project_manager.py @@ -85,6 +85,7 @@ def build_context( Args: datasource_ids: The list of datasource ids to build. If None, all datasources will be built. chunk_embedding_mode: The mode to use for chunk embedding. + progress: Optional callback that receives progress events during execution. Returns: The list of all built results. diff --git a/src/databao_context_engine/progress/progress.py b/src/databao_context_engine/progress/progress.py index 3f462e42..cd9f6f1c 100644 --- a/src/databao_context_engine/progress/progress.py +++ b/src/databao_context_engine/progress/progress.py @@ -4,29 +4,15 @@ from enum import Enum from typing import Callable +EMIT_EVERY = 10 + class ProgressKind(str, Enum): - # Build-level BUILD_STARTED = "build_started" BUILD_FINISHED = "build_finished" - - # Datasource-level lifecycle DATASOURCE_STARTED = "datasource_started" DATASOURCE_FINISHED = "datasource_finished" - - # Datasource-level detail / phases - DATASOURCE_PHASE = "datasource_phase" - CHUNKS_DISCOVERED = "chunks_discovered" - - # Chunk embedding - EMBEDDING_STARTED = "embedding_started" - EMBEDDING_PROGRESS = "embedding_progress" - EMBEDDING_FINISHED = "embedding_finished" - - # Persistence - PERSIST_STARTED = "persist_started" - PERSIST_PROGRESS = "persist_progress" - PERSIST_FINISHED = "persist_finished" + DATASOURCE_PROGRESS = "datasource_progress" class DatasourceStatus(str, Enum): @@ -35,45 +21,15 @@ class DatasourceStatus(str, Enum): FAILED = "failed" -class DatasourcePhase(str, Enum): - EXECUTE_PLUGIN = "execute_plugin" - DIVIDE_CHUNKS = "divide_chunks" - EMBED = "embed" - PERSIST = "persist" - EXPORT = "export" - - @dataclass(frozen=True, slots=True) class ProgressEvent: - """A structured progress update emitted by the build pipeline. - - Conventions: - - `total` may be None when unknown. - - For *_PROGRESS events, `done` should be monotonic increasing up to `total` (if total known). - - `message` is optional human-readable text; callers can ignore it and render their own. - """ - kind: ProgressKind - - # Identifiers (often present, but not required for build-level events) datasource_id: str | None = None - datasource_type: str | None = None # e.g. full_type - datasource_path: str | None = None # if helpful for debugging/logging - - # Build progress (datasources) - datasource_index: int | None = None # 1-based index - datasource_total: int | None = None # total datasources - - # Sub-progress (chunks / persistence) - done: int | None = None - total: int | None = None - - # Extra structured fields - phase: DatasourcePhase | None = None + datasource_index: int | None = None + datasource_total: int | None = None + percent: int | None = None status: DatasourceStatus | None = None error: str | None = None - - # Human text message: str = "" @@ -81,11 +37,6 @@ class ProgressEvent: class ProgressEmitter: - """Small helper so you don't sprinkle `if progress: progress(...)` everywhere. - - You can also extend this later with throttling helpers. - """ - def __init__(self, cb: ProgressCallback | None): self._cb = cb @@ -93,7 +44,6 @@ def emit(self, event: ProgressEvent) -> None: if self._cb is not None: self._cb(event) - # Convenience builders (optional, but nice for consistency) def build_started(self, *, total_datasources: int | None) -> None: self.emit(ProgressEvent(kind=ProgressKind.BUILD_STARTED, datasource_total=total_datasources)) @@ -109,8 +59,6 @@ def datasource_started( self, *, datasource_id: str, - datasource_type: str | None, - datasource_path: str | None, index: int, total: int, ) -> None: @@ -118,82 +66,42 @@ def datasource_started( ProgressEvent( kind=ProgressKind.DATASOURCE_STARTED, datasource_id=datasource_id, - datasource_type=datasource_type, - datasource_path=datasource_path, datasource_index=index, datasource_total=total, message=f"Starting {datasource_id}", ) ) - def datasource_phase(self, *, datasource_id: str, phase: DatasourcePhase, message: str = "") -> None: - self.emit( - ProgressEvent( - kind=ProgressKind.DATASOURCE_PHASE, - datasource_id=datasource_id, - phase=phase, - message=message or phase.value, - ) - ) - - def chunks_discovered(self, *, datasource_id: str, total_chunks: int) -> None: - self.emit( - ProgressEvent( - kind=ProgressKind.CHUNKS_DISCOVERED, - datasource_id=datasource_id, - total=total_chunks, - message=f"Discovered {total_chunks} chunks", - ) - ) - - def embedding_started(self, *, datasource_id: str, total_chunks: int) -> None: - self.emit( - ProgressEvent( - kind=ProgressKind.EMBEDDING_STARTED, - datasource_id=datasource_id, - total=total_chunks, - message="Embedding started", - ) - ) - - def embedding_progress(self, *, datasource_id: str, done: int, total: int | None, message: str = "") -> None: - self.emit( - ProgressEvent( - kind=ProgressKind.EMBEDDING_PROGRESS, - datasource_id=datasource_id, - done=done, - total=total, - message=message, - ) - ) - - def embedding_finished(self, *, datasource_id: str) -> None: - self.emit(ProgressEvent(kind=ProgressKind.EMBEDDING_FINISHED, datasource_id=datasource_id, message="Embedding finished")) - - def persist_started(self, *, datasource_id: str, total_items: int) -> None: - self.emit( - ProgressEvent( - kind=ProgressKind.PERSIST_STARTED, - datasource_id=datasource_id, - total=total_items, - message="Persist started", - ) - ) - - def persist_progress(self, *, datasource_id: str, done: int, total: int | None, message: str = "") -> None: + def datasource_progress_units( + self, + *, + datasource_id: str, + completed_units: int, + total_units: int, + message: str = "", + ) -> None: + if total_units <= 0: + self.datasource_progress(datasource_id=datasource_id, percent=100, message=message) + return + + completed_units = max(0, min(completed_units, total_units)) + percent = round((completed_units / total_units) * 100) + self.datasource_progress(datasource_id=datasource_id, percent=percent, message=message) + + def datasource_progress(self, *, datasource_id: str, percent: int, message: str = "") -> None: + if percent < 0: + percent = 0 + if percent > 100: + percent = 100 self.emit( ProgressEvent( - kind=ProgressKind.PERSIST_PROGRESS, + kind=ProgressKind.DATASOURCE_PROGRESS, datasource_id=datasource_id, - done=done, - total=total, + percent=percent, message=message, ) ) - def persist_finished(self, *, datasource_id: str) -> None: - self.emit(ProgressEvent(kind=ProgressKind.PERSIST_FINISHED, datasource_id=datasource_id, message="Persist finished")) - def datasource_finished( self, *, @@ -211,6 +119,8 @@ def datasource_finished( datasource_total=total, status=status, error=error, - message=(f"Finished {datasource_id}" if status == DatasourceStatus.OK else f"{status.value}: {datasource_id}"), + message=( + f"Finished {datasource_id}" if status == DatasourceStatus.OK else f"{status.value}: {datasource_id}" + ), ) - ) \ No newline at end of file + ) diff --git a/src/databao_context_engine/progress/rich_progress.py b/src/databao_context_engine/progress/rich_progress.py index aca011a7..b1563d14 100644 --- a/src/databao_context_engine/progress/rich_progress.py +++ b/src/databao_context_engine/progress/rich_progress.py @@ -1,220 +1,170 @@ -# databao_context_engine/cli/rich_progress.py from __future__ import annotations +import logging +import sys from contextlib import contextmanager -from typing import Iterator +from typing import Callable, Iterator, Optional from rich.console import Console from rich.progress import ( BarColumn, Progress, + ProgressColumn, SpinnerColumn, TaskID, TaskProgressColumn, TextColumn, TimeRemainingColumn, ) +from rich.table import Column +from rich.text import Text from databao_context_engine.progress.progress import ( - DatasourcePhase, ProgressCallback, ProgressEvent, ProgressKind, ) +_DESCRIPTION_COL_WIDTH = 50 + + +def _datasource_label(ds_id: str | None) -> str: + return ds_id or "datasource" + + +class _EtaExceptOverallColumn(ProgressColumn): + def __init__(self, overall_task_id_getter: Callable[[], Optional[TaskID]]): + super().__init__() + self._overall_task_id_getter = overall_task_id_getter + self._eta = TimeRemainingColumn() + + def render(self, task) -> Text: + overall_id = self._overall_task_id_getter() + if overall_id is not None and task.id == overall_id: + return Text("") + return self._eta.render(task) + @contextmanager def rich_progress() -> Iterator[ProgressCallback]: - """Context manager that provides a ProgressCallback rendering a live UI via Rich. + interactive = sys.stderr.isatty() + if not interactive: + + def noop(_: ProgressEvent) -> None: + return + + yield noop + return - Usage: - with rich_progress() as progress_cb: - manager.build_context(..., progress=progress_cb) - """ # noqa: DOC402 console = Console(stderr=True) + + tasks: dict[str, TaskID] = {} + ui_state = { + "datasource_index": None, + "datasource_total": None, + "last_percent": 0, + } + progress = Progress( SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), + TextColumn( + "[progress.description]{task.description}", + table_column=Column(width=_DESCRIPTION_COL_WIDTH, overflow="ellipsis", no_wrap=True), + ), BarColumn(), TaskProgressColumn(), - TimeRemainingColumn(), - transient=False, # keep the progress display after finishing (change to True if you prefer) + _EtaExceptOverallColumn(lambda: tasks.get("overall")), + transient=True, console=console, + redirect_stdout=True, + redirect_stderr=True, ) - # Task IDs (created lazily when we receive events) - tasks: dict[str, TaskID] = {} - - # Some state so we can update descriptions nicely - state = { - "current_datasource_id": None, - "current_phase": None, - "datasource_index": None, - "datasource_total": None, - "embed_total": None, - "persist_total": None, - } - - def _ensure_overall_task(total: int | None) -> TaskID: + def _get_or_create_overall_task(total: int | None) -> TaskID: if "overall" not in tasks: tasks["overall"] = progress.add_task("Datasources", total=total) else: - # allow setting/updating totals later if total is not None: progress.update(tasks["overall"], total=total) return tasks["overall"] - def _ensure_embed_task(total: int | None) -> TaskID: - ds = state["current_datasource_id"] or "datasource" - phase = state["current_phase"] or DatasourcePhase.EMBED - desc = f"{ds} • {phase.value}" - if "embed" not in tasks: - tasks["embed"] = progress.add_task(desc, total=total) - else: - progress.update(tasks["embed"], description=desc) - if total is not None: - progress.update(tasks["embed"], total=total) - return tasks["embed"] - - def _ensure_persist_task(total: int | None) -> TaskID: - ds = state["current_datasource_id"] or "datasource" - phase = state["current_phase"] or DatasourcePhase.PERSIST - desc = f"{ds} • {phase.value}" - if "persist" not in tasks: - tasks["persist"] = progress.add_task(desc, total=total) - else: - progress.update(tasks["persist"], description=desc) - if total is not None: - progress.update(tasks["persist"], total=total) - return tasks["persist"] + def _get_or_create_datasource_task() -> TaskID: + if "datasource" not in tasks: + tasks["datasource"] = progress.add_task("Datasource", total=100.0) + return tasks["datasource"] + + def _set_datasource_percent(percent: float) -> None: + task_id = _get_or_create_datasource_task() + clamped = max(0.0, min(100.0, percent)) + progress.update(task_id, completed=clamped) def _update_overall_description() -> None: if "overall" not in tasks: return - idx = state["datasource_index"] - tot = state["datasource_total"] - ds = state["current_datasource_id"] - ph = state["current_phase"].value if state["current_phase"] else None + idx = ui_state["datasource_index"] + tot = ui_state["datasource_total"] - if idx is not None and tot is not None and ds: - suffix = f" • {ds}" - if ph: - suffix += f" • {ph}" - progress.update(tasks["overall"], description=f"Datasources {idx}/{tot}{suffix}") + if idx is not None and tot is not None: + progress.update(tasks["overall"], description=f"Datasources {idx}/{tot}") def on_event(ev: ProgressEvent) -> None: - # ---- Build-level ---- - if ev.kind == ProgressKind.BUILD_STARTED: - _ensure_overall_task(ev.datasource_total) - return - - if ev.kind == ProgressKind.BUILD_FINISHED: - # Optional: mark overall as completed if totals are known. - # (We’re already advancing it on datasource finished.) - if ev.message: - progress.console.print(f"{ev.message}") - return - - # ---- Datasource lifecycle ---- - if ev.kind == ProgressKind.DATASOURCE_STARTED: - state["current_datasource_id"] = ev.datasource_id - state["datasource_index"] = ev.datasource_index - state["datasource_total"] = ev.datasource_total - state["embed_total"] = None - state["persist_total"] = None - _ensure_overall_task(ev.datasource_total) - _update_overall_description() - - # Reset per-datasource tasks so each datasource starts fresh without flicker. - # Keep the tasks and just reset progress/description. - ds = state["current_datasource_id"] or "datasource" - - if "embed" in tasks: - progress.update( - tasks["embed"], completed=0, total=None, description=f"{ds} • {DatasourcePhase.EMBED.value}" - ) - - if "persist" in tasks: - progress.update( - tasks["persist"], completed=0, total=None, description=f"{ds} • {DatasourcePhase.PERSIST.value}" - ) - return - - if ev.kind == ProgressKind.DATASOURCE_PHASE: - # Update current phase and refresh the overall description line - state["current_phase"] = ev.phase - _update_overall_description() - return - - if ev.kind == ProgressKind.DATASOURCE_FINISHED: - idx = ev.datasource_index or 0 - tot = ev.datasource_total or 0 - ds = ev.datasource_id or "(unknown datasource)" - status = ev.status.value if ev.status else "done" - progress.console.print(f"{status.upper()} {idx}/{tot}: {ds}") - - - # Advance overall by 1 (ok/failed/skipped are all "completed" from a progress perspective) - _ensure_overall_task(ev.datasource_total) - progress.advance(tasks["overall"], 1) - - # Clear phase display for next datasource - state["current_phase"] = None - _update_overall_description() - return - - # ---- Chunk discovery is informative; embed total will arrive in EMBEDDING_STARTED anyway ---- - if ev.kind == ProgressKind.CHUNKS_DISCOVERED: - # You could surface this as a message or update descriptions; optional. - return - - # ---- Embedding ---- - if ev.kind == ProgressKind.EMBEDDING_STARTED: - state["current_phase"] = DatasourcePhase.EMBED - _update_overall_description() - _ensure_embed_task(ev.total) - state["embed_total"] = ev.total - return - - if ev.kind == ProgressKind.EMBEDDING_PROGRESS: - task_id = _ensure_embed_task(ev.total) - # Use completed to avoid drift if events are throttled - if ev.done is not None: - progress.update(task_id, completed=ev.done) - if ev.total is not None: - progress.update(task_id, total=ev.total) - state["embed_total"] = ev.total - return - - if ev.kind == ProgressKind.EMBEDDING_FINISHED: - # If we have an embed task, mark it complete (if total known) - if "embed" in tasks and state["embed_total"] is not None: - progress.update(tasks["embed"], completed=state["embed_total"]) - return - - # ---- Persistence ---- - if ev.kind == ProgressKind.PERSIST_STARTED: - state["current_phase"] = DatasourcePhase.PERSIST - _update_overall_description() - _ensure_persist_task(ev.total) - state["persist_total"] = ev.total - return - - if ev.kind == ProgressKind.PERSIST_PROGRESS: - task_id = _ensure_persist_task(ev.total) - if ev.done is not None: - progress.update(task_id, completed=ev.done) - if ev.total is not None: - progress.update(task_id, total=ev.total) - state["persist_total"] = ev.total - return - - if ev.kind == ProgressKind.PERSIST_FINISHED: - if "persist" in tasks and state["persist_total"] is not None: - progress.update(tasks["persist"], completed=state["persist_total"]) - return - - # Ignore unknown events by default - - with progress: - yield on_event + match ev.kind: + case ProgressKind.BUILD_STARTED: + _get_or_create_overall_task(ev.datasource_total) + return + case ProgressKind.BUILD_FINISHED: + if ev.message: + progress.console.print(f"{ev.message}") + return + case ProgressKind.DATASOURCE_STARTED: + ui_state["datasource_index"] = ev.datasource_index + ui_state["datasource_total"] = ev.datasource_total + ui_state["last_percent"] = 0 + _get_or_create_overall_task(ev.datasource_total) + _update_overall_description() + + ds = _datasource_label(ev.datasource_id) + + task_id = _get_or_create_datasource_task() + progress.reset(task_id, completed=0, total=100.0, description=f"{ds}") + return + case ProgressKind.DATASOURCE_FINISHED: + idx = ev.datasource_index or 0 + tot = ev.datasource_total or 0 + ds = ev.datasource_id or "(unknown datasource)" + status = ev.status.value if ev.status else "done" + progress.console.print(f"{status.upper()} {idx}/{tot}: {ds}") + + _set_datasource_percent(100.0) + + _get_or_create_overall_task(ev.datasource_total) + progress.advance(tasks["overall"], 1) + + _update_overall_description() + return + case ProgressKind.DATASOURCE_PROGRESS: + if ev.percent is not None: + pct = int(ev.percent) + pct = max(ui_state["last_percent"], pct) + ui_state["last_percent"] = pct + _set_datasource_percent(float(pct)) + return + + root = logging.getLogger() + prev_level = root.level + prev_handlers = list(root.handlers) + + prev_disable_level = logging.root.manager.disable + logging.disable(logging.CRITICAL) + + try: + with progress: + yield on_event + finally: + logging.disable(prev_disable_level) + + for h in list(root.handlers): + root.removeHandler(h) + for h in prev_handlers: + root.addHandler(h) + root.setLevel(prev_level) diff --git a/src/databao_context_engine/services/chunk_embedding_service.py b/src/databao_context_engine/services/chunk_embedding_service.py index d15ca2c2..265182a1 100644 --- a/src/databao_context_engine/services/chunk_embedding_service.py +++ b/src/databao_context_engine/services/chunk_embedding_service.py @@ -5,7 +5,7 @@ from databao_context_engine.llm.descriptions.provider import DescriptionProvider from databao_context_engine.llm.embeddings.provider import EmbeddingProvider from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk -from databao_context_engine.progress.progress import ProgressCallback, ProgressEmitter +from databao_context_engine.progress.progress import EMIT_EVERY, ProgressCallback, ProgressEmitter from databao_context_engine.serialization.yaml import to_yaml_string from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver from databao_context_engine.services.models import ChunkEmbedding @@ -78,14 +78,12 @@ def embed_chunks( return emitter = ProgressEmitter(progress) - emitter.embedding_started(datasource_id=datasource_id, total_chunks=len(chunks)) logger.debug( f"Embedding {len(chunks)} chunks for datasource {datasource_id}, with chunk_embedding_mode={self._chunk_embedding_mode}" ) enriched_embeddings: list[ChunkEmbedding] = [] - emit_every = 10 # avoid overly chatty UIs for i, chunk in enumerate(chunks, start=1): chunk_display_text = chunk.content if isinstance(chunk.content, str) else to_yaml_string(chunk.content) @@ -114,15 +112,19 @@ def embed_chunks( generated_description=generated_description, ) ) - if i % emit_every == 0 or i == len(chunks): - emitter.embedding_progress( + if i % EMIT_EVERY == 0 or i == len(chunks): + total_units = len(chunks) * 2 + emitter.datasource_progress_units( datasource_id=datasource_id, - done=i, - total=len(chunks), - message=f"Embedded {i}/{len(chunks)} chunks", + completed_units=i, + total_units=total_units, ) - emitter.embedding_finished(datasource_id=datasource_id) + emitter.datasource_progress_units( + datasource_id=datasource_id, + completed_units=len(chunks), + total_units=len(chunks) * 2, + ) table_name = self._shard_resolver.resolve_or_create( embedder=self._embedding_provider.embedder, diff --git a/src/databao_context_engine/services/persistence_service.py b/src/databao_context_engine/services/persistence_service.py index 102cd305..7994c896 100644 --- a/src/databao_context_engine/services/persistence_service.py +++ b/src/databao_context_engine/services/persistence_service.py @@ -2,7 +2,7 @@ import duckdb -from databao_context_engine.progress.progress import ProgressCallback, ProgressEmitter +from databao_context_engine.progress.progress import EMIT_EVERY, ProgressCallback, ProgressEmitter from databao_context_engine.services.models import ChunkEmbedding from databao_context_engine.storage.models import ChunkDTO from databao_context_engine.storage.repositories.chunk_repository import ChunkRepository @@ -44,10 +44,8 @@ def write_chunks_and_embeddings( emitter = ProgressEmitter(progress) total_items = len(chunk_embeddings) - emitter.persist_started(datasource_id=datasource_id, total_items=total_items) with transaction(self._conn): - emit_every = 25 for i, chunk_embedding in enumerate(chunk_embeddings, start=1): chunk_dto = self.create_chunk( full_type=full_type, @@ -57,14 +55,18 @@ def write_chunks_and_embeddings( ) self.create_embedding(table_name=table_name, chunk_id=chunk_dto.chunk_id, vec=chunk_embedding.vec) - if i % emit_every == 0 or i == total_items: - emitter.persist_progress( + if i % EMIT_EVERY == 0 or i == total_items: + total_units = total_items * 2 + emitter.datasource_progress_units( datasource_id=datasource_id, - done=i, - total=total_items, - message=f"Persisted {i}/{total_items} chunks", + completed_units=total_items + i, + total_units=total_units, ) - emitter.persist_finished(datasource_id=datasource_id) + emitter.datasource_progress_units( + datasource_id=datasource_id, + completed_units=total_items * 2, + total_units=total_items * 2, + ) def create_chunk(self, *, full_type: str, datasource_id: str, embeddable_text: str, display_text: str) -> ChunkDTO: return self._chunk_repo.create( From 702e15069d8bc7bd3136894a0cc3edc7bf572b1f Mon Sep 17 00:00:00 2001 From: MateusCordeiro Date: Tue, 10 Feb 2026 11:04:36 +0000 Subject: [PATCH 3/6] Add opt dependency --- pyproject.toml | 4 +- .../build_sources/export_results.py | 2 +- src/databao_context_engine/cli/commands.py | 4 +- .../progress/rich_progress.py | 69 +++++++++++-------- uv.lock | 8 ++- 5 files changed, 51 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 029f9d4a..68d3184d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ dependencies = [ "mcp>=1.23.3", "pydantic>=2.12.4", "jinja2>=3.1.6", - "rich>=14.3.2" ] default-optional-dependency-keys = [ "mysql", @@ -20,6 +19,9 @@ default-optional-dependency-keys = [ ] [project.optional-dependencies] +cli = [ + "rich>=14.3.2" +] mssql = [ "mssql-python>=1.0.0" ] diff --git a/src/databao_context_engine/build_sources/export_results.py b/src/databao_context_engine/build_sources/export_results.py index b00de834..3fd1814b 100644 --- a/src/databao_context_engine/build_sources/export_results.py +++ b/src/databao_context_engine/build_sources/export_results.py @@ -20,7 +20,7 @@ def export_build_result(output_dir: Path, result: BuiltDatasourceContext) -> Pat with export_file_path.open("w") as export_file: write_yaml_to_stream(data=result, file_stream=export_file) - # logger.info(f"Exported result to {export_file_path.resolve()}") + logger.info(f"Exported result to {export_file_path.resolve()}") return export_file_path diff --git a/src/databao_context_engine/cli/commands.py b/src/databao_context_engine/cli/commands.py index 956c6636..c4a55b07 100644 --- a/src/databao_context_engine/cli/commands.py +++ b/src/databao_context_engine/cli/commands.py @@ -154,7 +154,9 @@ def build( """ with rich_progress() as progress_cb: result = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).build_context( - datasource_ids=None, chunk_embedding_mode=ChunkEmbeddingMode(chunk_embedding_mode.upper()), progress=progress_cb + datasource_ids=None, + chunk_embedding_mode=ChunkEmbeddingMode(chunk_embedding_mode.upper()), + progress=progress_cb, ) click.echo(f"Build complete. Processed {len(result)} datasources.") diff --git a/src/databao_context_engine/progress/rich_progress.py b/src/databao_context_engine/progress/rich_progress.py index b1563d14..4f1adfb3 100644 --- a/src/databao_context_engine/progress/rich_progress.py +++ b/src/databao_context_engine/progress/rich_progress.py @@ -3,21 +3,7 @@ import logging import sys from contextlib import contextmanager -from typing import Callable, Iterator, Optional - -from rich.console import Console -from rich.progress import ( - BarColumn, - Progress, - ProgressColumn, - SpinnerColumn, - TaskID, - TaskProgressColumn, - TextColumn, - TimeRemainingColumn, -) -from rich.table import Column -from rich.text import Text +from typing import Callable, Iterator, Optional, TypedDict from databao_context_engine.progress.progress import ( ProgressCallback, @@ -32,34 +18,57 @@ def _datasource_label(ds_id: str | None) -> str: return ds_id or "datasource" -class _EtaExceptOverallColumn(ProgressColumn): - def __init__(self, overall_task_id_getter: Callable[[], Optional[TaskID]]): - super().__init__() - self._overall_task_id_getter = overall_task_id_getter - self._eta = TimeRemainingColumn() +def _noop_progress_cb(_: ProgressEvent) -> None: + return + - def render(self, task) -> Text: - overall_id = self._overall_task_id_getter() - if overall_id is not None and task.id == overall_id: - return Text("") - return self._eta.render(task) +class _UIState(TypedDict): + datasource_index: int | None + datasource_total: int | None + last_percent: int @contextmanager def rich_progress() -> Iterator[ProgressCallback]: + try: + from rich.console import Console + from rich.progress import ( + BarColumn, + Progress, + ProgressColumn, + SpinnerColumn, + TaskID, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, + ) + from rich.table import Column + from rich.text import Text + except ImportError: + yield _noop_progress_cb + return + interactive = sys.stderr.isatty() if not interactive: + yield _noop_progress_cb + return - def noop(_: ProgressEvent) -> None: - return + class _EtaExceptOverallColumn(ProgressColumn): + def __init__(self, overall_task_id_getter: Callable[[], Optional[TaskID]]): + super().__init__() + self._overall_task_id_getter = overall_task_id_getter + self._eta = TimeRemainingColumn() - yield noop - return + def render(self, task) -> Text: + overall_id = self._overall_task_id_getter() + if overall_id is not None and task.id == overall_id: + return Text("") + return self._eta.render(task) console = Console(stderr=True) tasks: dict[str, TaskID] = {} - ui_state = { + ui_state: _UIState = { "datasource_index": None, "datasource_total": None, "last_percent": 0, diff --git a/uv.lock b/uv.lock index 52fbc68d..86ec042c 100644 --- a/uv.lock +++ b/uv.lock @@ -420,13 +420,15 @@ dependencies = [ { name = "pydantic" }, { name = "pyyaml" }, { name = "requests" }, - { name = "rich" }, ] [package.optional-dependencies] athena = [ { name = "pyathena" }, ] +cli = [ + { name = "rich" }, +] clickhouse = [ { name = "clickhouse-connect" }, ] @@ -471,10 +473,10 @@ requires-dist = [ { name = "pymysql", marker = "extra == 'mysql'", specifier = ">=1.1.2" }, { name = "pyyaml", specifier = ">=6.0.3" }, { name = "requests", specifier = ">=2.32.5" }, - { name = "rich", specifier = ">=14.3.2" }, + { name = "rich", marker = "extra == 'cli'", specifier = ">=14.3.2" }, { name = "snowflake-connector-python", marker = "extra == 'snowflake'", specifier = ">=4.2.0" }, ] -provides-extras = ["mssql", "clickhouse", "athena", "snowflake", "mysql", "postgresql"] +provides-extras = ["cli", "mssql", "clickhouse", "athena", "snowflake", "mysql", "postgresql"] [package.metadata.requires-dev] dev = [ From 57705f02efac026a2157c365fd41fa3193b1576e Mon Sep 17 00:00:00 2001 From: MateusCordeiro Date: Tue, 10 Feb 2026 12:54:34 +0000 Subject: [PATCH 4/6] Fix broken test --- tests/build_sources/test_build_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/build_sources/test_build_service.py b/tests/build_sources/test_build_service.py index 3a23946a..a75f4b93 100644 --- a/tests/build_sources/test_build_service.py +++ b/tests/build_sources/test_build_service.py @@ -72,6 +72,7 @@ def test_process_prepared_source_happy_path_creates_row_and_embeds(svc, chunk_em result=f"context: ok{os.linesep}", datasource_id="files/two.md", full_type="files/md", + progress=None, ) assert out is result From 73bc5fd07671eed6ce1008358173c590e46389ea Mon Sep 17 00:00:00 2001 From: MateusCordeiro Date: Tue, 10 Feb 2026 16:34:58 +0000 Subject: [PATCH 5/6] Rename build started/ended to task started ended in progress bar --- .../build_sources/build_runner.py | 16 ++++++++-------- .../build_sources/build_wiring.py | 4 +++- src/databao_context_engine/progress/progress.py | 12 ++++++------ .../progress/rich_progress.py | 4 ++-- tests/build_sources/test_build_runner.py | 6 +++--- tests/build_sources/test_build_service.py | 1 + tests/test_databao_context_project_manager.py | 2 ++ 7 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/databao_context_engine/build_sources/build_runner.py b/src/databao_context_engine/build_sources/build_runner.py index 739aedc4..ace2481b 100644 --- a/src/databao_context_engine/build_sources/build_runner.py +++ b/src/databao_context_engine/build_sources/build_runner.py @@ -73,11 +73,11 @@ def build( if not datasource_ids: logger.info("No sources discovered under %s", project_layout.src_dir) - emitter.build_started(total_datasources=0) - emitter.build_finished(ok=0, failed=0, skipped=0) + emitter.task_started(total_datasources=0) + emitter.task_finished(ok=0, failed=0, skipped=0) return [] - emitter.build_started(total_datasources=len(datasource_ids)) + emitter.task_started(total_datasources=len(datasource_ids)) number_of_failed_builds = 0 number_of_skipped_builds = 0 @@ -157,7 +157,7 @@ def build( f"Failed to build {number_of_failed_builds}." if number_of_failed_builds > 0 else "", ) - emitter.build_finished( + emitter.task_finished( ok=len(build_result), failed=number_of_failed_builds, skipped=number_of_skipped_builds, @@ -189,11 +189,11 @@ def run_indexing( emitter = ProgressEmitter(progress) if not contexts: - emitter.build_started(total_datasources=0) - emitter.build_finished(ok=0, failed=0, skipped=0) + emitter.task_started(total_datasources=0) + emitter.task_finished(ok=0, failed=0, skipped=0) return summary - emitter.build_started(total_datasources=len(contexts)) + emitter.task_started(total_datasources=len(contexts)) for datasource_index, context in enumerate(contexts, start=1): try: @@ -251,5 +251,5 @@ def run_indexing( f"Skipped {summary.skipped}. Failed {summary.failed}." if (summary.skipped or summary.failed) else "", ) - emitter.build_finished(ok=summary.indexed, failed=summary.failed, skipped=summary.skipped) + emitter.task_finished(ok=summary.indexed, failed=summary.failed, skipped=summary.skipped) return summary diff --git a/src/databao_context_engine/build_sources/build_wiring.py b/src/databao_context_engine/build_sources/build_wiring.py index 29c2af7b..024ef2fa 100644 --- a/src/databao_context_engine/build_sources/build_wiring.py +++ b/src/databao_context_engine/build_sources/build_wiring.py @@ -108,7 +108,9 @@ def index_built_contexts( description_provider=description_provider, chunk_embedding_mode=chunk_embedding_mode, ) - return run_indexing(project_layout=project_layout, build_service=build_service, contexts=contexts, progress=progress) + return run_indexing( + project_layout=project_layout, build_service=build_service, contexts=contexts, progress=progress + ) def _create_build_service( diff --git a/src/databao_context_engine/progress/progress.py b/src/databao_context_engine/progress/progress.py index cd9f6f1c..efc4050c 100644 --- a/src/databao_context_engine/progress/progress.py +++ b/src/databao_context_engine/progress/progress.py @@ -8,8 +8,8 @@ class ProgressKind(str, Enum): - BUILD_STARTED = "build_started" - BUILD_FINISHED = "build_finished" + TASK_STARTED = "task_started" + TASK_FINISHED = "task_finished" DATASOURCE_STARTED = "datasource_started" DATASOURCE_FINISHED = "datasource_finished" DATASOURCE_PROGRESS = "datasource_progress" @@ -44,13 +44,13 @@ def emit(self, event: ProgressEvent) -> None: if self._cb is not None: self._cb(event) - def build_started(self, *, total_datasources: int | None) -> None: - self.emit(ProgressEvent(kind=ProgressKind.BUILD_STARTED, datasource_total=total_datasources)) + def task_started(self, *, total_datasources: int | None) -> None: + self.emit(ProgressEvent(kind=ProgressKind.TASK_STARTED, datasource_total=total_datasources)) - def build_finished(self, *, ok: int, failed: int, skipped: int) -> None: + def task_finished(self, *, ok: int, failed: int, skipped: int) -> None: self.emit( ProgressEvent( - kind=ProgressKind.BUILD_FINISHED, + kind=ProgressKind.TASK_FINISHED, message=f"Finished (ok={ok}, failed={failed}, skipped={skipped})", ) ) diff --git a/src/databao_context_engine/progress/rich_progress.py b/src/databao_context_engine/progress/rich_progress.py index 4f1adfb3..898d48f9 100644 --- a/src/databao_context_engine/progress/rich_progress.py +++ b/src/databao_context_engine/progress/rich_progress.py @@ -118,10 +118,10 @@ def _update_overall_description() -> None: def on_event(ev: ProgressEvent) -> None: match ev.kind: - case ProgressKind.BUILD_STARTED: + case ProgressKind.TASK_STARTED: _get_or_create_overall_task(ev.datasource_total) return - case ProgressKind.BUILD_FINISHED: + case ProgressKind.TASK_FINISHED: if ev.message: progress.console.print(f"{ev.message}") return diff --git a/tests/build_sources/test_build_runner.py b/tests/build_sources/test_build_runner.py index af84e5d0..4c47be20 100644 --- a/tests/build_sources/test_build_runner.py +++ b/tests/build_sources/test_build_runner.py @@ -149,7 +149,7 @@ def test_run_indexing_indexes_when_plugin_exists(mocker, mock_build_service, pro build_runner.run_indexing(project_layout=project_layout, build_service=mock_build_service, contexts=[ctx]) - mock_build_service.index_built_context.assert_called_once_with(context=ctx, plugin=plugin) + mock_build_service.index_built_context.assert_called_once_with(context=ctx, plugin=plugin, progress=None) def test_run_indexing_skips_when_plugin_missing(mocker, mock_build_service, project_layout, caplog): @@ -183,5 +183,5 @@ def test_run_indexing_continues_on_exception(mocker, mock_build_service, project build_runner.run_indexing(project_layout=project_layout, build_service=mock_build_service, contexts=[c1, c2]) assert mock_build_service.index_built_context.call_count == 2 - mock_build_service.index_built_context.assert_any_call(context=c1, plugin=plugin) - mock_build_service.index_built_context.assert_any_call(context=c2, plugin=plugin) + mock_build_service.index_built_context.assert_any_call(context=c1, plugin=plugin, progress=None) + mock_build_service.index_built_context.assert_any_call(context=c2, plugin=plugin, progress=None) diff --git a/tests/build_sources/test_build_service.py b/tests/build_sources/test_build_service.py index ebc5b512..be677452 100644 --- a/tests/build_sources/test_build_service.py +++ b/tests/build_sources/test_build_service.py @@ -136,6 +136,7 @@ def test_index_built_context_happy_path_embeds(svc, chunk_embed_svc, mocker): full_type="files/md", datasource_id="files/two.md", override=True, + progress=None, ) diff --git a/tests/test_databao_context_project_manager.py b/tests/test_databao_context_project_manager.py index 13dcd3f1..2495b031 100644 --- a/tests/test_databao_context_project_manager.py +++ b/tests/test_databao_context_project_manager.py @@ -132,6 +132,7 @@ def test_databao_context_project_manager__index_built_contexts_indexes_all_when_ project_layout=pm._project_layout, contexts=[c1, c2], chunk_embedding_mode=ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY, + progress=None, ) @@ -164,6 +165,7 @@ def test_databao_context_project_manager__index_built_contexts_filters_by_dataso project_layout=pm._project_layout, contexts=[c1, c3], chunk_embedding_mode=ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY, + progress=None, ) From b9e744260ecd89b8718916cc0e2cadd4019cf65f Mon Sep 17 00:00:00 2001 From: MateusCordeiro Date: Wed, 11 Feb 2026 14:40:57 +0000 Subject: [PATCH 6/6] Redirect logging via rich if interactive mode is enabled --- .../build_sources/build_runner.py | 6 --- .../progress/rich_progress.py | 50 ++++++++++++------- .../services/chunk_embedding_service.py | 6 --- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/databao_context_engine/build_sources/build_runner.py b/src/databao_context_engine/build_sources/build_runner.py index ace2481b..f258dbfe 100644 --- a/src/databao_context_engine/build_sources/build_runner.py +++ b/src/databao_context_engine/build_sources/build_runner.py @@ -187,12 +187,6 @@ def run_indexing( summary = IndexSummary(total=len(contexts), indexed=0, skipped=0, failed=0) emitter = ProgressEmitter(progress) - - if not contexts: - emitter.task_started(total_datasources=0) - emitter.task_finished(ok=0, failed=0, skipped=0) - return summary - emitter.task_started(total_datasources=len(contexts)) for datasource_index, context in enumerate(contexts, start=1): diff --git a/src/databao_context_engine/progress/rich_progress.py b/src/databao_context_engine/progress/rich_progress.py index 898d48f9..b5e87319 100644 --- a/src/databao_context_engine/progress/rich_progress.py +++ b/src/databao_context_engine/progress/rich_progress.py @@ -5,6 +5,8 @@ from contextlib import contextmanager from typing import Callable, Iterator, Optional, TypedDict +from rich.logging import RichHandler + from databao_context_engine.progress.progress import ( ProgressCallback, ProgressEvent, @@ -67,6 +69,35 @@ def render(self, task) -> Text: console = Console(stderr=True) + @contextmanager + def _use_rich_console_logging() -> Iterator[None]: + app_logger = logging.getLogger("databao_context_engine") + + prev_handlers = list(app_logger.handlers) + prev_propagate = app_logger.propagate + + def _is_console_handler(h: logging.Handler) -> bool: + return isinstance(h, logging.StreamHandler) and getattr(h, "stream", None) in (sys.stderr, sys.stdout) + + kept_handlers = [h for h in prev_handlers if not _is_console_handler(h)] + + rich_handler = RichHandler( + console=console, + show_time=False, + show_level=True, + show_path=False, + rich_tracebacks=False, + ) + + try: + app_logger.handlers = kept_handlers + [rich_handler] + app_logger.propagate = False + yield + finally: + app_logger.handlers = prev_handlers + app_logger.propagate = prev_propagate + + tasks: dict[str, TaskID] = {} ui_state: _UIState = { "datasource_index": None, @@ -159,21 +190,6 @@ def on_event(ev: ProgressEvent) -> None: _set_datasource_percent(float(pct)) return - root = logging.getLogger() - prev_level = root.level - prev_handlers = list(root.handlers) - - prev_disable_level = logging.root.manager.disable - logging.disable(logging.CRITICAL) - - try: + with _use_rich_console_logging(): with progress: - yield on_event - finally: - logging.disable(prev_disable_level) - - for h in list(root.handlers): - root.removeHandler(h) - for h in prev_handlers: - root.addHandler(h) - root.setLevel(prev_level) + yield on_event \ No newline at end of file diff --git a/src/databao_context_engine/services/chunk_embedding_service.py b/src/databao_context_engine/services/chunk_embedding_service.py index f9e64427..594936b1 100644 --- a/src/databao_context_engine/services/chunk_embedding_service.py +++ b/src/databao_context_engine/services/chunk_embedding_service.py @@ -121,12 +121,6 @@ def embed_chunks( total_units=total_units, ) - emitter.datasource_progress_units( - datasource_id=datasource_id, - completed_units=len(chunks), - total_units=len(chunks) * 2, - ) - table_name = self._shard_resolver.resolve_or_create( embedder=self._embedding_provider.embedder, model_id=self._embedding_provider.model_id,