From e9a09ae56e54d8d8431903f4c26176b52ff1c8fd Mon Sep 17 00:00:00 2001 From: Christophe Blefari Date: Tue, 10 Feb 2026 17:12:39 +0100 Subject: [PATCH 1/7] Add integration tests for redshift and duckdb --- cli/.gitignore | 3 + cli/nao_core/commands/debug.py | 29 +- cli/nao_core/commands/sync/__init__.py | 73 +++-- cli/nao_core/commands/sync/accessors.py | 290 ----------------- .../commands/sync/providers/__init__.py | 70 ++++- .../sync/providers/databases/__init__.py | 17 +- .../sync/providers/databases/bigquery.py | 79 ----- .../sync/providers/databases/context.py | 70 +++++ .../sync/providers/databases/databricks.py | 79 ----- .../sync/providers/databases/duckdb.py | 78 ----- .../sync/providers/databases/postgres.py | 79 ----- .../sync/providers/databases/provider.py | 139 ++++++--- .../sync/providers/databases/redshift.py | 79 ----- .../sync/providers/databases/snowflake.py | 79 ----- cli/nao_core/commands/sync/registry.py | 23 -- cli/nao_core/config/__init__.py | 2 - cli/nao_core/config/databases/__init__.py | 3 +- cli/nao_core/config/databases/base.py | 34 +- cli/nao_core/config/databases/bigquery.py | 21 +- cli/nao_core/config/databases/databricks.py | 21 +- cli/nao_core/config/databases/duckdb.py | 10 +- cli/nao_core/config/databases/postgres.py | 21 +- cli/nao_core/config/databases/redshift.py | 162 +++++++++- cli/nao_core/config/databases/snowflake.py | 21 +- .../defaults/databases/columns.md.j2 | 19 +- .../defaults/databases/description.md.j2 | 19 +- .../defaults/databases/preview.md.j2 | 13 +- .../defaults/databases/profiling.md.j2 | 34 -- cli/nao_core/templates/engine.py | 22 ++ .../commands/sync/integration/.env.example | 11 + .../commands/sync/integration/__init__.py | 0 .../commands/sync/integration/conftest.py | 20 ++ .../sync/integration/dml/redshift.sql | 29 ++ .../commands/sync/integration/test_duckdb.py | 219 +++++++++++++ .../sync/integration/test_redshift.py | 293 ++++++++++++++++++ .../nao_core/commands/sync/test_accessors.py | 87 ------ .../nao_core/commands/sync/test_context.py | 74 +++++ .../nao_core/commands/sync/test_providers.py | 6 +- .../nao_core/commands/sync/test_registry.py | 33 -- cli/tests/nao_core/commands/sync/test_sync.py | 36 +-- cli/tests/nao_core/commands/test_debug.py | 61 ++-- cli/tests/nao_core/templates/test_engine.py | 36 +-- cli/uv.lock | 28 +- .../table=customers/description.md | 14 + .../schema=main/table=customers/preview.md | 30 +- .../schema=main/table=orders/description.md | 14 + .../schema=main/table=orders/preview.md | 30 +- .../table=raw_customers/description.md | 14 + .../table=raw_customers/preview.md | 30 +- .../table=raw_orders/description.md | 14 + .../schema=main/table=raw_orders/preview.md | 30 +- .../table=raw_payments/description.md | 14 + .../schema=main/table=raw_payments/preview.md | 30 +- .../table=stg_customers/description.md | 14 + .../table=stg_customers/preview.md | 30 +- .../table=stg_orders/description.md | 14 + .../schema=main/table=stg_orders/preview.md | 30 +- .../table=stg_payments/description.md | 14 + .../schema=main/table=stg_payments/preview.md | 30 +- example/nao_config.yaml | 5 +- example/repos/dbt | 2 +- example/templates/databases/preview.md.j2 | 32 +- 62 files changed, 1564 insertions(+), 1349 deletions(-) delete mode 100644 cli/nao_core/commands/sync/accessors.py delete mode 100644 cli/nao_core/commands/sync/providers/databases/bigquery.py create mode 100644 cli/nao_core/commands/sync/providers/databases/context.py delete mode 100644 cli/nao_core/commands/sync/providers/databases/databricks.py delete mode 100644 cli/nao_core/commands/sync/providers/databases/duckdb.py delete mode 100644 cli/nao_core/commands/sync/providers/databases/postgres.py delete mode 100644 cli/nao_core/commands/sync/providers/databases/redshift.py delete mode 100644 cli/nao_core/commands/sync/providers/databases/snowflake.py delete mode 100644 cli/nao_core/commands/sync/registry.py delete mode 100644 cli/nao_core/templates/defaults/databases/profiling.md.j2 create mode 100644 cli/tests/nao_core/commands/sync/integration/.env.example create mode 100644 cli/tests/nao_core/commands/sync/integration/__init__.py create mode 100644 cli/tests/nao_core/commands/sync/integration/conftest.py create mode 100644 cli/tests/nao_core/commands/sync/integration/dml/redshift.sql create mode 100644 cli/tests/nao_core/commands/sync/integration/test_duckdb.py create mode 100644 cli/tests/nao_core/commands/sync/integration/test_redshift.py delete mode 100644 cli/tests/nao_core/commands/sync/test_accessors.py create mode 100644 cli/tests/nao_core/commands/sync/test_context.py delete mode 100644 cli/tests/nao_core/commands/sync/test_registry.py create mode 100644 example/databases/type=duckdb/database=jaffle_shop/schema=main/table=customers/description.md create mode 100644 example/databases/type=duckdb/database=jaffle_shop/schema=main/table=orders/description.md create mode 100644 example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_customers/description.md create mode 100644 example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_orders/description.md create mode 100644 example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_payments/description.md create mode 100644 example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_customers/description.md create mode 100644 example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_orders/description.md create mode 100644 example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_payments/description.md diff --git a/cli/.gitignore b/cli/.gitignore index f60c88de..b02c1ac6 100644 --- a/cli/.gitignore +++ b/cli/.gitignore @@ -17,3 +17,6 @@ build/ .venv/ venv/ +# Integration test secrets +tests/**/.env + diff --git a/cli/nao_core/commands/debug.py b/cli/nao_core/commands/debug.py index 67716c20..60af4f16 100644 --- a/cli/nao_core/commands/debug.py +++ b/cli/nao_core/commands/debug.py @@ -4,7 +4,6 @@ from rich.table import Table from nao_core.config import NaoConfig -from nao_core.config.databases import AnyDatabaseConfig from nao_core.tracking import track_command console = Console() @@ -46,32 +45,6 @@ def _check_available_models(provider: str, api_key: str) -> Tuple[bool, str]: return True, f"Connected successfully ({model_count} models available)" -def check_database_connection(db_config: AnyDatabaseConfig) -> tuple[bool, str]: - """Test connectivity to a database. - - Returns: - Tuple of (success, message) - """ - try: - conn = db_config.connect() - # Run a simple query to verify the connection works - if hasattr(db_config, "dataset_id") and db_config.dataset_id: - # If dataset is specified, list tables in that dataset - tables = conn.list_tables() - table_count = len(tables) - return True, f"Connected successfully ({table_count} tables found)" - elif list_databases := getattr(conn, "list_databases", None): - # If no dataset, list schemas in the database instead - schemas = list_databases() - schema_count = len(schemas) - return True, f"Connected successfully ({schema_count} schemas found)" - else: - # Fallback for backends that don't support list_tables and list_databases - return True, "Connected but unable to list neither datasets nor schemas" - except Exception as e: - return False, str(e) - - def check_llm_connection(llm_config) -> tuple[bool, str]: """Test connectivity to an LLM provider. @@ -110,7 +83,7 @@ def debug(): for db in config.databases: console.print(f" Testing [cyan]{db.name}[/cyan]...", end=" ") - success, message = check_database_connection(db) + success, message = db.check_connection() if success: console.print("[bold green]✓[/bold green]") diff --git a/cli/nao_core/commands/sync/__init__.py b/cli/nao_core/commands/sync/__init__.py index dbccbf7c..6d6b02bb 100644 --- a/cli/nao_core/commands/sync/__init__.py +++ b/cli/nao_core/commands/sync/__init__.py @@ -2,22 +2,38 @@ import sys from pathlib import Path +from typing import Annotated +from cyclopts import Parameter from rich.console import Console from nao_core.config import NaoConfig from nao_core.templates.render import render_all_templates from nao_core.tracking import track_command -from .providers import SyncProvider, SyncResult, get_all_providers +from .providers import ( + PROVIDER_CHOICES, + ProviderSelection, + SyncResult, + get_all_providers, + get_providers_by_names, +) console = Console() @track_command("sync") def sync( - output_dirs: dict[str, str] | None = None, - providers: list[SyncProvider] | None = None, + *, + provider: Annotated[ + list[str] | None, + Parameter( + name=["-p", "--provider"], + help=f"Provider(s) to sync. Use `-p provider:name` to sync a specific connection (e.g. databases:my-db). Or just `-p databases` to sync all connections. Options: {', '.join(PROVIDER_CHOICES)}", + ), + ] = None, + output_dirs: Annotated[dict[str, str] | None, Parameter(show=False)] = None, + _providers: Annotated[list[ProviderSelection] | None, Parameter(show=False)] = None, render_templates: bool = True, ): """Sync resources using configured providers. @@ -29,14 +45,6 @@ def sync( After syncing providers, renders any Jinja templates (*.j2 files) found in the project directory, making the `nao` context object available for accessing provider data. - - Args: - output_dirs: Optional dict mapping provider names to custom output directories. - If not specified, uses each provider's default_output_dir. - providers: Optional list of providers to use. If not specified, uses all - registered providers. - render_templates: Whether to render Jinja templates after syncing providers. - Defaults to True. """ console.print("\n[bold cyan]🔄 nao sync[/bold cyan]\n") @@ -48,31 +56,52 @@ def sync( console.print(f"[dim]Project:[/dim] {config.project_name}") - # Use provided providers or default to all registered providers - active_providers = providers if providers is not None else get_all_providers() + # Resolve providers: CLI names > programmatic providers > all providers + if provider: + try: + active_providers = get_providers_by_names(provider) + except ValueError as e: + console.print(f"[red]Error:[/red] {e}") + sys.exit(1) + elif _providers is not None: + active_providers = _providers + else: + active_providers = get_all_providers() + output_dirs = output_dirs or {} # Run each provider results: list[SyncResult] = [] - for provider in active_providers: + for selection in active_providers: + sync_provider = selection.provider + connection_filter = selection.connection_name + # Get output directory (custom or default) - output_dir = output_dirs.get(provider.name, provider.default_output_dir) + output_dir = output_dirs.get(sync_provider.name, sync_provider.default_output_dir) output_path = Path(output_dir) try: - provider.pre_sync(config, output_path) + sync_provider.pre_sync(config, output_path) - if not provider.should_sync(config): + if not sync_provider.should_sync(config): continue - # Get items and sync - items = provider.get_items(config) - result = provider.sync(items, output_path, project_path=project_path) + # Get items and filter by connection name if specified + items = sync_provider.get_items(config) + if connection_filter: + items = [item for item in items if getattr(item, "name", None) == connection_filter] + if not items: + console.print( + f"[yellow]Warning:[/yellow] No connection named '{connection_filter}' found for {sync_provider.name}" + ) + continue + + result = sync_provider.sync(items, output_path, project_path=project_path) results.append(result) except Exception as e: # Capture error but continue with other providers - results.append(SyncResult.from_error(provider.name, e)) - console.print(f" [yellow]⚠[/yellow] {provider.emoji} {provider.name}: [red]{e}[/red]") + results.append(SyncResult.from_error(sync_provider.name, e)) + console.print(f" [yellow]⚠[/yellow] {sync_provider.emoji} {sync_provider.name}: [red]{e}[/red]") # Render user Jinja templates template_result = None diff --git a/cli/nao_core/commands/sync/accessors.py b/cli/nao_core/commands/sync/accessors.py deleted file mode 100644 index 5b30287a..00000000 --- a/cli/nao_core/commands/sync/accessors.py +++ /dev/null @@ -1,290 +0,0 @@ -"""Data accessor classes for generating markdown documentation from database tables.""" - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any - -from ibis import BaseBackend - -from nao_core.templates import get_template_engine - - -class DataAccessor(ABC): - """Base class for data accessors that generate markdown files for tables. - - Accessors use Jinja2 templates for generating output. Default templates - are shipped with nao and can be overridden by users by placing templates - with the same name in their project's `templates/` directory. - - Example: - To override the preview template, create: - `/templates/databases/preview.md.j2` - """ - - # Path to the nao project root (set by sync provider) - _project_path: Path | None = None - - @property - @abstractmethod - def filename(self) -> str: - """The filename this accessor writes to (e.g., 'columns.md').""" - ... - - @property - @abstractmethod - def template_name(self) -> str: - """The template file to use (e.g., 'databases/columns.md.j2').""" - ... - - @abstractmethod - def get_context(self, conn: BaseBackend, dataset: str, table: str) -> dict[str, Any]: - """Get the template context for rendering. - - Args: - conn: The Ibis database connection - dataset: The dataset/schema name - table: The table name - - Returns: - Dictionary of variables to pass to the template - """ - ... - - def generate(self, conn: BaseBackend, dataset: str, table: str) -> str: - """Generate the markdown content for a table using templates. - - Args: - conn: The Ibis database connection - dataset: The dataset/schema name - table: The table name - - Returns: - Markdown string content - """ - try: - context = self.get_context(conn, dataset, table) - engine = get_template_engine(self._project_path) - return engine.render(self.template_name, **context) - except Exception as e: - return f"# {table}\n\nError generating content: {e}" - - def get_table(self, conn: BaseBackend, dataset: str, table: str): - """Helper to get an Ibis table reference.""" - return conn.table(table, database=dataset) - - @classmethod - def set_project_path(cls, path: Path | None) -> None: - """Set the project path for template resolution. - - Args: - path: Path to the nao project root - """ - cls._project_path = path - - -def truncate_middle(text: str, max_length: int) -> str: - """Truncate text in the middle if it exceeds max_length.""" - if len(text) <= max_length: - return text - half = (max_length - 3) // 2 - return text[:half] + "..." + text[-half:] - - -class ColumnsAccessor(DataAccessor): - """Generates columns.md with column names, types, and nullable info. - - Template variables: - - table_name: Name of the table - - dataset: Schema/dataset name - - columns: List of dicts with 'name', 'type', 'nullable', 'description' - - column_count: Total number of columns - """ - - def __init__(self, max_description_length: int = 256): - self.max_description_length = max_description_length - - @property - def filename(self) -> str: - return "columns.md" - - @property - def template_name(self) -> str: - return "databases/columns.md.j2" - - def get_context(self, conn: BaseBackend, dataset: str, table: str) -> dict[str, Any]: - t = self.get_table(conn, dataset, table) - schema = t.schema() - - columns = [] - for name, dtype in schema.items(): - columns.append( - { - "name": name, - "type": str(dtype), - "nullable": dtype.nullable if hasattr(dtype, "nullable") else True, - "description": None, # Could be populated from metadata - } - ) - - return { - "table_name": table, - "dataset": dataset, - "columns": columns, - "column_count": len(columns), - } - - -class PreviewAccessor(DataAccessor): - """Generates preview.md with the first N rows of data as JSONL. - - Template variables: - - table_name: Name of the table - - dataset: Schema/dataset name - - rows: List of row dictionaries - - row_count: Number of preview rows - - columns: List of column info dicts - """ - - def __init__(self, num_rows: int = 10): - self.num_rows = num_rows - - @property - def filename(self) -> str: - return "preview.md" - - @property - def template_name(self) -> str: - return "databases/preview.md.j2" - - def get_context(self, conn: BaseBackend, dataset: str, table: str) -> dict[str, Any]: - t = self.get_table(conn, dataset, table) - schema = t.schema() - preview_df = t.limit(self.num_rows).execute() - - rows = [] - for _, row in preview_df.iterrows(): - row_dict = row.to_dict() - # Convert non-serializable types to strings - for key, val in row_dict.items(): - if val is not None and not isinstance(val, (str, int, float, bool, list, dict)): - row_dict[key] = str(val) - rows.append(row_dict) - - columns = [{"name": name, "type": str(dtype)} for name, dtype in schema.items()] - - return { - "table_name": table, - "dataset": dataset, - "rows": rows, - "row_count": len(rows), - "columns": columns, - } - - -class DescriptionAccessor(DataAccessor): - """Generates description.md with table metadata (row count, column count, etc.). - - Template variables: - - table_name: Name of the table - - dataset: Schema/dataset name - - row_count: Total rows in the table - - column_count: Number of columns - - description: Table description (if available) - - columns: List of column info dicts - """ - - @property - def filename(self) -> str: - return "description.md" - - @property - def template_name(self) -> str: - return "databases/description.md.j2" - - def get_context(self, conn: BaseBackend, dataset: str, table: str) -> dict[str, Any]: - t = self.get_table(conn, dataset, table) - schema = t.schema() - - row_count = t.count().execute() - columns = [{"name": name, "type": str(dtype)} for name, dtype in schema.items()] - - return { - "table_name": table, - "dataset": dataset, - "row_count": row_count, - "column_count": len(schema), - "description": None, # Could be populated from metadata - "columns": columns, - } - - -class ProfilingAccessor(DataAccessor): - """Generates profiling.md with column statistics and data profiling. - - Template variables: - - table_name: Name of the table - - dataset: Schema/dataset name - - column_stats: List of dicts with stats for each column: - - name: Column name - - type: Data type - - null_count: Number of nulls - - unique_count: Number of unique values - - min_value: Min value (numeric/temporal) - - max_value: Max value (numeric/temporal) - - error: Error message if stats couldn't be computed - - columns: List of column info dicts - """ - - @property - def filename(self) -> str: - return "profiling.md" - - @property - def template_name(self) -> str: - return "databases/profiling.md.j2" - - def get_context(self, conn: BaseBackend, dataset: str, table: str) -> dict[str, Any]: - t = self.get_table(conn, dataset, table) - schema = t.schema() - - column_stats = [] - columns = [] - - for name, dtype in schema.items(): - columns.append({"name": name, "type": str(dtype)}) - col = t[name] - dtype_str = str(dtype) - - stat = { - "name": name, - "type": dtype_str, - "null_count": 0, - "unique_count": 0, - "min_value": None, - "max_value": None, - "error": None, - } - - try: - stat["null_count"] = t.filter(col.isnull()).count().execute() - stat["unique_count"] = col.nunique().execute() - - if dtype.is_numeric() or dtype.is_temporal(): - try: - min_val = str(col.min().execute()) - max_val = str(col.max().execute()) - stat["min_value"] = truncate_middle(min_val, 20) - stat["max_value"] = truncate_middle(max_val, 20) - except Exception: - pass - except Exception as col_error: - stat["error"] = str(col_error) - - column_stats.append(stat) - - return { - "table_name": table, - "dataset": dataset, - "column_stats": column_stats, - "columns": columns, - } diff --git a/cli/nao_core/commands/sync/providers/__init__.py b/cli/nao_core/commands/sync/providers/__init__.py index afe986b1..a59ae412 100644 --- a/cli/nao_core/commands/sync/providers/__init__.py +++ b/cli/nao_core/commands/sync/providers/__init__.py @@ -1,32 +1,82 @@ """Sync providers for different resource types.""" +from dataclasses import dataclass + from .base import SyncProvider, SyncResult from .databases.provider import DatabaseSyncProvider from .notion.provider import NotionSyncProvider from .repositories.provider import RepositorySyncProvider +# Provider registry mapping CLI-friendly names to provider instances +PROVIDER_REGISTRY: dict[str, SyncProvider] = { + "notion": NotionSyncProvider(), + "repositories": RepositorySyncProvider(), + "databases": DatabaseSyncProvider(), +} + # Default providers in order of execution -DEFAULT_PROVIDERS: list[SyncProvider] = [ - NotionSyncProvider(), - RepositorySyncProvider(), - DatabaseSyncProvider(), -] +DEFAULT_PROVIDERS: list[SyncProvider] = list(PROVIDER_REGISTRY.values()) + +# Valid provider names for CLI help text +PROVIDER_CHOICES: list[str] = list(PROVIDER_REGISTRY.keys()) + + +@dataclass +class ProviderSelection: + """A provider with an optional connection name filter.""" + + provider: SyncProvider + connection_name: str | None = None + + +def get_all_providers() -> list[ProviderSelection]: + """Get all registered sync providers.""" + return [ProviderSelection(p) for p in DEFAULT_PROVIDERS] + + +def parse_provider_arg(arg: str) -> ProviderSelection: + """Parse a provider argument like 'databases' or 'databases:my-connection'. + + Raises: + ValueError: If the provider name is not valid. + """ + if ":" in arg: + provider_name, connection_name = arg.split(":", 1) + else: + provider_name = arg + connection_name = None + + provider_name_lower = provider_name.lower() + if provider_name_lower not in PROVIDER_REGISTRY: + valid = ", ".join(PROVIDER_CHOICES) + raise ValueError(f"Unknown provider '{provider_name}'. Valid options: {valid}") + + return ProviderSelection( + provider=PROVIDER_REGISTRY[provider_name_lower], + connection_name=connection_name, + ) + +def get_providers_by_names(names: list[str]) -> list[ProviderSelection]: + """Get provider selections by their CLI-friendly names. -def get_all_providers() -> list[SyncProvider]: - """Get all registered sync providers. + Supports 'provider' or 'provider:connection_name' syntax. - Returns: - List of sync provider instances + Raises: + ValueError: If any provider name is not valid. """ - return DEFAULT_PROVIDERS.copy() + return [parse_provider_arg(name) for name in names] __all__ = [ "SyncProvider", "SyncResult", + "ProviderSelection", "DatabaseSyncProvider", "RepositorySyncProvider", + "PROVIDER_REGISTRY", + "PROVIDER_CHOICES", "DEFAULT_PROVIDERS", "get_all_providers", + "get_providers_by_names", ] diff --git a/cli/nao_core/commands/sync/providers/databases/__init__.py b/cli/nao_core/commands/sync/providers/databases/__init__.py index eb925ce6..e99bef1a 100644 --- a/cli/nao_core/commands/sync/providers/databases/__init__.py +++ b/cli/nao_core/commands/sync/providers/databases/__init__.py @@ -1,19 +1,10 @@ """Database syncing functionality for generating markdown documentation from database schemas.""" -from .bigquery import sync_bigquery -from .databricks import sync_databricks -from .duckdb import sync_duckdb -from .postgres import sync_postgres -from .provider import DatabaseSyncProvider -from .redshift import sync_redshift -from .snowflake import sync_snowflake +from .context import DatabaseContext +from .provider import DatabaseSyncProvider, sync_database __all__ = [ + "DatabaseContext", "DatabaseSyncProvider", - "sync_bigquery", - "sync_databricks", - "sync_duckdb", - "sync_postgres", - "sync_redshift", - "sync_snowflake", + "sync_database", ] diff --git a/cli/nao_core/commands/sync/providers/databases/bigquery.py b/cli/nao_core/commands/sync/providers/databases/bigquery.py deleted file mode 100644 index 96377374..00000000 --- a/cli/nao_core/commands/sync/providers/databases/bigquery.py +++ /dev/null @@ -1,79 +0,0 @@ -from pathlib import Path - -from rich.progress import Progress - -from nao_core.commands.sync.accessors import DataAccessor -from nao_core.commands.sync.cleanup import DatabaseSyncState - - -def sync_bigquery( - db_config, - base_path: Path, - progress: Progress, - accessors: list[DataAccessor], -) -> DatabaseSyncState: - """Sync BigQuery database schema to markdown files. - - Args: - db_config: The database configuration - base_path: Base output path - progress: Rich progress instance - accessors: List of data accessors to run - - Returns: - DatabaseSyncState with sync results and tracked paths - """ - conn = db_config.connect() - db_name = db_config.get_database_name() - db_path = base_path / "type=bigquery" / f"database={db_name}" - state = DatabaseSyncState(db_path=db_path) - - if db_config.dataset_id: - datasets = [db_config.dataset_id] - else: - datasets = conn.list_databases() - - dataset_task = progress.add_task( - f"[dim]{db_config.name}[/dim]", - total=len(datasets), - ) - - for dataset in datasets: - try: - all_tables = conn.list_tables(database=dataset) - except Exception: - progress.update(dataset_task, advance=1) - continue - - # Filter tables based on include/exclude patterns - tables = [t for t in all_tables if db_config.matches_pattern(dataset, t)] - - # Skip dataset if no tables match - if not tables: - progress.update(dataset_task, advance=1) - continue - - dataset_path = db_path / f"schema={dataset}" - dataset_path.mkdir(parents=True, exist_ok=True) - state.add_schema(dataset) - - table_task = progress.add_task( - f" [cyan]{dataset}[/cyan]", - total=len(tables), - ) - - for table in tables: - table_path = dataset_path / f"table={table}" - table_path.mkdir(parents=True, exist_ok=True) - - for accessor in accessors: - content = accessor.generate(conn, dataset, table) - output_file = table_path / accessor.filename - output_file.write_text(content) - - state.add_table(dataset, table) - progress.update(table_task, advance=1) - - progress.update(dataset_task, advance=1) - - return state diff --git a/cli/nao_core/commands/sync/providers/databases/context.py b/cli/nao_core/commands/sync/providers/databases/context.py new file mode 100644 index 00000000..4f71bf5a --- /dev/null +++ b/cli/nao_core/commands/sync/providers/databases/context.py @@ -0,0 +1,70 @@ +"""Database context exposing methods available in templates during sync.""" + +from typing import Any + +from ibis import BaseBackend + + +class DatabaseContext: + """Context object passed to Jinja2 templates during database sync. + + Exposes data-fetching methods that templates can call to retrieve + column metadata, row previews, table descriptions, etc. + """ + + def __init__(self, conn: BaseBackend, schema: str, table_name: str): + self._conn = conn + self._schema = schema + self._table_name = table_name + self._table_ref = None + + @property + def table(self): + if self._table_ref is None: + self._table_ref = self._conn.table(self._table_name, database=self._schema) + return self._table_ref + + def columns(self) -> list[dict[str, Any]]: + """Return column metadata: name, type, nullable, description.""" + schema = self.table.schema() + return [ + { + "name": name, + "type": self._format_type(dtype), + "nullable": dtype.nullable if hasattr(dtype, "nullable") else True, + "description": None, + } + for name, dtype in schema.items() + ] + + @staticmethod + def _format_type(dtype) -> str: + """Convert Ibis type to a human-readable string (e.g. !int32 -> int32 NOT NULL).""" + raw = str(dtype) + if raw.startswith("!"): + return f"{raw[1:]} NOT NULL" + return raw + + def preview(self, limit: int = 10) -> list[dict[str, Any]]: + """Return the first N rows as a list of dictionaries.""" + df = self.table.limit(limit).execute() + rows = [] + for _, row in df.iterrows(): + row_dict = row.to_dict() + for key, val in row_dict.items(): + if val is not None and not isinstance(val, (str, int, float, bool, list, dict)): + row_dict[key] = str(val) + rows.append(row_dict) + return rows + + def row_count(self) -> int: + """Return the total number of rows in the table.""" + return self.table.count().execute() + + def column_count(self) -> int: + """Return the number of columns in the table.""" + return len(self.table.schema()) + + def description(self) -> str | None: + """Return the table description if available.""" + return None diff --git a/cli/nao_core/commands/sync/providers/databases/databricks.py b/cli/nao_core/commands/sync/providers/databases/databricks.py deleted file mode 100644 index 7a37f72f..00000000 --- a/cli/nao_core/commands/sync/providers/databases/databricks.py +++ /dev/null @@ -1,79 +0,0 @@ -from pathlib import Path - -from rich.progress import Progress - -from nao_core.commands.sync.accessors import DataAccessor -from nao_core.commands.sync.cleanup import DatabaseSyncState - - -def sync_databricks( - db_config, - base_path: Path, - progress: Progress, - accessors: list[DataAccessor], -) -> DatabaseSyncState: - """Sync Databricks database schema to markdown files. - - Args: - db_config: The database configuration - base_path: Base output path - progress: Rich progress instance - accessors: List of data accessors to run - - Returns: - DatabaseSyncState with sync results and tracked paths - """ - conn = db_config.connect() - db_name = db_config.get_database_name() - db_path = base_path / "type=databricks" / f"database={db_name}" - state = DatabaseSyncState(db_path=db_path) - - if db_config.schema: - schemas = [db_config.schema] - else: - schemas = conn.list_databases() - - schema_task = progress.add_task( - f"[dim]{db_config.name}[/dim]", - total=len(schemas), - ) - - for schema in schemas: - try: - all_tables = conn.list_tables(database=schema) - except Exception: - progress.update(schema_task, advance=1) - continue - - # Filter tables based on include/exclude patterns - tables = [t for t in all_tables if db_config.matches_pattern(schema, t)] - - # Skip schema if no tables match - if not tables: - progress.update(schema_task, advance=1) - continue - - schema_path = db_path / f"schema={schema}" - schema_path.mkdir(parents=True, exist_ok=True) - state.add_schema(schema) - - table_task = progress.add_task( - f" [cyan]{schema}[/cyan]", - total=len(tables), - ) - - for table in tables: - table_path = schema_path / f"table={table}" - table_path.mkdir(parents=True, exist_ok=True) - - for accessor in accessors: - content = accessor.generate(conn, schema, table) - output_file = table_path / accessor.filename - output_file.write_text(content) - - state.add_table(schema, table) - progress.update(table_task, advance=1) - - progress.update(schema_task, advance=1) - - return state diff --git a/cli/nao_core/commands/sync/providers/databases/duckdb.py b/cli/nao_core/commands/sync/providers/databases/duckdb.py deleted file mode 100644 index a2449260..00000000 --- a/cli/nao_core/commands/sync/providers/databases/duckdb.py +++ /dev/null @@ -1,78 +0,0 @@ -from pathlib import Path - -from rich.progress import Progress - -from nao_core.commands.sync.accessors import DataAccessor -from nao_core.commands.sync.cleanup import DatabaseSyncState - - -def sync_duckdb( - db_config, - base_path: Path, - progress: Progress, - accessors: list[DataAccessor], -) -> DatabaseSyncState: - """Sync DuckDB database schema to markdown files. - - Args: - db_config: The database configuration - base_path: Base output path - progress: Rich progress instance - accessors: List of data accessors to run - - Returns: - DatabaseSyncState with sync results and tracked paths - """ - conn = db_config.connect() - - db_name = db_config.get_database_name() - db_path = base_path / "type=duckdb" / f"database={db_name}" - state = DatabaseSyncState(db_path=db_path) - - # List all schemas in DuckDB - schemas = conn.list_databases() - - schema_task = progress.add_task( - f"[dim]{db_config.name}[/dim]", - total=len(schemas), - ) - - for schema in schemas: - try: - all_tables = conn.list_tables(database=schema) - except Exception: - progress.update(schema_task, advance=1) - continue - - # Filter tables based on include/exclude patterns - tables = [t for t in all_tables if db_config.matches_pattern(schema, t)] - - # Skip schema if no tables match - if not tables: - progress.update(schema_task, advance=1) - continue - - schema_path = db_path / f"schema={schema}" - schema_path.mkdir(parents=True, exist_ok=True) - state.add_schema(schema) - - table_task = progress.add_task( - f" [cyan]{schema}[/cyan]", - total=len(tables), - ) - - for table in tables: - table_path = schema_path / f"table={table}" - table_path.mkdir(parents=True, exist_ok=True) - - for accessor in accessors: - content = accessor.generate(conn, schema, table) - output_file = table_path / accessor.filename - output_file.write_text(content) - - state.add_table(schema, table) - progress.update(table_task, advance=1) - - progress.update(schema_task, advance=1) - - return state diff --git a/cli/nao_core/commands/sync/providers/databases/postgres.py b/cli/nao_core/commands/sync/providers/databases/postgres.py deleted file mode 100644 index b559215f..00000000 --- a/cli/nao_core/commands/sync/providers/databases/postgres.py +++ /dev/null @@ -1,79 +0,0 @@ -from pathlib import Path - -from rich.progress import Progress - -from nao_core.commands.sync.accessors import DataAccessor -from nao_core.commands.sync.cleanup import DatabaseSyncState - - -def sync_postgres( - db_config, - base_path: Path, - progress: Progress, - accessors: list[DataAccessor], -) -> DatabaseSyncState: - """Sync PostgreSQL database schema to markdown files. - - Args: - db_config: The database configuration - base_path: Base output path - progress: Rich progress instance - accessors: List of data accessors to run - - Returns: - DatabaseSyncState with sync results and tracked paths - """ - conn = db_config.connect() - db_name = db_config.get_database_name() - db_path = base_path / "type=postgres" / f"database={db_name}" - state = DatabaseSyncState(db_path=db_path) - - if db_config.schema_name: - schemas = [db_config.schema_name] - else: - schemas = conn.list_databases() - - schema_task = progress.add_task( - f"[dim]{db_config.name}[/dim]", - total=len(schemas), - ) - - for schema in schemas: - try: - all_tables = conn.list_tables(database=schema) - except Exception: - progress.update(schema_task, advance=1) - continue - - # Filter tables based on include/exclude patterns - tables = [t for t in all_tables if db_config.matches_pattern(schema, t)] - - # Skip schema if no tables match - if not tables: - progress.update(schema_task, advance=1) - continue - - schema_path = db_path / f"schema={schema}" - schema_path.mkdir(parents=True, exist_ok=True) - state.add_schema(schema) - - table_task = progress.add_task( - f" [cyan]{schema}[/cyan]", - total=len(tables), - ) - - for table in tables: - table_path = schema_path / f"table={table}" - table_path.mkdir(parents=True, exist_ok=True) - - for accessor in accessors: - content = accessor.generate(conn, schema, table) - output_file = table_path / accessor.filename - output_file.write_text(content) - - state.add_table(schema, table) - progress.update(table_task, advance=1) - - progress.update(schema_task, advance=1) - - return state diff --git a/cli/nao_core/commands/sync/providers/databases/provider.py b/cli/nao_core/commands/sync/providers/databases/provider.py index e208df90..d956b5b7 100644 --- a/cli/nao_core/commands/sync/providers/databases/provider.py +++ b/cli/nao_core/commands/sync/providers/databases/provider.py @@ -6,30 +6,91 @@ from rich.console import Console from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn -from nao_core.commands.sync.accessors import DataAccessor from nao_core.commands.sync.cleanup import DatabaseSyncState, cleanup_stale_databases, cleanup_stale_paths -from nao_core.commands.sync.registry import get_accessors from nao_core.config import AnyDatabaseConfig, NaoConfig +from nao_core.config.databases.base import DatabaseConfig +from nao_core.templates.engine import get_template_engine from ..base import SyncProvider, SyncResult -from .bigquery import sync_bigquery -from .databricks import sync_databricks -from .duckdb import sync_duckdb -from .postgres import sync_postgres -from .redshift import sync_redshift -from .snowflake import sync_snowflake +from .context import DatabaseContext console = Console() -# Registry mapping database types to their sync functions -DATABASE_SYNC_FUNCTIONS = { - "bigquery": sync_bigquery, - "duckdb": sync_duckdb, - "databricks": sync_databricks, - "snowflake": sync_snowflake, - "postgres": sync_postgres, - "redshift": sync_redshift, -} +TEMPLATE_PREFIX = "databases" + + +def sync_database( + db_config: DatabaseConfig, + base_path: Path, + progress: Progress, + project_path: Path | None = None, +) -> DatabaseSyncState: + """Sync a single database by rendering all database templates for each table.""" + engine = get_template_engine(project_path) + templates = engine.list_templates(TEMPLATE_PREFIX) + + conn = db_config.connect() + db_name = db_config.get_database_name() + db_path = base_path / f"type={db_config.type}" / f"database={db_name}" + state = DatabaseSyncState(db_path=db_path) + + schemas = db_config.get_schemas(conn) + + schema_task = progress.add_task( + f"[dim]{db_config.name}[/dim]", + total=len(schemas), + ) + + for schema in schemas: + try: + all_tables = conn.list_tables(database=schema) + except Exception: + progress.update(schema_task, advance=1) + continue + + tables = [t for t in all_tables if db_config.matches_pattern(schema, t)] + + if not tables: + progress.update(schema_task, advance=1) + continue + + schema_path = db_path / f"schema={schema}" + schema_path.mkdir(parents=True, exist_ok=True) + state.add_schema(schema) + + table_task = progress.add_task( + f" [cyan]{schema}[/cyan]", + total=len(tables), + ) + + for table in tables: + table_path = schema_path / f"table={table}" + table_path.mkdir(parents=True, exist_ok=True) + + # Use custom context if database config provides one (e.g., for Redshift) + create_context = getattr(db_config, "create_context", None) + if create_context and callable(create_context): + ctx = create_context(conn, schema, table) + else: + ctx = DatabaseContext(conn, schema, table) + + for template_name in templates: + try: + content = engine.render(template_name, db=ctx, table_name=table, dataset=schema) + except Exception as e: + content = f"# {table}\n\nError generating content: {e}" + + # Derive output filename: "databases/columns.md.j2" → "columns.md" + output_filename = Path(template_name).stem # "columns.md" (stem strips .j2) + output_file = table_path / output_filename + output_file.write_text(content) + + state.add_table(schema, table) + progress.update(table_task, advance=1) + + progress.update(schema_task, advance=1) + + return state class DatabaseSyncProvider(SyncProvider): @@ -48,39 +109,29 @@ def default_output_dir(self) -> str: return "databases" def pre_sync(self, config: NaoConfig, output_path: Path) -> None: - """ - Always run before syncing. - """ cleanup_stale_databases(config.databases, output_path, verbose=True) def get_items(self, config: NaoConfig) -> list[AnyDatabaseConfig]: return config.databases def sync(self, items: list[Any], output_path: Path, project_path: Path | None = None) -> SyncResult: - """Sync all configured databases. - - Args: - items: List of database configurations - output_path: Base path where database schemas are stored - project_path: Path to the nao project root (for template resolution) - - Returns: - SyncResult with datasets and tables synced - """ if not items: console.print("\n[dim]No databases configured[/dim]") return SyncResult(provider_name=self.name, items_synced=0) - # Set project path for template resolution - DataAccessor.set_project_path(project_path) - total_datasets = 0 total_tables = 0 total_removed = 0 sync_states: list[DatabaseSyncState] = [] + # Show which templates will be used + engine = get_template_engine(project_path) + templates = engine.list_templates(TEMPLATE_PREFIX) + template_names = [Path(t).stem.replace(".md", "") for t in templates] + console.print(f"\n[bold cyan]{self.emoji} Syncing {self.name}[/bold cyan]") - console.print(f"[dim]Location:[/dim] {output_path.absolute()}\n") + console.print(f"[dim]Location:[/dim] {output_path.absolute()}") + console.print(f"[dim]Templates:[/dim] {', '.join(template_names)}\n") with Progress( SpinnerColumn(style="dim"), @@ -91,30 +142,18 @@ def sync(self, items: list[Any], output_path: Path, project_path: Path | None = transient=False, ) as progress: for db in items: - # Get accessors from database config - db_accessors = get_accessors(db.accessors) - accessor_names = [a.filename.replace(".md", "") for a in db_accessors] - try: - console.print(f"[dim]{db.name} accessors:[/dim] {', '.join(accessor_names)}") - - sync_fn = DATABASE_SYNC_FUNCTIONS.get(db.type) - if sync_fn: - state = sync_fn(db, output_path, progress, db_accessors) - sync_states.append(state) - total_datasets += state.schemas_synced - total_tables += state.tables_synced - else: - console.print(f"[yellow]⚠ Unsupported database type: {db.type}[/yellow]") + state = sync_database(db, output_path, progress, project_path) + sync_states.append(state) + total_datasets += state.schemas_synced + total_tables += state.tables_synced except Exception as e: console.print(f"[bold red]✗[/bold red] Failed to sync {db.name}: {e}") - # Clean up stale files after all syncs complete for state in sync_states: removed = cleanup_stale_paths(state, verbose=True) total_removed += removed - # Build summary summary = f"{total_tables} tables across {total_datasets} datasets" if total_removed > 0: summary += f", {total_removed} stale removed" diff --git a/cli/nao_core/commands/sync/providers/databases/redshift.py b/cli/nao_core/commands/sync/providers/databases/redshift.py deleted file mode 100644 index 25c02030..00000000 --- a/cli/nao_core/commands/sync/providers/databases/redshift.py +++ /dev/null @@ -1,79 +0,0 @@ -from pathlib import Path - -from rich.progress import Progress - -from nao_core.commands.sync.accessors import DataAccessor -from nao_core.commands.sync.cleanup import DatabaseSyncState - - -def sync_redshift( - db_config, - base_path: Path, - progress: Progress, - accessors: list[DataAccessor], -) -> DatabaseSyncState: - """Sync Redshift database schema to markdown files. - - Args: - db_config: The database configuration - base_path: Base output path - progress: Rich progress instance - accessors: List of data accessors to run - - Returns: - DatabaseSyncState with sync results and tracked paths - """ - conn = db_config.connect() - db_name = db_config.get_database_name() - db_path = base_path / "type=redshift" / f"database={db_name}" - state = DatabaseSyncState(db_path=db_path) - - if db_config.schema_name: - schemas = [db_config.schema_name] - else: - schemas = conn.list_databases() - - schema_task = progress.add_task( - f"[dim]{db_config.name}[/dim]", - total=len(schemas), - ) - - for schema in schemas: - try: - all_tables = conn.list_tables(database=schema) - except Exception: - progress.update(schema_task, advance=1) - continue - - # Filter tables based on include/exclude patterns - tables = [t for t in all_tables if db_config.matches_pattern(schema, t)] - - # Skip schema if no tables match - if not tables: - progress.update(schema_task, advance=1) - continue - - schema_path = db_path / f"schema={schema}" - schema_path.mkdir(parents=True, exist_ok=True) - state.add_schema(schema) - - table_task = progress.add_task( - f" [cyan]{schema}[/cyan]", - total=len(tables), - ) - - for table in tables: - table_path = schema_path / f"table={table}" - table_path.mkdir(parents=True, exist_ok=True) - - for accessor in accessors: - content = accessor.generate(conn, schema, table) - output_file = table_path / accessor.filename - output_file.write_text(content) - - state.add_table(schema, table) - progress.update(table_task, advance=1) - - progress.update(schema_task, advance=1) - - return state diff --git a/cli/nao_core/commands/sync/providers/databases/snowflake.py b/cli/nao_core/commands/sync/providers/databases/snowflake.py deleted file mode 100644 index 06ff915a..00000000 --- a/cli/nao_core/commands/sync/providers/databases/snowflake.py +++ /dev/null @@ -1,79 +0,0 @@ -from pathlib import Path - -from rich.progress import Progress - -from nao_core.commands.sync.accessors import DataAccessor -from nao_core.commands.sync.cleanup import DatabaseSyncState - - -def sync_snowflake( - db_config, - base_path: Path, - progress: Progress, - accessors: list[DataAccessor], -) -> DatabaseSyncState: - """Sync Snowflake database schema to markdown files. - - Args: - db_config: The database configuration - base_path: Base output path - progress: Rich progress instance - accessors: List of data accessors to run - - Returns: - DatabaseSyncState with sync results and tracked paths - """ - conn = db_config.connect() - db_name = db_config.get_database_name() - db_path = base_path / "type=snowflake" / f"database={db_name}" - state = DatabaseSyncState(db_path=db_path) - - if db_config.schema: - schemas = [db_config.schema] - else: - schemas = conn.list_databases() - - schema_task = progress.add_task( - f"[dim]{db_config.name}[/dim]", - total=len(schemas), - ) - - for schema in schemas: - try: - all_tables = conn.list_tables(database=schema) - except Exception: - progress.update(schema_task, advance=1) - continue - - # Filter tables based on include/exclude patterns - tables = [t for t in all_tables if db_config.matches_pattern(schema, t)] - - # Skip schema if no tables match - if not tables: - progress.update(schema_task, advance=1) - continue - - schema_path = db_path / f"schema={schema}" - schema_path.mkdir(parents=True, exist_ok=True) - state.add_schema(schema) - - table_task = progress.add_task( - f" [cyan]{schema}[/cyan]", - total=len(tables), - ) - - for table in tables: - table_path = schema_path / f"table={table}" - table_path.mkdir(parents=True, exist_ok=True) - - for accessor in accessors: - content = accessor.generate(conn, schema, table) - output_file = table_path / accessor.filename - output_file.write_text(content) - - state.add_table(schema, table) - progress.update(table_task, advance=1) - - progress.update(schema_task, advance=1) - - return state diff --git a/cli/nao_core/commands/sync/registry.py b/cli/nao_core/commands/sync/registry.py deleted file mode 100644 index 5216362b..00000000 --- a/cli/nao_core/commands/sync/registry.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Accessor registry for mapping accessor types to implementations.""" - -from nao_core.config import AccessorType - -from .accessors import ( - ColumnsAccessor, - DataAccessor, - DescriptionAccessor, - PreviewAccessor, - ProfilingAccessor, -) - -ACCESSOR_REGISTRY: dict[AccessorType, DataAccessor] = { - AccessorType.COLUMNS: ColumnsAccessor(), - AccessorType.PREVIEW: PreviewAccessor(num_rows=10), - AccessorType.DESCRIPTION: DescriptionAccessor(), - AccessorType.PROFILING: ProfilingAccessor(), -} - - -def get_accessors(accessor_types: list[AccessorType]) -> list[DataAccessor]: - """Get accessor instances for the given types.""" - return [ACCESSOR_REGISTRY[t] for t in accessor_types if t in ACCESSOR_REGISTRY] diff --git a/cli/nao_core/config/__init__.py b/cli/nao_core/config/__init__.py index c8439ae6..8220267f 100644 --- a/cli/nao_core/config/__init__.py +++ b/cli/nao_core/config/__init__.py @@ -1,6 +1,5 @@ from .base import NaoConfig from .databases import ( - AccessorType, AnyDatabaseConfig, BigQueryConfig, DatabaseType, @@ -15,7 +14,6 @@ __all__ = [ "NaoConfig", - "AccessorType", "AnyDatabaseConfig", "BigQueryConfig", "DuckDBConfig", diff --git a/cli/nao_core/config/databases/__init__.py b/cli/nao_core/config/databases/__init__.py index 1e97bcc4..8db9594a 100644 --- a/cli/nao_core/config/databases/__init__.py +++ b/cli/nao_core/config/databases/__init__.py @@ -2,7 +2,7 @@ from pydantic import Discriminator, Tag -from .base import AccessorType, DatabaseConfig, DatabaseType +from .base import DatabaseConfig, DatabaseType from .bigquery import BigQueryConfig from .databricks import DatabricksConfig from .duckdb import DuckDBConfig @@ -58,7 +58,6 @@ def parse_database_config(data: dict) -> DatabaseConfig: __all__ = [ - "AccessorType", "AnyDatabaseConfig", "BigQueryConfig", "DATABASE_CONFIG_CLASSES", diff --git a/cli/nao_core/config/databases/base.py b/cli/nao_core/config/databases/base.py index b25cfb6d..ab1eeadf 100644 --- a/cli/nao_core/config/databases/base.py +++ b/cli/nao_core/config/databases/base.py @@ -25,25 +25,12 @@ def choices(cls) -> list[questionary.Choice]: return [questionary.Choice(db.value.capitalize(), value=db.value) for db in cls] -class AccessorType(str, Enum): - """Available data accessors for sync.""" - - COLUMNS = "columns" - PREVIEW = "preview" - DESCRIPTION = "description" - PROFILING = "profiling" - - class DatabaseConfig(BaseModel, ABC): """Base configuration for all database backends.""" + type: str # Narrowed to Literal in each subclass for discriminated union name: str = Field(description="A friendly name for this connection") - # Sync settings - accessors: list[AccessorType] = Field( - default=[AccessorType.COLUMNS, AccessorType.PREVIEW, AccessorType.DESCRIPTION], - description="List of accessors to run during sync (columns, preview, description, profiling)", - ) include: list[str] = Field( default_factory=list, description="Glob patterns for schemas/tables to include (e.g., 'prod_*.*', 'analytics.dim_*'). Empty means include all.", @@ -93,5 +80,22 @@ def matches_pattern(self, schema: str, table: str) -> bool: @abstractmethod def get_database_name(self) -> str: """Get the database name for this database type.""" - ... + + def get_schemas(self, conn: BaseBackend) -> list[str]: + """Return the list of schemas to sync. Override in subclasses for custom behavior.""" + list_databases = getattr(conn, "list_databases", None) + if list_databases: + return list_databases() + return [] + + def check_connection(self) -> tuple[bool, str]: + """Test connectivity to the database. Override in subclasses for custom behavior.""" + try: + conn = self.connect() + if list_databases := getattr(conn, "list_databases", None): + schemas = list_databases() + return True, f"Connected successfully ({len(schemas)} schemas found)" + return True, "Connected successfully" + except Exception as e: + return False, str(e) diff --git a/cli/nao_core/config/databases/bigquery.py b/cli/nao_core/config/databases/bigquery.py index d03b2454..ed7f4ed5 100644 --- a/cli/nao_core/config/databases/bigquery.py +++ b/cli/nao_core/config/databases/bigquery.py @@ -105,5 +105,24 @@ def connect(self) -> BaseBackend: def get_database_name(self) -> str: """Get the database name for BigQuery.""" - return self.project_id + + def get_schemas(self, conn: BaseBackend) -> list[str]: + if self.dataset_id: + return [self.dataset_id] + list_databases = getattr(conn, "list_databases", None) + return list_databases() if list_databases else [] + + def check_connection(self) -> tuple[bool, str]: + """Test connectivity to BigQuery.""" + try: + conn = self.connect() + if self.dataset_id: + tables = conn.list_tables() + return True, f"Connected successfully ({len(tables)} tables found)" + if list_databases := getattr(conn, "list_databases", None): + schemas = list_databases() + return True, f"Connected successfully ({len(schemas)} datasets found)" + return True, "Connected successfully" + except Exception as e: + return False, str(e) diff --git a/cli/nao_core/config/databases/databricks.py b/cli/nao_core/config/databases/databricks.py index 92b0f7c9..005d364d 100644 --- a/cli/nao_core/config/databases/databricks.py +++ b/cli/nao_core/config/databases/databricks.py @@ -61,5 +61,24 @@ def connect(self) -> BaseBackend: def get_database_name(self) -> str: """Get the database name for Databricks.""" - return self.catalog or "main" + + def get_schemas(self, conn: BaseBackend) -> list[str]: + if self.schema_name: + return [self.schema_name] + list_databases = getattr(conn, "list_databases", None) + return list_databases() if list_databases else [] + + def check_connection(self) -> tuple[bool, str]: + """Test connectivity to Databricks.""" + try: + conn = self.connect() + if self.schema_name: + tables = conn.list_tables() + return True, f"Connected successfully ({len(tables)} tables found)" + if list_databases := getattr(conn, "list_databases", None): + schemas = list_databases() + return True, f"Connected successfully ({len(schemas)} schemas found)" + return True, "Connected successfully" + except Exception as e: + return False, str(e) diff --git a/cli/nao_core/config/databases/duckdb.py b/cli/nao_core/config/databases/duckdb.py index 6c452950..cdb5e7a5 100644 --- a/cli/nao_core/config/databases/duckdb.py +++ b/cli/nao_core/config/databases/duckdb.py @@ -33,7 +33,15 @@ def connect(self) -> BaseBackend: def get_database_name(self) -> str: """Get the database name for DuckDB.""" - if self.path == ":memory:": return "memory" return Path(self.path).stem + + def check_connection(self) -> tuple[bool, str]: + """Test connectivity to DuckDB.""" + try: + conn = self.connect() + tables = conn.list_tables() + return True, f"Connected successfully ({len(tables)} tables found)" + except Exception as e: + return False, str(e) diff --git a/cli/nao_core/config/databases/postgres.py b/cli/nao_core/config/databases/postgres.py index 5b7eaa40..4eb6ef66 100644 --- a/cli/nao_core/config/databases/postgres.py +++ b/cli/nao_core/config/databases/postgres.py @@ -66,5 +66,24 @@ def connect(self) -> BaseBackend: def get_database_name(self) -> str: """Get the database name for Postgres.""" - return self.database + + def get_schemas(self, conn: BaseBackend) -> list[str]: + if self.schema_name: + return [self.schema_name] + list_databases = getattr(conn, "list_databases", None) + return list_databases() if list_databases else [] + + def check_connection(self) -> tuple[bool, str]: + """Test connectivity to PostgreSQL.""" + try: + conn = self.connect() + if self.schema_name: + tables = conn.list_tables() + return True, f"Connected successfully ({len(tables)} tables found)" + if list_databases := getattr(conn, "list_databases", None): + schemas = list_databases() + return True, f"Connected successfully ({len(schemas)} schemas found)" + return True, "Connected successfully" + except Exception as e: + return False, str(e) diff --git a/cli/nao_core/config/databases/redshift.py b/cli/nao_core/config/databases/redshift.py index 45588c6c..965faca4 100644 --- a/cli/nao_core/config/databases/redshift.py +++ b/cli/nao_core/config/databases/redshift.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Literal +from typing import Any, Literal import ibis from ibis import BaseBackend @@ -12,6 +12,130 @@ from .base import DatabaseConfig +class RedshiftDatabaseContext: + """Redshift-specific context that bypasses Ibis's problematic pg_enum queries.""" + + def __init__(self, conn: BaseBackend, schema: str, table_name: str): + self._conn = conn + self._schema = schema + self._table_name = table_name + self._table_ref = None + + @property + def table(self): + if self._table_ref is None: + self._table_ref = self._conn.table(self._table_name, database=self._schema) + return self._table_ref + + def columns(self) -> list[dict[str, Any]]: + """Return column metadata by querying information_schema directly.""" + query = f""" + SELECT + column_name, + data_type, + is_nullable, + character_maximum_length, + numeric_precision, + numeric_scale + FROM information_schema.columns + WHERE table_schema = '{self._schema}' + AND table_name = '{self._table_name}' + ORDER BY ordinal_position + """ + result = self._conn.raw_sql(query).fetchall() # type: ignore[union-attr] + + columns = [] + for row in result: + col_name = row[0] + data_type = row[1] + is_nullable = row[2] == "YES" + char_length = row[3] + num_precision = row[4] + num_scale = row[5] + + # Map SQL types to Ibis-like type strings + formatted_type = self._format_redshift_type(data_type, is_nullable, char_length, num_precision, num_scale) + + columns.append( + { + "name": col_name, + "type": formatted_type, + "nullable": is_nullable, + "description": None, + } + ) + + return columns + + @staticmethod + def _format_redshift_type( + data_type: str, + is_nullable: bool, + char_length: int | None, + num_precision: int | None, + num_scale: int | None, + ) -> str: + """Convert Redshift SQL type to Ibis-like format.""" + # Map common Redshift types to Ibis types + type_map = { + "integer": "int32", + "bigint": "int64", + "smallint": "int16", + "boolean": "boolean", + "real": "float32", + "double precision": "float64", + "character varying": "string", + "character": "string", + "text": "string", + "date": "date", + "timestamp without time zone": "timestamp", + "timestamp with time zone": "timestamp", + } + + ibis_type = type_map.get(data_type, "string") + + if not is_nullable: + return f"{ibis_type} NOT NULL" + return ibis_type + + def preview(self, limit: int = 10) -> list[dict[str, Any]]: + """Return the first N rows as a list of dictionaries.""" + # Use raw SQL to avoid Ibis's pg_enum queries + query = f'SELECT * FROM "{self._schema}"."{self._table_name}" LIMIT {limit}' + result = self._conn.raw_sql(query).fetchall() # type: ignore[union-attr] + + # Get column names from the columns metadata + columns = self.columns() + col_names = [col["name"] for col in columns] + + rows = [] + for row in result: + row_dict = {} + for i, col_name in enumerate(col_names): + val = row[i] if i < len(row) else None + if val is not None and not isinstance(val, (str, int, float, bool, list, dict)): + row_dict[col_name] = str(val) + else: + row_dict[col_name] = val + rows.append(row_dict) + return rows + + def row_count(self) -> int: + """Return the total number of rows in the table.""" + # Use raw SQL to avoid Ibis's pg_enum queries + query = f'SELECT COUNT(*) FROM "{self._schema}"."{self._table_name}"' + result = self._conn.raw_sql(query).fetchone() # type: ignore[union-attr] + return result[0] if result else 0 + + def column_count(self) -> int: + """Return the number of columns in the table.""" + return len(self.columns()) + + def description(self) -> str | None: + """Return the table description if available.""" + return None + + class RedshiftSSHTunnelConfig(BaseModel): """SSH tunnel configuration for Redshift connection.""" @@ -129,5 +253,39 @@ def connect(self) -> BaseBackend: def get_database_name(self) -> str: """Get the database name for Redshift.""" - return self.database + + def get_schemas(self, conn: BaseBackend) -> list[str]: + if self.schema_name: + return [self.schema_name] + list_databases = getattr(conn, "list_databases", None) + schemas = list_databases() if list_databases else [] + return schemas + ["public"] + + def create_context(self, conn: BaseBackend, schema: str, table_name: str) -> RedshiftDatabaseContext: + """Create a Redshift-specific database context that avoids pg_enum queries.""" + return RedshiftDatabaseContext(conn, schema, table_name) + + def check_connection(self) -> tuple[bool, str]: + """Test connectivity to Redshift.""" + try: + conn = self.connect() + + if self.schema_name: + tables = conn.list_tables(database=self.schema_name) + return True, f"Connected successfully ({len(tables)} tables found)" + + if self.database: + if list_databases := getattr(conn, "list_databases", None): + schemas = list_databases() + ["public"] + else: + schemas = ["public"] + + tables = [] + for schema in schemas: + tables.extend(conn.list_tables(database=schema)) + return True, f"Connected successfully ({len(tables)} tables found)" + + return True, "Connected successfully" + except Exception as e: + return False, str(e) diff --git a/cli/nao_core/config/databases/snowflake.py b/cli/nao_core/config/databases/snowflake.py index 717d9044..9e5a0444 100644 --- a/cli/nao_core/config/databases/snowflake.py +++ b/cli/nao_core/config/databases/snowflake.py @@ -101,5 +101,24 @@ def connect(self) -> BaseBackend: def get_database_name(self) -> str: """Get the database name for Snowflake.""" - return self.database + + def get_schemas(self, conn: BaseBackend) -> list[str]: + if self.schema_name: + return [self.schema_name] + list_databases = getattr(conn, "list_databases", None) + return list_databases() if list_databases else [] + + def check_connection(self) -> tuple[bool, str]: + """Test connectivity to Snowflake.""" + try: + conn = self.connect() + if self.schema_name: + tables = conn.list_tables() + return True, f"Connected successfully ({len(tables)} tables found)" + if list_databases := getattr(conn, "list_databases", None): + schemas = list_databases() + return True, f"Connected successfully ({len(schemas)} schemas found)" + return True, "Connected successfully" + except Exception as e: + return False, str(e) diff --git a/cli/nao_core/templates/defaults/databases/columns.md.j2 b/cli/nao_core/templates/defaults/databases/columns.md.j2 index 7dcba7b4..e80a0964 100644 --- a/cli/nao_core/templates/defaults/databases/columns.md.j2 +++ b/cli/nao_core/templates/defaults/databases/columns.md.j2 @@ -1,22 +1,23 @@ {# Template: columns.md.j2 Description: Generates column documentation for a database table - - Available variables: + + Available context: - table_name (str): Name of the table - dataset (str): Schema/dataset name - - columns (list): List of column dictionaries with: - - name (str): Column name - - type (str): Data type - - nullable (bool): Whether the column allows nulls - - description (str|None): Column description if available - - column_count (int): Total number of columns + - db (DatabaseContext): Database context with helper methods + - db.columns() -> list of dicts with: name, type, nullable, description + - db.preview(limit=10) -> list of row dicts + - db.row_count() -> int + - db.column_count() -> int + - db.description() -> str or None #} +{% set columns = db.columns() %} # {{ table_name }} **Dataset:** `{{ dataset }}` -## Columns ({{ column_count }}) +## Columns ({{ columns | length }}) {% for col in columns %} - {{ col.name }} ({{ col.type }}{% if col.description %}, "{{ col.description | truncate_middle(256) }}"{% endif %}) diff --git a/cli/nao_core/templates/defaults/databases/description.md.j2 b/cli/nao_core/templates/defaults/databases/description.md.j2 index 5f5824ac..391052eb 100644 --- a/cli/nao_core/templates/defaults/databases/description.md.j2 +++ b/cli/nao_core/templates/defaults/databases/description.md.j2 @@ -1,16 +1,11 @@ {# Template: description.md.j2 Description: Generates table metadata and description documentation - - Available variables: + + Available context: - table_name (str): Name of the table - dataset (str): Schema/dataset name - - row_count (int): Total number of rows in the table - - column_count (int): Number of columns in the table - - description (str|None): Table description if available - - columns (list): List of column dictionaries with: - - name (str): Column name - - type (str): Data type + - db (DatabaseContext): Database context with helper methods #} # {{ table_name }} @@ -20,13 +15,13 @@ | Property | Value | |----------|-------| -| **Row Count** | {{ "{:,}".format(row_count) }} | -| **Column Count** | {{ column_count }} | +| **Row Count** | {{ "{:,}".format(db.row_count()) }} | +| **Column Count** | {{ db.column_count() }} | ## Description -{% if description %} -{{ description }} +{% if db.description() %} +{{ db.description() }} {% else %} _No description available._ {% endif %} diff --git a/cli/nao_core/templates/defaults/databases/preview.md.j2 b/cli/nao_core/templates/defaults/databases/preview.md.j2 index f6ce56fa..f789ad23 100644 --- a/cli/nao_core/templates/defaults/databases/preview.md.j2 +++ b/cli/nao_core/templates/defaults/databases/preview.md.j2 @@ -1,21 +1,18 @@ {# Template: preview.md.j2 Description: Generates a preview of table rows in JSONL format - - Available variables: + + Available context: - table_name (str): Name of the table - dataset (str): Schema/dataset name - - rows (list): List of row dictionaries (first N rows of the table) - - row_count (int): Number of preview rows shown - - columns (list): List of column dictionaries with: - - name (str): Column name - - type (str): Data type + - db (DatabaseContext): Database context with helper methods #} +{% set rows = db.preview() %} # {{ table_name }} - Preview **Dataset:** `{{ dataset }}` -## Rows ({{ row_count }}) +## Rows ({{ rows | length }}) {% for row in rows %} - {{ row | to_json }} diff --git a/cli/nao_core/templates/defaults/databases/profiling.md.j2 b/cli/nao_core/templates/defaults/databases/profiling.md.j2 deleted file mode 100644 index baf155db..00000000 --- a/cli/nao_core/templates/defaults/databases/profiling.md.j2 +++ /dev/null @@ -1,34 +0,0 @@ -{# - Template: profiling.md.j2 - Description: Generates column-level statistics and profiling data - - Available variables: - - table_name (str): Name of the table - - dataset (str): Schema/dataset name - - column_stats (list): List of column statistics dictionaries with: - - name (str): Column name - - type (str): Data type - - null_count (int): Number of null values - - unique_count (int): Number of unique values - - min_value (str|None): Minimum value (for numeric/temporal columns) - - max_value (str|None): Maximum value (for numeric/temporal columns) - - error (str|None): Error message if stats couldn't be computed - - columns (list): List of column dictionaries with: - - name (str): Column name - - type (str): Data type -#} -# {{ table_name }} - Profiling - -**Dataset:** `{{ dataset }}` - -## Column Statistics - -| Column | Type | Nulls | Unique | Min | Max | -|--------|------|-------|--------|-----|-----| -{% for stat in column_stats %} -{% if stat.error %} -| `{{ stat.name }}` | `{{ stat.type }}` | Error: {{ stat.error }} | | | | -{% else %} -| `{{ stat.name }}` | `{{ stat.type }}` | {{ "{:,}".format(stat.null_count) }} | {{ "{:,}".format(stat.unique_count) }} | {{ stat.min_value or "" }} | {{ stat.max_value or "" }} | -{% endif %} -{% endfor %} diff --git a/cli/nao_core/templates/engine.py b/cli/nao_core/templates/engine.py index 09d0bc06..809d5a42 100644 --- a/cli/nao_core/templates/engine.py +++ b/cli/nao_core/templates/engine.py @@ -99,6 +99,28 @@ def has_template(self, template_name: str) -> bool: except Exception: return False + def list_templates(self, prefix: str) -> list[str]: + """List all available templates under a given prefix. + + Merges defaults and user overrides, returning unique template names. + """ + templates: set[str] = set() + + # Collect from default templates + default_dir = DEFAULT_TEMPLATES_DIR / prefix + if default_dir.exists(): + for path in default_dir.rglob("*.j2"): + templates.add(f"{prefix}/{path.relative_to(default_dir)}") + + # Collect from user templates (may add new ones or override defaults) + if self.user_templates_dir: + user_dir = self.user_templates_dir / prefix + if user_dir.exists(): + for path in user_dir.rglob("*.j2"): + templates.add(f"{prefix}/{path.relative_to(user_dir)}") + + return sorted(templates) + def is_user_override(self, template_name: str) -> bool: """Check if a template is a user override. diff --git a/cli/tests/nao_core/commands/sync/integration/.env.example b/cli/tests/nao_core/commands/sync/integration/.env.example new file mode 100644 index 00000000..e445d6c7 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/.env.example @@ -0,0 +1,11 @@ +# Copy to .env and fill in real values to run Redshift integration tests. +# The .env file is gitignored — never commit credentials. + +# Redshift +REDSHIFT_HOST=your-cluster.region.redshift.amazonaws.com +REDSHIFT_PORT=5439 +REDSHIFT_DATABASE=dev +REDSHIFT_USER=awsuser +REDSHIFT_PASSWORD=changeme +REDSHIFT_SCHEMA=public +REDSHIFT_SSLMODE=require diff --git a/cli/tests/nao_core/commands/sync/integration/__init__.py b/cli/tests/nao_core/commands/sync/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cli/tests/nao_core/commands/sync/integration/conftest.py b/cli/tests/nao_core/commands/sync/integration/conftest.py new file mode 100644 index 00000000..c62f792e --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/conftest.py @@ -0,0 +1,20 @@ +"""Shared fixtures for database sync integration tests.""" + +from pathlib import Path + +import pytest +from dotenv import load_dotenv + +import nao_core.templates.engine as engine_module + +# Auto-load .env sitting next to this conftest so env vars are available +# before pytest collects test modules (where skipif reads them). +load_dotenv(Path(__file__).parent / ".env") + + +@pytest.fixture(autouse=True) +def reset_template_engine(): + """Reset the global template engine between tests.""" + engine_module._engine = None + yield + engine_module._engine = None diff --git a/cli/tests/nao_core/commands/sync/integration/dml/redshift.sql b/cli/tests/nao_core/commands/sync/integration/dml/redshift.sql new file mode 100644 index 00000000..1c1fb49d --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/dml/redshift.sql @@ -0,0 +1,29 @@ +CREATE TABLE nao_unit_tests.public.users ( +id INTEGER NOT NULL, +name VARCHAR NOT NULL, +email VARCHAR, +active BOOLEAN DEFAULT TRUE +); + +INSERT INTO nao_unit_tests.public.users VALUES +(1, 'Alice', 'alice@example.com', true), +(2, 'Bob', NULL, false), +(3, 'Charlie', 'charlie@example.com', true); + + +CREATE TABLE nao_unit_tests.public.orders ( +id INTEGER NOT NULL, +user_id INTEGER NOT NULL, +amount FLOAT4 NOT NULL +); + +INSERT INTO nao_unit_tests.public.orders VALUES +(1, 1, 99.99), +(2, 1, 24.50); + +CREATE SCHEMA nao_unit_tests.another; + +CREATE TABLE nao_unit_tests.another.whatever ( +id INTEGER NOT NULL, +price FLOAT4 NOT NULL +); \ No newline at end of file diff --git a/cli/tests/nao_core/commands/sync/integration/test_duckdb.py b/cli/tests/nao_core/commands/sync/integration/test_duckdb.py new file mode 100644 index 00000000..6d20f3ba --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/test_duckdb.py @@ -0,0 +1,219 @@ +"""Integration tests for the database sync pipeline using a real DuckDB database.""" + +import json + +import duckdb +import pytest +from rich.progress import Progress + +from nao_core.commands.sync.providers.databases.provider import sync_database +from nao_core.config.databases.duckdb import DuckDBConfig + + +@pytest.fixture +def duckdb_path(tmp_path): + """Create a DuckDB database with two tables: users and orders.""" + db_path = tmp_path / "test.duckdb" + conn = duckdb.connect(str(db_path)) + + conn.execute(""" + CREATE TABLE users ( + id INTEGER NOT NULL, + name VARCHAR NOT NULL, + email VARCHAR, + active BOOLEAN DEFAULT TRUE + ) + """) + conn.execute(""" + INSERT INTO users VALUES + (1, 'Alice', 'alice@example.com', true), + (2, 'Bob', NULL, false), + (3, 'Charlie', 'charlie@example.com', true) + """) + + conn.execute(""" + CREATE TABLE orders ( + id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + amount DOUBLE NOT NULL + ) + """) + conn.execute(""" + INSERT INTO orders VALUES + (1, 1, 99.99), + (2, 1, 24.50) + """) + + conn.close() + return db_path + + +class TestDuckDBSyncIntegration: + def _sync(self, duckdb_path, output_path): + config = DuckDBConfig(name="test-db", path=str(duckdb_path)) + + with Progress(transient=True) as progress: + state = sync_database(config, output_path, progress) + + return state + + def test_creates_expected_directory_tree(self, tmp_path, duckdb_path): + output = tmp_path / "output" + self._sync(duckdb_path, output) + + base = output / "type=duckdb" / "database=test" + + # Schema directory + assert (base / "schema=main").is_dir() + + # Table directories + assert (base / "schema=main" / "table=users").is_dir() + assert (base / "schema=main" / "table=orders").is_dir() + + # Each table should have exactly the 3 default template outputs + for table in ("users", "orders"): + table_dir = base / "schema=main" / f"table={table}" + files = sorted(f.name for f in table_dir.iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + def test_columns_md_users(self, tmp_path, duckdb_path): + output = tmp_path / "output" + self._sync(duckdb_path, output) + + content = (output / "type=duckdb" / "database=test" / "schema=main" / "table=users" / "columns.md").read_text() + + # NOT NULL columns are prefixed with ! by Ibis (e.g. !int32) + assert content == ( + "# users\n" + "\n" + "**Dataset:** `main`\n" + "\n" + "## Columns (4)\n" + "\n" + "- id (int32 NOT NULL)\n" + "- name (string NOT NULL)\n" + "- email (string)\n" + "- active (boolean)\n" + ) + + def test_columns_md_orders(self, tmp_path, duckdb_path): + output = tmp_path / "output" + self._sync(duckdb_path, output) + + content = (output / "type=duckdb" / "database=test" / "schema=main" / "table=orders" / "columns.md").read_text() + + assert content == ( + "# orders\n" + "\n" + "**Dataset:** `main`\n" + "\n" + "## Columns (3)\n" + "\n" + "- id (int32 NOT NULL)\n" + "- user_id (int32 NOT NULL)\n" + "- amount (float64 NOT NULL)\n" + ) + + def test_description_md_users(self, tmp_path, duckdb_path): + output = tmp_path / "output" + self._sync(duckdb_path, output) + + content = ( + output / "type=duckdb" / "database=test" / "schema=main" / "table=users" / "description.md" + ).read_text() + + assert content == ( + "# users\n" + "\n" + "**Dataset:** `main`\n" + "\n" + "## Table Metadata\n" + "\n" + "| Property | Value |\n" + "|----------|-------|\n" + "| **Row Count** | 3 |\n" + "| **Column Count** | 4 |\n" + "\n" + "## Description\n" + "\n" + "_No description available._\n" + ) + + def test_description_md_orders(self, tmp_path, duckdb_path): + output = tmp_path / "output" + self._sync(duckdb_path, output) + + content = ( + output / "type=duckdb" / "database=test" / "schema=main" / "table=orders" / "description.md" + ).read_text() + + assert "| **Row Count** | 2 |" in content + assert "| **Column Count** | 3 |" in content + + def test_preview_md_users(self, tmp_path, duckdb_path): + output = tmp_path / "output" + self._sync(duckdb_path, output) + + content = (output / "type=duckdb" / "database=test" / "schema=main" / "table=users" / "preview.md").read_text() + + assert "# users - Preview" in content + assert "**Dataset:** `main`" in content + assert "## Rows (3)" in content + + # Parse the JSONL rows from the markdown + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 3 + assert rows[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} + assert rows[1] == {"id": 2, "name": "Bob", "email": None, "active": False} + assert rows[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} + + def test_preview_md_orders(self, tmp_path, duckdb_path): + output = tmp_path / "output" + self._sync(duckdb_path, output) + + content = (output / "type=duckdb" / "database=test" / "schema=main" / "table=orders" / "preview.md").read_text() + + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 2 + assert rows[0] == {"id": 1, "user_id": 1, "amount": 99.99} + assert rows[1] == {"id": 2, "user_id": 1, "amount": 24.5} + + def test_sync_state_tracks_schemas_and_tables(self, tmp_path, duckdb_path): + output = tmp_path / "output" + state = self._sync(duckdb_path, output) + + assert state.schemas_synced == 1 + assert state.tables_synced == 2 + assert "main" in state.synced_schemas + assert "users" in state.synced_tables["main"] + assert "orders" in state.synced_tables["main"] + + def test_include_filter(self, tmp_path, duckdb_path): + """Only tables matching include patterns should be synced.""" + config = DuckDBConfig(name="test-db", path=str(duckdb_path), include=["main.users"]) + + output = tmp_path / "output" + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=duckdb" / "database=test" / "schema=main" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_exclude_filter(self, tmp_path, duckdb_path): + """Tables matching exclude patterns should be skipped.""" + config = DuckDBConfig(name="test-db", path=str(duckdb_path), exclude=["main.orders"]) + + output = tmp_path / "output" + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=duckdb" / "database=test" / "schema=main" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 diff --git a/cli/tests/nao_core/commands/sync/integration/test_redshift.py b/cli/tests/nao_core/commands/sync/integration/test_redshift.py new file mode 100644 index 00000000..d1a371f8 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/test_redshift.py @@ -0,0 +1,293 @@ +"""Integration tests for the database sync pipeline against a real Redshift cluster. + +Connection is configured via environment variables: + REDSHIFT_HOST, REDSHIFT_PORT (default 5439), REDSHIFT_DATABASE, + REDSHIFT_USER, REDSHIFT_PASSWORD, REDSHIFT_SCHEMA (default public), + REDSHIFT_SSLMODE (default require). + +The test suite is skipped entirely when REDSHIFT_HOST is not set. +""" + +import json +import os + +import pytest +from rich.progress import Progress + +from nao_core.commands.sync.providers.databases.provider import sync_database +from nao_core.config.databases.redshift import RedshiftConfig + +REDSHIFT_HOST = os.environ.get("REDSHIFT_HOST") + +pytestmark = pytest.mark.skipif( + REDSHIFT_HOST is None, reason="REDSHIFT_HOST not set — skipping Redshift integration tests" +) + +# ibis uses pg_catalog.pg_enum which Redshift does not support +KNOWN_ERROR = "pg_catalog.pg_enum" + + +@pytest.fixture(scope="module") +def redshift_config(): + """Build a RedshiftConfig from environment variables.""" + return RedshiftConfig( + name="test-redshift", + host=os.environ["REDSHIFT_HOST"], + port=int(os.environ.get("REDSHIFT_PORT", "5439")), + database=os.environ["REDSHIFT_DATABASE"], + user=os.environ["REDSHIFT_USER"], + password=os.environ["REDSHIFT_PASSWORD"], + schema_name=os.environ.get("REDSHIFT_SCHEMA", "public"), + sslmode=os.environ.get("REDSHIFT_SSLMODE", "require"), + ) + + +@pytest.fixture(scope="module") +def synced(tmp_path_factory, redshift_config): + """Run sync once for the whole module and return (state, output_path, config).""" + output = tmp_path_factory.mktemp("redshift_sync") + + with Progress(transient=True) as progress: + state = sync_database(redshift_config, output, progress) + + return state, output, redshift_config + + +class TestRedshiftSyncIntegration: + """Verify the sync pipeline produces correct output against a live Redshift cluster.""" + + def test_creates_expected_directory_tree(self, synced): + state, output, config = synced + + base = output / "type=redshift" / "database=nao_unit_tests" / "schema=public" + + # Schema directory + assert base.is_dir() + + # Each table should have exactly the 3 default template outputs + for table in ("orders", "users"): + assert (base / f"table={table}").is_dir() + table_dir = base / f"table={table}" + files = sorted(f.name for f in table_dir.iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + def test_columns_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users" / "columns.md" + ).read_text() + + # NOT NULL columns are prefixed with ! by Ibis (e.g. !int32) + assert content == ( + "# users\n" + "\n" + "**Dataset:** `public`\n" + "\n" + "## Columns (4)\n" + "\n" + "- id (int32 NOT NULL)\n" + "- name (string NOT NULL)\n" + "- email (string)\n" + "- active (boolean)\n" + ) + + def test_columns_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders" / "columns.md" + ).read_text() + + assert content == ( + "# orders\n" + "\n" + "**Dataset:** `public`\n" + "\n" + "## Columns (3)\n" + "\n" + "- id (int32 NOT NULL)\n" + "- user_id (int32 NOT NULL)\n" + "- amount (float32 NOT NULL)\n" + ) + + def test_description_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users" / "description.md" + ).read_text() + + assert content == ( + "# users\n" + "\n" + "**Dataset:** `public`\n" + "\n" + "## Table Metadata\n" + "\n" + "| Property | Value |\n" + "|----------|-------|\n" + "| **Row Count** | 3 |\n" + "| **Column Count** | 4 |\n" + "\n" + "## Description\n" + "\n" + "_No description available._\n" + ) + + def test_description_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders" / "description.md" + ).read_text() + + assert "| **Row Count** | 2 |" in content + assert "| **Column Count** | 3 |" in content + + def test_preview_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users" / "preview.md" + ).read_text() + + assert "# users - Preview" in content + assert "**Dataset:** `public`" in content + assert "## Rows (3)" in content + + # Parse the JSONL rows from the markdown + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 3 + assert rows[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} + assert rows[1] == {"id": 2, "name": "Bob", "email": None, "active": False} + assert rows[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} + + def test_preview_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders" / "preview.md" + ).read_text() + + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 2 + assert rows[0] == {"id": 1, "user_id": 1, "amount": 99.99} + assert rows[1] == {"id": 2, "user_id": 1, "amount": 24.5} + + def test_sync_state_tracks_schemas_and_tables(self, synced): + state, output, config = synced + + assert state.schemas_synced == 1 + assert state.tables_synced == 2 + assert "public" in state.synced_schemas + assert "users" in state.synced_tables["public"] + assert "orders" in state.synced_tables["public"] + + def test_include_filter(self, tmp_path_factory, redshift_config): + """Only tables matching include patterns should be synced.""" + config = RedshiftConfig( + name=redshift_config.name, + host=redshift_config.host, + port=redshift_config.port, + database=redshift_config.database, + user=redshift_config.user, + password=redshift_config.password, + schema_name=redshift_config.schema_name, + sslmode=redshift_config.sslmode, + include=["public.users"], + ) + + output = tmp_path_factory.mktemp("redshift_include") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=redshift" / "database=nao_unit_tests" / "schema=public" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_exclude_filter(self, tmp_path_factory, redshift_config): + """Tables matching exclude patterns should be skipped.""" + config = RedshiftConfig( + name=redshift_config.name, + host=redshift_config.host, + port=redshift_config.port, + database=redshift_config.database, + user=redshift_config.user, + password=redshift_config.password, + schema_name=redshift_config.schema_name, + sslmode=redshift_config.sslmode, + exclude=["public.orders"], + ) + + output = tmp_path_factory.mktemp("redshift_exclude") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=redshift" / "database=nao_unit_tests" / "schema=public" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, redshift_config): + """When schema_name is not provided, all schemas should be synced.""" + config = RedshiftConfig( + name=redshift_config.name, + host=redshift_config.host, + port=redshift_config.port, + database=redshift_config.database, + user=redshift_config.user, + password=redshift_config.password, + schema_name=None, + sslmode=redshift_config.sslmode, + ) + + output = tmp_path_factory.mktemp("redshift_all_schemas") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + # Verify public schema tables + assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=public").is_dir() + assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users").is_dir() + assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders").is_dir() + + # Verify public.users files + files = sorted( + f.name + for f in (output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users").iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify public.orders files + files = sorted( + f.name + for f in (output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders").iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify another schema table + assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=another").is_dir() + assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=another" / "table=whatever").is_dir() + + # Verify another.whatever files + files = sorted( + f.name + for f in ( + output / "type=redshift" / "database=nao_unit_tests" / "schema=another" / "table=whatever" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify state + assert state.schemas_synced == 2 + assert state.tables_synced == 3 + assert "public" in state.synced_schemas + assert "another" in state.synced_schemas + assert "users" in state.synced_tables["public"] + assert "orders" in state.synced_tables["public"] + assert "whatever" in state.synced_tables["another"] diff --git a/cli/tests/nao_core/commands/sync/test_accessors.py b/cli/tests/nao_core/commands/sync/test_accessors.py deleted file mode 100644 index 8fa211c4..00000000 --- a/cli/tests/nao_core/commands/sync/test_accessors.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Unit tests for sync accessors.""" - -from unittest.mock import MagicMock - -from nao_core.commands.sync.accessors import ( - ColumnsAccessor, - DescriptionAccessor, - PreviewAccessor, - ProfilingAccessor, - truncate_middle, -) - - -class TestTruncateMiddle: - def test_short_text_unchanged(self): - assert truncate_middle("hello", 10) == "hello" - - def test_exact_length_unchanged(self): - assert truncate_middle("hello", 5) == "hello" - - def test_long_text_truncated(self): - result = truncate_middle("hello world example", 10) - assert len(result) <= 10 - assert "..." in result - - def test_truncation_preserves_start_and_end(self): - result = truncate_middle("abcdefghijklmnop", 10) - assert result.startswith("abc") - assert result.endswith("nop") - - -class TestColumnsAccessor: - def test_filename(self): - accessor = ColumnsAccessor() - assert accessor.filename == "columns.md" - - def test_generate_creates_markdown(self): - accessor = ColumnsAccessor() - mock_conn = MagicMock() - mock_table = MagicMock() - mock_schema = MagicMock() - mock_schema.items.return_value = [ - ("id", "int64"), - ("name", "string"), - ] - mock_table.schema.return_value = mock_schema - mock_conn.table.return_value = mock_table - - result = accessor.generate(mock_conn, "my_dataset", "my_table") - - assert "# my_table" in result - assert "**Dataset:** `my_dataset`" in result - assert "## Columns (2)" in result - assert "- id (int64)" in result - assert "- name (string)" in result - - def test_generate_handles_error(self): - accessor = ColumnsAccessor() - mock_conn = MagicMock() - mock_conn.table.side_effect = Exception("Connection error") - - result = accessor.generate(mock_conn, "my_dataset", "my_table") - - assert "# my_table" in result - assert "Error generating content" in result or "Connection error" in result - - -class TestPreviewAccessor: - def test_filename(self): - accessor = PreviewAccessor() - assert accessor.filename == "preview.md" - - def test_custom_num_rows(self): - accessor = PreviewAccessor(num_rows=5) - assert accessor.num_rows == 5 - - -class TestDescriptionAccessor: - def test_filename(self): - accessor = DescriptionAccessor() - assert accessor.filename == "description.md" - - -class TestProfilingAccessor: - def test_filename(self): - accessor = ProfilingAccessor() - assert accessor.filename == "profiling.md" diff --git a/cli/tests/nao_core/commands/sync/test_context.py b/cli/tests/nao_core/commands/sync/test_context.py new file mode 100644 index 00000000..9d8c9b65 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/test_context.py @@ -0,0 +1,74 @@ +"""Unit tests for DatabaseContext.""" + +from unittest.mock import MagicMock + +import pandas as pd + +from nao_core.commands.sync.providers.databases.context import DatabaseContext + + +class TestDatabaseContext: + def _make_context(self): + mock_conn = MagicMock() + mock_table = MagicMock() + mock_schema = MagicMock() + schema_items = [ + ("id", MagicMock(__str__=lambda s: "int64", nullable=False)), + ("name", MagicMock(__str__=lambda s: "string", nullable=True)), + ] + mock_schema.items.return_value = schema_items + mock_schema.__len__ = lambda s: len(schema_items) + mock_table.schema.return_value = mock_schema + mock_conn.table.return_value = mock_table + return DatabaseContext(mock_conn, "my_schema", "my_table"), mock_table + + def test_columns_returns_metadata(self): + ctx, _ = self._make_context() + columns = ctx.columns() + + assert len(columns) == 2 + assert columns[0]["name"] == "id" + assert columns[0]["type"] == "int64" + assert columns[1]["name"] == "name" + assert columns[1]["type"] == "string" + + def test_preview_returns_rows(self): + ctx, mock_table = self._make_context() + df = pd.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}) + mock_table.limit.return_value.execute.return_value = df + + rows = ctx.preview(limit=2) + + assert len(rows) == 2 + assert rows[0]["name"] == "Alice" + mock_table.limit.assert_called_once_with(2) + + def test_row_count(self): + ctx, mock_table = self._make_context() + mock_table.count.return_value.execute.return_value = 42 + + assert ctx.row_count() == 42 + + def test_column_count(self): + ctx, _ = self._make_context() + assert ctx.column_count() == 2 + + def test_description_returns_none_by_default(self): + ctx, _ = self._make_context() + assert ctx.description() is None + + def test_table_is_lazily_loaded(self): + mock_conn = MagicMock() + ctx = DatabaseContext(mock_conn, "schema", "table") + + mock_conn.table.assert_not_called() + _ = ctx.table + mock_conn.table.assert_called_once_with("table", database="schema") + + def test_table_is_cached(self): + mock_conn = MagicMock() + ctx = DatabaseContext(mock_conn, "schema", "table") + + _ = ctx.table + _ = ctx.table + mock_conn.table.assert_called_once() diff --git a/cli/tests/nao_core/commands/sync/test_providers.py b/cli/tests/nao_core/commands/sync/test_providers.py index b5906a2e..ad92464a 100644 --- a/cli/tests/nao_core/commands/sync/test_providers.py +++ b/cli/tests/nao_core/commands/sync/test_providers.py @@ -48,9 +48,9 @@ def test_returns_list_of_providers(self): providers = get_all_providers() assert len(providers) == 3 - assert any(isinstance(p, RepositorySyncProvider) for p in providers) - assert any(isinstance(p, DatabaseSyncProvider) for p in providers) - assert any(isinstance(p, NotionSyncProvider) for p in providers) + assert any(isinstance(p.provider, RepositorySyncProvider) for p in providers) + assert any(isinstance(p.provider, DatabaseSyncProvider) for p in providers) + assert any(isinstance(p.provider, NotionSyncProvider) for p in providers) def test_returns_copy_of_providers(self): providers1 = get_all_providers() diff --git a/cli/tests/nao_core/commands/sync/test_registry.py b/cli/tests/nao_core/commands/sync/test_registry.py deleted file mode 100644 index da2a2320..00000000 --- a/cli/tests/nao_core/commands/sync/test_registry.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Unit tests for sync accessor registry.""" - -from nao_core.commands.sync.accessors import ( - ColumnsAccessor, - PreviewAccessor, -) -from nao_core.commands.sync.registry import get_accessors -from nao_core.config import AccessorType - - -class TestGetAccessors: - def test_returns_accessors_for_valid_types(self): - accessor_types = [AccessorType.COLUMNS, AccessorType.PREVIEW] - accessors = get_accessors(accessor_types) - - assert len(accessors) == 2 - assert isinstance(accessors[0], ColumnsAccessor) - assert isinstance(accessors[1], PreviewAccessor) - - def test_returns_empty_list_for_empty_input(self): - accessors = get_accessors([]) - assert accessors == [] - - def test_all_accessor_types(self): - all_types = [ - AccessorType.COLUMNS, - AccessorType.PREVIEW, - AccessorType.DESCRIPTION, - AccessorType.PROFILING, - ] - accessors = get_accessors(all_types) - - assert len(accessors) == 4 diff --git a/cli/tests/nao_core/commands/sync/test_sync.py b/cli/tests/nao_core/commands/sync/test_sync.py index 25e1242b..7ca4d2db 100644 --- a/cli/tests/nao_core/commands/sync/test_sync.py +++ b/cli/tests/nao_core/commands/sync/test_sync.py @@ -6,7 +6,7 @@ import pytest from nao_core.commands.sync import sync -from nao_core.commands.sync.providers import SyncProvider, SyncResult +from nao_core.commands.sync.providers import ProviderSelection, SyncProvider, SyncResult def _make_provider( @@ -32,7 +32,7 @@ def _make_provider( provider_name=name, items_synced=items_synced, ) - return provider + return ProviderSelection(provider) @pytest.mark.usefixtures("clean_env") @@ -50,42 +50,42 @@ def test_sync_exits_when_no_config_found(self, tmp_path: Path, monkeypatch): def test_sync_runs_providers_when_config_exists(self, create_config): create_config() - mock_provider = _make_provider() + selection = _make_provider() with patch("nao_core.commands.sync.console"): - sync(providers=[mock_provider]) + sync(_providers=[selection]) - mock_provider.should_sync.assert_called_once() + selection.provider.should_sync.assert_called_once() def test_sync_uses_custom_output_dirs(self, tmp_path: Path, create_config): create_config() - mock_provider = _make_provider(output_dir="default-output", items=["item1"], items_synced=1) + selection = _make_provider(output_dir="default-output", items=["item1"], items_synced=1) custom_output = str(tmp_path / "custom-output") with patch("nao_core.commands.sync.console"): - sync(output_dirs={"TestProvider": custom_output}, providers=[mock_provider]) + sync(output_dirs={"TestProvider": custom_output}, _providers=[selection]) # Verify sync was called with the custom output path - call_args = mock_provider.sync.call_args + call_args = selection.provider.sync.call_args assert str(call_args[0][1]) == custom_output def test_sync_skips_provider_when_should_sync_false(self, create_config): create_config() - mock_provider = _make_provider(should_sync=False) + selection = _make_provider(should_sync=False) with patch("nao_core.commands.sync.console"): - sync(providers=[mock_provider]) + sync(_providers=[selection]) # sync should not be called when should_sync returns False - mock_provider.sync.assert_not_called() + selection.provider.sync.assert_not_called() def test_sync_prints_nothing_to_sync_when_no_results(self, create_config): create_config() - mock_provider = _make_provider() + selection = _make_provider() with patch("nao_core.commands.sync.console") as mock_console: - sync(providers=[mock_provider]) + sync(_providers=[selection]) # Check that "Nothing to sync" was printed calls = [str(call) for call in mock_console.print.call_args_list] @@ -103,12 +103,12 @@ def test_sync_continues_when_provider_fails(self, create_config): with patch("nao_core.commands.sync.console"): with pytest.raises(SystemExit) as exc_info: - sync(providers=[failing, working]) + sync(_providers=[failing, working]) assert exc_info.value.code == 1 # Verify both providers were attempted - failing.sync.assert_called_once() - working.sync.assert_called_once() + failing.provider.sync.assert_called_once() + working.provider.sync.assert_called_once() def test_sync_shows_partial_success_when_some_providers_fail(self, create_config): """Test that sync shows partial success status when some providers fail.""" @@ -122,7 +122,7 @@ def test_sync_shows_partial_success_when_some_providers_fail(self, create_config with patch("nao_core.commands.sync.console") as mock_console: with pytest.raises(SystemExit) as exc_info: - sync(providers=[failing, working]) + sync(_providers=[failing, working]) assert exc_info.value.code == 1 calls = [str(call) for call in mock_console.print.call_args_list] @@ -140,7 +140,7 @@ def test_sync_shows_failure_when_all_providers_fail(self, create_config): with patch("nao_core.commands.sync.console") as mock_console: with pytest.raises(SystemExit) as exc_info: - sync(providers=[failing]) + sync(_providers=[failing]) assert exc_info.value.code == 1 calls = [str(call) for call in mock_console.print.call_args_list] diff --git a/cli/tests/nao_core/commands/test_debug.py b/cli/tests/nao_core/commands/test_debug.py index 77a7af38..00dfe416 100644 --- a/cli/tests/nao_core/commands/test_debug.py +++ b/cli/tests/nao_core/commands/test_debug.py @@ -2,7 +2,8 @@ import pytest -from nao_core.commands.debug import check_database_connection, check_llm_connection, debug +from nao_core.commands.debug import check_llm_connection, debug +from nao_core.config.databases import BigQueryConfig, DuckDBConfig, PostgresConfig from nao_core.config.llm import LLMConfig, LLMProvider @@ -133,46 +134,58 @@ def test_mistral_exception_returns_failure(self): class TestDatabaseConnection: - """Tests for check_database_connection.""" + """Tests for check_connection method on database configs.""" - def test_connection_with_tables(self): - mock_db = MagicMock() - mock_db.dataset_id = "my_dataset" + def test_bigquery_connection_with_dataset(self): + config = BigQueryConfig(name="test", project_id="my-project", dataset_id="my_dataset") mock_conn = MagicMock() mock_conn.list_tables.return_value = ["table1", "table2"] - mock_db.connect.return_value = mock_conn - success, message = check_database_connection(mock_db) + with patch.object(BigQueryConfig, "connect", return_value=mock_conn): + success, message = config.check_connection() assert success is True assert "2 tables found" in message - def test_connection_with_schemas(self): - mock_db = MagicMock(spec=["connect", "name", "type"]) # no dataset_id + def test_bigquery_connection_with_schemas(self): + config = BigQueryConfig(name="test", project_id="my-project") mock_conn = MagicMock() mock_conn.list_databases.return_value = ["schema1", "schema2", "schema3"] - mock_db.connect.return_value = mock_conn - success, message = check_database_connection(mock_db) + with patch.object(BigQueryConfig, "connect", return_value=mock_conn): + success, message = config.check_connection() assert success is True - assert "3 schemas found" in message + assert "3 datasets found" in message - def test_connection_fallback(self): - mock_db = MagicMock(spec=["connect", "name", "type"]) # no dataset_id - mock_conn = MagicMock(spec=[]) # no list_tables or list_databases - mock_db.connect.return_value = mock_conn + def test_duckdb_connection_with_tables(self): + config = DuckDBConfig(name="test", path=":memory:") + mock_conn = MagicMock() + mock_conn.list_tables.return_value = ["table1", "table2"] + + with patch.object(DuckDBConfig, "connect", return_value=mock_conn): + success, message = config.check_connection() + + assert success is True + assert "2 tables found" in message + + def test_postgres_connection_fallback(self): + config = PostgresConfig( + name="test", host="localhost", port=5432, database="testdb", user="user", password="pass" + ) + mock_conn = MagicMock(spec=[]) # no list_databases - success, message = check_database_connection(mock_db) + with patch.object(PostgresConfig, "connect", return_value=mock_conn): + success, message = config.check_connection() assert success is True - assert "unable to list" in message + assert "Connected successfully" in message def test_connection_failure(self): - mock_db = MagicMock() - mock_db.connect.side_effect = Exception("Connection refused") + config = DuckDBConfig(name="test", path=":memory:") - success, message = check_database_connection(mock_db) + with patch.object(DuckDBConfig, "connect", side_effect=Exception("Connection refused")): + success, message = config.check_connection() assert success is False assert "Connection refused" in message @@ -206,7 +219,8 @@ def test_debug_with_databases(self, create_config): """) with patch( - "nao_core.commands.debug.check_database_connection", return_value=(True, "Connected (5 tables found)") + "nao_core.config.databases.postgres.PostgresConfig.check_connection", + return_value=(True, "Connected (5 tables found)"), ): with patch("nao_core.commands.debug.console") as mock_console: debug() @@ -233,7 +247,8 @@ def test_debug_with_databases_error(self, create_config): """) with patch( - "nao_core.commands.debug.check_database_connection", return_value=(False, "Failed DB connection") + "nao_core.config.databases.postgres.PostgresConfig.check_connection", + return_value=(False, "Failed DB connection"), ) as mock_check: with patch("nao_core.commands.debug.console") as mock_console: debug() diff --git a/cli/tests/nao_core/templates/test_engine.py b/cli/tests/nao_core/templates/test_engine.py index 72041b3b..cbacbf69 100644 --- a/cli/tests/nao_core/templates/test_engine.py +++ b/cli/tests/nao_core/templates/test_engine.py @@ -2,6 +2,7 @@ import json from pathlib import Path +from unittest.mock import MagicMock from nao_core.templates.engine import ( DEFAULT_TEMPLATES_DIR, @@ -48,7 +49,6 @@ def test_default_templates_exist(self): "databases/columns.md.j2", "databases/preview.md.j2", "databases/description.md.j2", - "databases/profiling.md.j2", ] for template in expected_templates: @@ -75,18 +75,20 @@ def test_render_basic_template(self, tmp_path: Path): assert result == "Hello, World!" def test_render_with_default_template(self): - """Engine renders default database templates correctly.""" + """Engine renders default database templates correctly via db context.""" engine = TemplateEngine() + mock_db = MagicMock() + mock_db.columns.return_value = [ + {"name": "id", "type": "int64", "nullable": False, "description": None}, + {"name": "email", "type": "string", "nullable": True, "description": "User email"}, + ] + result = engine.render( "databases/columns.md.j2", table_name="users", dataset="main", - columns=[ - {"name": "id", "type": "int64", "nullable": False, "description": None}, - {"name": "email", "type": "string", "nullable": True, "description": "User email"}, - ], - column_count=2, + db=mock_db, ) assert "# users" in result @@ -96,28 +98,25 @@ def test_render_with_default_template(self): assert "- email (string" in result def test_render_preview_template(self): - """Engine renders preview template with rows correctly.""" + """Engine renders preview template with rows correctly via db context.""" engine = TemplateEngine() + mock_db = MagicMock() + mock_db.preview.return_value = [ + {"id": 1, "amount": 100.50}, + {"id": 2, "amount": 200.75}, + ] + result = engine.render( "databases/preview.md.j2", table_name="orders", dataset="sales", - rows=[ - {"id": 1, "amount": 100.50}, - {"id": 2, "amount": 200.75}, - ], - row_count=2, - columns=[ - {"name": "id", "type": "int64"}, - {"name": "amount", "type": "float64"}, - ], + db=mock_db, ) assert "# orders - Preview" in result assert "**Dataset:** `sales`" in result assert "## Rows (2)" in result - # Check JSON rows are present assert '"id": 1' in result assert '"amount": 100.5' in result @@ -306,7 +305,6 @@ def test_all_default_database_templates_present(self): "columns.md.j2", "preview.md.j2", "description.md.j2", - "profiling.md.j2", ] for filename in expected_files: diff --git a/cli/uv.lock b/cli/uv.lock index 0d258414..5077e90f 100644 --- a/cli/uv.lock +++ b/cli/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11'", @@ -180,30 +180,30 @@ wheels = [ [[package]] name = "boto3" -version = "1.42.44" +version = "1.42.45" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1d/88/de5c2a0ce069973345f9fac81200de5b58f503e231dbd566357a5b8c9109/boto3-1.42.44.tar.gz", hash = "sha256:d5601ea520d30674c1d15791a1f98b5c055e973c775e1d9952ccc09ee5913c4e", size = 112865, upload-time = "2026-02-06T20:28:05.647Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/e6/8fdd78825de6d8086aa3097955f83d8db3c5a3868b73da233c49977a7444/boto3-1.42.45.tar.gz", hash = "sha256:4db50b8b39321fab87ff7f40ab407887d436d004c1f2b0dfdf56e42b4884709b", size = 112846, upload-time = "2026-02-09T21:50:14.925Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/40/fb/0341da1482f7fa256d257cfba89383f6692570b741598d4e26d879b26c57/boto3-1.42.44-py3-none-any.whl", hash = "sha256:32e995b0d56e19422cff22f586f698e8924c792eb00943de9c517ff4607e4e18", size = 140604, upload-time = "2026-02-06T20:28:03.598Z" }, + { url = "https://files.pythonhosted.org/packages/6c/e0/d59a178799412cfe38c2757d6e49c337a5e71b18cdc3641dd6d9daf52151/boto3-1.42.45-py3-none-any.whl", hash = "sha256:5074e074a718a6f3c2b519cbb9ceab258f17b331a143d23351d487984f2a412f", size = 140604, upload-time = "2026-02-09T21:50:13.113Z" }, ] [[package]] name = "botocore" -version = "1.42.44" +version = "1.42.45" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/29/ff/54cef2c5ff4e1c77fabc0ed68781e48eb36f33433f82bba3605e9c0e45ce/botocore-1.42.44.tar.gz", hash = "sha256:47ba27360f2afd2c2721545d8909217f7be05fdee16dd8fc0b09589535a0701c", size = 14936071, upload-time = "2026-02-06T20:27:53.654Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/b1/c36ad705d67bb935eac3085052b5dc03ec22d5ac12e7aedf514f3d76cac8/botocore-1.42.45.tar.gz", hash = "sha256:40b577d07b91a0ed26879da9e4658d82d3a400382446af1014d6ad3957497545", size = 14941217, upload-time = "2026-02-09T21:50:01.966Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/9e/b45c54abfbb902ff174444a48558f97f9917143bc2e996729220f2631db1/botocore-1.42.44-py3-none-any.whl", hash = "sha256:ba406b9243a20591ee87d53abdb883d46416705cebccb639a7f1c923f9dd82df", size = 14611152, upload-time = "2026-02-06T20:27:49.565Z" }, + { url = "https://files.pythonhosted.org/packages/7e/ec/6681b8e4884f8663d7650220e702c503e4ba6bd09a5b91d44803b0b1d0a8/botocore-1.42.45-py3-none-any.whl", hash = "sha256:a5ea5d1b7c46c2d5d113879e45b21eaf7d60dc865f4bcb46dfcf0703fe3429f4", size = 14615557, upload-time = "2026-02-09T21:49:57.066Z" }, ] [[package]] @@ -612,7 +612,7 @@ wheels = [ [[package]] name = "databricks-sql-connector" -version = "4.2.4" +version = "4.2.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "lz4" }, @@ -626,9 +626,9 @@ dependencies = [ { name = "thrift" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/32/86/0e91b65ea75b376b3ebc5eea2b444984579b39362a5cb5d8dcaa44e2570d/databricks_sql_connector-4.2.4.tar.gz", hash = "sha256:e8ce4257ada2b6274ee1c17d4e29831c76cd9acc994c243bb9eb22314dac74ee", size = 186683, upload-time = "2026-01-08T09:52:04.589Z" } +sdist = { url = "https://files.pythonhosted.org/packages/42/0c/1e8179f427044a0c769e279b2c45b72a20cff902f4e92ca1bcca50549435/databricks_sql_connector-4.2.5.tar.gz", hash = "sha256:762df7568ef1998540f96b20cad6f1aaae87d1aad54e40e528f87e4524397291", size = 187223, upload-time = "2026-02-09T11:26:29.762Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/31/a9/1d4ee9b06b53be9e84b698874bcb5b3bcac911411e09b61cbd60c67caf4a/databricks_sql_connector-4.2.4-py3-none-any.whl", hash = "sha256:e9ea625df4ee3a8a4a7855fd7e0b5799a029c8965b2ff37694fff5e8294734df", size = 213124, upload-time = "2026-01-08T09:52:03.026Z" }, + { url = "https://files.pythonhosted.org/packages/67/a7/0d6dd8323cb2249a979cf4c6a45694e975668c53b19d52d7e15490bafb4c/databricks_sql_connector-4.2.5-py3-none-any.whl", hash = "sha256:31cee10552ce77a830318ce9488fc5e67daca7abbcdf0d8d34f12a180bc55039", size = 213906, upload-time = "2026-02-09T11:26:28.566Z" }, ] [[package]] @@ -749,7 +749,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -2729,11 +2729,11 @@ wheels = [ [[package]] name = "sqlglot" -version = "28.10.0" +version = "28.10.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/69/7d/1479ac3543caada917c3781893d3f846c810aec6355eb7b2f58df68f999b/sqlglot-28.10.0.tar.gz", hash = "sha256:f3d4759164ad854176980b3a47eb0c7ef699118dfa80beeb93e010885637b211", size = 5739594, upload-time = "2026-02-04T14:22:18.26Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/66/b2b300f325227044aa6f511ea7c9f3109a1dc74b13a0897931c1754b504e/sqlglot-28.10.1.tar.gz", hash = "sha256:66e0dae43b4bce23314b80e9aef41b8c88fea0e17ada62de095b45262084a8c5", size = 5739510, upload-time = "2026-02-09T23:36:23.671Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/34/c5de8f3c110bd066ebfa31b2d948dd33b691c7ccea39065e37f97f3f30a1/sqlglot-28.10.0-py3-none-any.whl", hash = "sha256:d442473bfd2340776dfc88382de3df456b9c5b66974623e554c6ad6426ba365e", size = 597042, upload-time = "2026-02-04T14:22:16.534Z" }, + { url = "https://files.pythonhosted.org/packages/55/ff/5a768b34202e1ee485737bfa167bd84592585aa40383f883a8e346d767cc/sqlglot-28.10.1-py3-none-any.whl", hash = "sha256:214aef51fd4ce16407022f81cfc80c173409dab6d0f6ae18c52b43f43b31d4dd", size = 597053, upload-time = "2026-02-09T23:36:21.385Z" }, ] [[package]] diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=customers/description.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=customers/description.md new file mode 100644 index 00000000..c870f060 --- /dev/null +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=customers/description.md @@ -0,0 +1,14 @@ +# customers + +**Dataset:** `main` + +## Table Metadata + +| Property | Value | +|----------|-------| +| **Row Count** | 100 | +| **Column Count** | 7 | + +## Description + +_No description available._ diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=customers/preview.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=customers/preview.md index f9e2e74b..b53c9df5 100644 --- a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=customers/preview.md +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=customers/preview.md @@ -1,21 +1,19 @@ -# 📊 customers +# customers - Preview -> **Schema:** `main` +**Dataset:** `main` -## Sample Data (10 rows) +## Rows (10) -| customer_id | first_name | last_name | first_order | most_recent_order | number_of_orders | customer_lifetime_value | -| --- | --- | --- | --- | --- | --- | --- | -| 1 | Michael | P. | 2018-01-01 00:00:00 | 2018-02-10 00:00:00 | 2 | 33.0 | -| 2 | Shawn | M. | 2018-01-11 00:00:00 | 2018-01-11 00:00:00 | 1 | 23.0 | -| 3 | Kathleen | P. | 2018-01-02 00:00:00 | 2018-03-11 00:00:00 | 3 | 65.0 | -| 6 | Sarah | R. | 2018-02-19 00:00:00 | 2018-02-19 00:00:00 | 1 | 8.0 | -| 7 | Martin | M. | 2018-01-14 00:00:00 | 2018-01-14 00:00:00 | 1 | 26.0 | -| 8 | Frank | R. | 2018-01-29 00:00:00 | 2018-03-12 00:00:00 | 2 | 45.0 | -| 9 | Jennifer | F. | 2018-03-17 00:00:00 | 2018-03-17 00:00:00 | 1 | 30.0 | -| 11 | Fred | S. | 2018-03-23 00:00:00 | 2018-03-23 00:00:00 | 1 | 3.0 | -| 12 | Amy | D. | 2018-03-03 00:00:00 | 2018-03-03 00:00:00 | 1 | 4.0 | -| 13 | Kathleen | M. | 2018-03-07 00:00:00 | 2018-03-07 00:00:00 | 1 | 26.0 | +- {"customer_id": 1, "first_name": "Michael", "last_name": "P.", "first_order": "2018-01-01 00:00:00", "most_recent_order": "2018-02-10 00:00:00", "number_of_orders": 2, "customer_lifetime_value": 33.0} +- {"customer_id": 2, "first_name": "Shawn", "last_name": "M.", "first_order": "2018-01-11 00:00:00", "most_recent_order": "2018-01-11 00:00:00", "number_of_orders": 1, "customer_lifetime_value": 23.0} +- {"customer_id": 3, "first_name": "Kathleen", "last_name": "P.", "first_order": "2018-01-02 00:00:00", "most_recent_order": "2018-03-11 00:00:00", "number_of_orders": 3, "customer_lifetime_value": 65.0} +- {"customer_id": 6, "first_name": "Sarah", "last_name": "R.", "first_order": "2018-02-19 00:00:00", "most_recent_order": "2018-02-19 00:00:00", "number_of_orders": 1, "customer_lifetime_value": 8.0} +- {"customer_id": 7, "first_name": "Martin", "last_name": "M.", "first_order": "2018-01-14 00:00:00", "most_recent_order": "2018-01-14 00:00:00", "number_of_orders": 1, "customer_lifetime_value": 26.0} +- {"customer_id": 8, "first_name": "Frank", "last_name": "R.", "first_order": "2018-01-29 00:00:00", "most_recent_order": "2018-03-12 00:00:00", "number_of_orders": 2, "customer_lifetime_value": 45.0} +- {"customer_id": 9, "first_name": "Jennifer", "last_name": "F.", "first_order": "2018-03-17 00:00:00", "most_recent_order": "2018-03-17 00:00:00", "number_of_orders": 1, "customer_lifetime_value": 30.0} +- {"customer_id": 11, "first_name": "Fred", "last_name": "S.", "first_order": "2018-03-23 00:00:00", "most_recent_order": "2018-03-23 00:00:00", "number_of_orders": 1, "customer_lifetime_value": 3.0} +- {"customer_id": 12, "first_name": "Amy", "last_name": "D.", "first_order": "2018-03-03 00:00:00", "most_recent_order": "2018-03-03 00:00:00", "number_of_orders": 1, "customer_lifetime_value": 4.0} +- {"customer_id": 13, "first_name": "Kathleen", "last_name": "M.", "first_order": "2018-03-07 00:00:00", "most_recent_order": "2018-03-07 00:00:00", "number_of_orders": 1, "customer_lifetime_value": 26.0} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=orders/description.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=orders/description.md new file mode 100644 index 00000000..37b8e97b --- /dev/null +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=orders/description.md @@ -0,0 +1,14 @@ +# orders + +**Dataset:** `main` + +## Table Metadata + +| Property | Value | +|----------|-------| +| **Row Count** | 99 | +| **Column Count** | 9 | + +## Description + +_No description available._ diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=orders/preview.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=orders/preview.md index 4e8c9851..bf8420e8 100644 --- a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=orders/preview.md +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=orders/preview.md @@ -1,21 +1,19 @@ -# 📊 orders +# orders - Preview -> **Schema:** `main` +**Dataset:** `main` -## Sample Data (10 rows) +## Rows (10) -| order_id | customer_id | order_date | status | credit_card_amount | coupon_amount | bank_transfer_amount | gift_card_amount | amount | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | -| 1 | 1 | 2018-01-01 00:00:00 | returned | 10.0 | 0.0 | 0.0 | 0.0 | 10.0 | -| 2 | 3 | 2018-01-02 00:00:00 | completed | 20.0 | 0.0 | 0.0 | 0.0 | 20.0 | -| 3 | 94 | 2018-01-04 00:00:00 | completed | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | -| 4 | 50 | 2018-01-05 00:00:00 | completed | 0.0 | 25.0 | 0.0 | 0.0 | 25.0 | -| 5 | 64 | 2018-01-05 00:00:00 | completed | 0.0 | 0.0 | 17.0 | 0.0 | 17.0 | -| 6 | 54 | 2018-01-07 00:00:00 | completed | 6.0 | 0.0 | 0.0 | 0.0 | 6.0 | -| 7 | 88 | 2018-01-09 00:00:00 | completed | 16.0 | 0.0 | 0.0 | 0.0 | 16.0 | -| 8 | 2 | 2018-01-11 00:00:00 | returned | 23.0 | 0.0 | 0.0 | 0.0 | 23.0 | -| 9 | 53 | 2018-01-12 00:00:00 | completed | 0.0 | 0.0 | 0.0 | 23.0 | 23.0 | -| 10 | 7 | 2018-01-14 00:00:00 | completed | 0.0 | 0.0 | 26.0 | 0.0 | 26.0 | +- {"order_id": 1, "customer_id": 1, "order_date": "2018-01-01 00:00:00", "status": "returned", "credit_card_amount": 10.0, "coupon_amount": 0.0, "bank_transfer_amount": 0.0, "gift_card_amount": 0.0, "amount": 10.0} +- {"order_id": 2, "customer_id": 3, "order_date": "2018-01-02 00:00:00", "status": "completed", "credit_card_amount": 20.0, "coupon_amount": 0.0, "bank_transfer_amount": 0.0, "gift_card_amount": 0.0, "amount": 20.0} +- {"order_id": 3, "customer_id": 94, "order_date": "2018-01-04 00:00:00", "status": "completed", "credit_card_amount": 0.0, "coupon_amount": 1.0, "bank_transfer_amount": 0.0, "gift_card_amount": 0.0, "amount": 1.0} +- {"order_id": 4, "customer_id": 50, "order_date": "2018-01-05 00:00:00", "status": "completed", "credit_card_amount": 0.0, "coupon_amount": 25.0, "bank_transfer_amount": 0.0, "gift_card_amount": 0.0, "amount": 25.0} +- {"order_id": 5, "customer_id": 64, "order_date": "2018-01-05 00:00:00", "status": "completed", "credit_card_amount": 0.0, "coupon_amount": 0.0, "bank_transfer_amount": 17.0, "gift_card_amount": 0.0, "amount": 17.0} +- {"order_id": 6, "customer_id": 54, "order_date": "2018-01-07 00:00:00", "status": "completed", "credit_card_amount": 6.0, "coupon_amount": 0.0, "bank_transfer_amount": 0.0, "gift_card_amount": 0.0, "amount": 6.0} +- {"order_id": 7, "customer_id": 88, "order_date": "2018-01-09 00:00:00", "status": "completed", "credit_card_amount": 16.0, "coupon_amount": 0.0, "bank_transfer_amount": 0.0, "gift_card_amount": 0.0, "amount": 16.0} +- {"order_id": 8, "customer_id": 2, "order_date": "2018-01-11 00:00:00", "status": "returned", "credit_card_amount": 23.0, "coupon_amount": 0.0, "bank_transfer_amount": 0.0, "gift_card_amount": 0.0, "amount": 23.0} +- {"order_id": 9, "customer_id": 53, "order_date": "2018-01-12 00:00:00", "status": "completed", "credit_card_amount": 0.0, "coupon_amount": 0.0, "bank_transfer_amount": 0.0, "gift_card_amount": 23.0, "amount": 23.0} +- {"order_id": 10, "customer_id": 7, "order_date": "2018-01-14 00:00:00", "status": "completed", "credit_card_amount": 0.0, "coupon_amount": 0.0, "bank_transfer_amount": 26.0, "gift_card_amount": 0.0, "amount": 26.0} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_customers/description.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_customers/description.md new file mode 100644 index 00000000..174ea635 --- /dev/null +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_customers/description.md @@ -0,0 +1,14 @@ +# raw_customers + +**Dataset:** `main` + +## Table Metadata + +| Property | Value | +|----------|-------| +| **Row Count** | 100 | +| **Column Count** | 3 | + +## Description + +_No description available._ diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_customers/preview.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_customers/preview.md index 9500d40f..720b78b6 100644 --- a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_customers/preview.md +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_customers/preview.md @@ -1,21 +1,19 @@ -# 📊 raw_customers +# raw_customers - Preview -> **Schema:** `main` +**Dataset:** `main` -## Sample Data (10 rows) +## Rows (10) -| id | first_name | last_name | -| --- | --- | --- | -| 1 | Michael | P. | -| 2 | Shawn | M. | -| 3 | Kathleen | P. | -| 4 | Jimmy | C. | -| 5 | Katherine | R. | -| 6 | Sarah | R. | -| 7 | Martin | M. | -| 8 | Frank | R. | -| 9 | Jennifer | F. | -| 10 | Henry | W. | +- {"id": 1, "first_name": "Michael", "last_name": "P."} +- {"id": 2, "first_name": "Shawn", "last_name": "M."} +- {"id": 3, "first_name": "Kathleen", "last_name": "P."} +- {"id": 4, "first_name": "Jimmy", "last_name": "C."} +- {"id": 5, "first_name": "Katherine", "last_name": "R."} +- {"id": 6, "first_name": "Sarah", "last_name": "R."} +- {"id": 7, "first_name": "Martin", "last_name": "M."} +- {"id": 8, "first_name": "Frank", "last_name": "R."} +- {"id": 9, "first_name": "Jennifer", "last_name": "F."} +- {"id": 10, "first_name": "Henry", "last_name": "W."} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_orders/description.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_orders/description.md new file mode 100644 index 00000000..06800a26 --- /dev/null +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_orders/description.md @@ -0,0 +1,14 @@ +# raw_orders + +**Dataset:** `main` + +## Table Metadata + +| Property | Value | +|----------|-------| +| **Row Count** | 99 | +| **Column Count** | 4 | + +## Description + +_No description available._ diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_orders/preview.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_orders/preview.md index 43f8c6a8..9e2870f3 100644 --- a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_orders/preview.md +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_orders/preview.md @@ -1,21 +1,19 @@ -# 📊 raw_orders +# raw_orders - Preview -> **Schema:** `main` +**Dataset:** `main` -## Sample Data (10 rows) +## Rows (10) -| id | user_id | order_date | status | -| --- | --- | --- | --- | -| 1 | 1 | 2018-01-01 00:00:00 | returned | -| 2 | 3 | 2018-01-02 00:00:00 | completed | -| 3 | 94 | 2018-01-04 00:00:00 | completed | -| 4 | 50 | 2018-01-05 00:00:00 | completed | -| 5 | 64 | 2018-01-05 00:00:00 | completed | -| 6 | 54 | 2018-01-07 00:00:00 | completed | -| 7 | 88 | 2018-01-09 00:00:00 | completed | -| 8 | 2 | 2018-01-11 00:00:00 | returned | -| 9 | 53 | 2018-01-12 00:00:00 | completed | -| 10 | 7 | 2018-01-14 00:00:00 | completed | +- {"id": 1, "user_id": 1, "order_date": "2018-01-01 00:00:00", "status": "returned"} +- {"id": 2, "user_id": 3, "order_date": "2018-01-02 00:00:00", "status": "completed"} +- {"id": 3, "user_id": 94, "order_date": "2018-01-04 00:00:00", "status": "completed"} +- {"id": 4, "user_id": 50, "order_date": "2018-01-05 00:00:00", "status": "completed"} +- {"id": 5, "user_id": 64, "order_date": "2018-01-05 00:00:00", "status": "completed"} +- {"id": 6, "user_id": 54, "order_date": "2018-01-07 00:00:00", "status": "completed"} +- {"id": 7, "user_id": 88, "order_date": "2018-01-09 00:00:00", "status": "completed"} +- {"id": 8, "user_id": 2, "order_date": "2018-01-11 00:00:00", "status": "returned"} +- {"id": 9, "user_id": 53, "order_date": "2018-01-12 00:00:00", "status": "completed"} +- {"id": 10, "user_id": 7, "order_date": "2018-01-14 00:00:00", "status": "completed"} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_payments/description.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_payments/description.md new file mode 100644 index 00000000..e4c67b71 --- /dev/null +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_payments/description.md @@ -0,0 +1,14 @@ +# raw_payments + +**Dataset:** `main` + +## Table Metadata + +| Property | Value | +|----------|-------| +| **Row Count** | 113 | +| **Column Count** | 4 | + +## Description + +_No description available._ diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_payments/preview.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_payments/preview.md index e9e15cc2..dd483317 100644 --- a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_payments/preview.md +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=raw_payments/preview.md @@ -1,21 +1,19 @@ -# 📊 raw_payments +# raw_payments - Preview -> **Schema:** `main` +**Dataset:** `main` -## Sample Data (10 rows) +## Rows (10) -| id | order_id | payment_method | amount | -| --- | --- | --- | --- | -| 1 | 1 | credit_card | 1000 | -| 2 | 2 | credit_card | 2000 | -| 3 | 3 | coupon | 100 | -| 4 | 4 | coupon | 2500 | -| 5 | 5 | bank_transfer | 1700 | -| 6 | 6 | credit_card | 600 | -| 7 | 7 | credit_card | 1600 | -| 8 | 8 | credit_card | 2300 | -| 9 | 9 | gift_card | 2300 | -| 10 | 9 | bank_transfer | 0 | +- {"id": 1, "order_id": 1, "payment_method": "credit_card", "amount": 1000} +- {"id": 2, "order_id": 2, "payment_method": "credit_card", "amount": 2000} +- {"id": 3, "order_id": 3, "payment_method": "coupon", "amount": 100} +- {"id": 4, "order_id": 4, "payment_method": "coupon", "amount": 2500} +- {"id": 5, "order_id": 5, "payment_method": "bank_transfer", "amount": 1700} +- {"id": 6, "order_id": 6, "payment_method": "credit_card", "amount": 600} +- {"id": 7, "order_id": 7, "payment_method": "credit_card", "amount": 1600} +- {"id": 8, "order_id": 8, "payment_method": "credit_card", "amount": 2300} +- {"id": 9, "order_id": 9, "payment_method": "gift_card", "amount": 2300} +- {"id": 10, "order_id": 9, "payment_method": "bank_transfer", "amount": 0} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_customers/description.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_customers/description.md new file mode 100644 index 00000000..d408dbc1 --- /dev/null +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_customers/description.md @@ -0,0 +1,14 @@ +# stg_customers + +**Dataset:** `main` + +## Table Metadata + +| Property | Value | +|----------|-------| +| **Row Count** | 100 | +| **Column Count** | 3 | + +## Description + +_No description available._ diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_customers/preview.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_customers/preview.md index 71abd916..cb356695 100644 --- a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_customers/preview.md +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_customers/preview.md @@ -1,21 +1,19 @@ -# 📊 stg_customers +# stg_customers - Preview -> **Schema:** `main` +**Dataset:** `main` -## Sample Data (10 rows) +## Rows (10) -| customer_id | first_name | last_name | -| --- | --- | --- | -| 1 | Michael | P. | -| 2 | Shawn | M. | -| 3 | Kathleen | P. | -| 4 | Jimmy | C. | -| 5 | Katherine | R. | -| 6 | Sarah | R. | -| 7 | Martin | M. | -| 8 | Frank | R. | -| 9 | Jennifer | F. | -| 10 | Henry | W. | +- {"customer_id": 1, "first_name": "Michael", "last_name": "P."} +- {"customer_id": 2, "first_name": "Shawn", "last_name": "M."} +- {"customer_id": 3, "first_name": "Kathleen", "last_name": "P."} +- {"customer_id": 4, "first_name": "Jimmy", "last_name": "C."} +- {"customer_id": 5, "first_name": "Katherine", "last_name": "R."} +- {"customer_id": 6, "first_name": "Sarah", "last_name": "R."} +- {"customer_id": 7, "first_name": "Martin", "last_name": "M."} +- {"customer_id": 8, "first_name": "Frank", "last_name": "R."} +- {"customer_id": 9, "first_name": "Jennifer", "last_name": "F."} +- {"customer_id": 10, "first_name": "Henry", "last_name": "W."} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_orders/description.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_orders/description.md new file mode 100644 index 00000000..8a4b306d --- /dev/null +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_orders/description.md @@ -0,0 +1,14 @@ +# stg_orders + +**Dataset:** `main` + +## Table Metadata + +| Property | Value | +|----------|-------| +| **Row Count** | 99 | +| **Column Count** | 4 | + +## Description + +_No description available._ diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_orders/preview.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_orders/preview.md index f7510875..2c6a8fc9 100644 --- a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_orders/preview.md +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_orders/preview.md @@ -1,21 +1,19 @@ -# 📊 stg_orders +# stg_orders - Preview -> **Schema:** `main` +**Dataset:** `main` -## Sample Data (10 rows) +## Rows (10) -| order_id | customer_id | order_date | status | -| --- | --- | --- | --- | -| 1 | 1 | 2018-01-01 00:00:00 | returned | -| 2 | 3 | 2018-01-02 00:00:00 | completed | -| 3 | 94 | 2018-01-04 00:00:00 | completed | -| 4 | 50 | 2018-01-05 00:00:00 | completed | -| 5 | 64 | 2018-01-05 00:00:00 | completed | -| 6 | 54 | 2018-01-07 00:00:00 | completed | -| 7 | 88 | 2018-01-09 00:00:00 | completed | -| 8 | 2 | 2018-01-11 00:00:00 | returned | -| 9 | 53 | 2018-01-12 00:00:00 | completed | -| 10 | 7 | 2018-01-14 00:00:00 | completed | +- {"order_id": 1, "customer_id": 1, "order_date": "2018-01-01 00:00:00", "status": "returned"} +- {"order_id": 2, "customer_id": 3, "order_date": "2018-01-02 00:00:00", "status": "completed"} +- {"order_id": 3, "customer_id": 94, "order_date": "2018-01-04 00:00:00", "status": "completed"} +- {"order_id": 4, "customer_id": 50, "order_date": "2018-01-05 00:00:00", "status": "completed"} +- {"order_id": 5, "customer_id": 64, "order_date": "2018-01-05 00:00:00", "status": "completed"} +- {"order_id": 6, "customer_id": 54, "order_date": "2018-01-07 00:00:00", "status": "completed"} +- {"order_id": 7, "customer_id": 88, "order_date": "2018-01-09 00:00:00", "status": "completed"} +- {"order_id": 8, "customer_id": 2, "order_date": "2018-01-11 00:00:00", "status": "returned"} +- {"order_id": 9, "customer_id": 53, "order_date": "2018-01-12 00:00:00", "status": "completed"} +- {"order_id": 10, "customer_id": 7, "order_date": "2018-01-14 00:00:00", "status": "completed"} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_payments/description.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_payments/description.md new file mode 100644 index 00000000..5fc2f285 --- /dev/null +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_payments/description.md @@ -0,0 +1,14 @@ +# stg_payments + +**Dataset:** `main` + +## Table Metadata + +| Property | Value | +|----------|-------| +| **Row Count** | 113 | +| **Column Count** | 4 | + +## Description + +_No description available._ diff --git a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_payments/preview.md b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_payments/preview.md index 032cfe03..9550b1dc 100644 --- a/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_payments/preview.md +++ b/example/databases/type=duckdb/database=jaffle_shop/schema=main/table=stg_payments/preview.md @@ -1,21 +1,19 @@ -# 📊 stg_payments +# stg_payments - Preview -> **Schema:** `main` +**Dataset:** `main` -## Sample Data (10 rows) +## Rows (10) -| payment_id | order_id | payment_method | amount | -| --- | --- | --- | --- | -| 1 | 1 | credit_card | 10.0 | -| 2 | 2 | credit_card | 20.0 | -| 3 | 3 | coupon | 1.0 | -| 4 | 4 | coupon | 25.0 | -| 5 | 5 | bank_transfer | 17.0 | -| 6 | 6 | credit_card | 6.0 | -| 7 | 7 | credit_card | 16.0 | -| 8 | 8 | credit_card | 23.0 | -| 9 | 9 | gift_card | 23.0 | -| 10 | 9 | bank_transfer | 0.0 | +- {"payment_id": 1, "order_id": 1, "payment_method": "credit_card", "amount": 10.0} +- {"payment_id": 2, "order_id": 2, "payment_method": "credit_card", "amount": 20.0} +- {"payment_id": 3, "order_id": 3, "payment_method": "coupon", "amount": 1.0} +- {"payment_id": 4, "order_id": 4, "payment_method": "coupon", "amount": 25.0} +- {"payment_id": 5, "order_id": 5, "payment_method": "bank_transfer", "amount": 17.0} +- {"payment_id": 6, "order_id": 6, "payment_method": "credit_card", "amount": 6.0} +- {"payment_id": 7, "order_id": 7, "payment_method": "credit_card", "amount": 16.0} +- {"payment_id": 8, "order_id": 8, "payment_method": "credit_card", "amount": 23.0} +- {"payment_id": 9, "order_id": 9, "payment_method": "gift_card", "amount": 23.0} +- {"payment_id": 10, "order_id": 9, "payment_method": "bank_transfer", "amount": 0.0} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* diff --git a/example/nao_config.yaml b/example/nao_config.yaml index a4f76716..208cfcc7 100644 --- a/example/nao_config.yaml +++ b/example/nao_config.yaml @@ -1,9 +1,6 @@ project_name: example databases: - name: duckdb-jaffle-shop - accessors: - - columns - - preview include: [] exclude: [] type: duckdb @@ -13,7 +10,7 @@ repos: url: https://github.com/dbt-labs/jaffle_shop_duckdb.git branch: null notion: - api_key: {{ env('NOTION_API_KEY') }} + api_key: yolo pages: - https://naolabs.notion.site/Jaffle-shop-information-2f8c7a70bc0680a4b7d0caf99f055360 llm: null diff --git a/example/repos/dbt b/example/repos/dbt index db6bffad..38904772 160000 --- a/example/repos/dbt +++ b/example/repos/dbt @@ -1 +1 @@ -Subproject commit db6bffad67174b33f27eb0e8f8cf3934aa4213b6 +Subproject commit 389047728a2debad897f26475d348eb8a416ba2a diff --git a/example/templates/databases/preview.md.j2 b/example/templates/databases/preview.md.j2 index 5e228ed3..2d69f17a 100644 --- a/example/templates/databases/preview.md.j2 +++ b/example/templates/databases/preview.md.j2 @@ -1,32 +1,22 @@ {# - Custom preview template example - - This file overrides the default preview.md.j2 template. - You can customize how preview data is displayed. - - Available variables: + Template: preview.md.j2 + Description: Generates a preview of table rows in JSONL format + + Available context: - table_name (str): Name of the table - dataset (str): Schema/dataset name - - rows (list): List of row dictionaries (first N rows of the table) - - row_count (int): Number of preview rows shown - - columns (list): List of column dictionaries with: - - name (str): Column name - - type (str): Data type + - db (DatabaseContext): Database context with helper methods #} -# 📊 {{ table_name }} - -> **Schema:** `{{ dataset }}` +{% set rows = db.preview() %} +# {{ table_name }} - Preview -## Sample Data ({{ row_count }} rows) +**Dataset:** `{{ dataset }}` -| {% for col in columns %}{{ col.name }} | {% endfor %} - -| {% for col in columns %}--- | {% endfor %} +## Rows ({{ rows | length }}) {% for row in rows %} -| {% for col in columns %}{{ row[col.name] | truncate_middle(30) }} | {% endfor %} - +- {{ row | to_json }} {% endfor %} --- -*Generated by nao sync with custom template* +*Generated by nao sync with custom template (in example/templates/databases/preview.md.j2)* From b197fd6a2082292c827fbda0c377379c85b801e4 Mon Sep 17 00:00:00 2001 From: Christophe Blefari Date: Tue, 10 Feb 2026 19:03:42 +0100 Subject: [PATCH 2/7] integration test for all providers --- .gitignore | 4 +- cli/nao_core/config/databases/postgres.py | 6 +- cli/nao_core/config/databases/snowflake.py | 46 ++- .../sync/integration/dml/bigquery.sql | 27 ++ .../sync/integration/dml/databricks.sql | 29 ++ .../sync/integration/dml/postgres.sql | 31 ++ .../sync/integration/dml/redshift.sql | 12 +- .../sync/integration/dml/snowflake.sql | 29 ++ .../sync/integration/test_bigquery.py | 385 ++++++++++++++++++ .../sync/integration/test_databricks.py | 324 +++++++++++++++ .../sync/integration/test_postgres.py | 355 ++++++++++++++++ .../sync/integration/test_redshift.py | 118 +++++- .../sync/integration/test_snowflake.py | 368 +++++++++++++++++ 13 files changed, 1699 insertions(+), 35 deletions(-) create mode 100644 cli/tests/nao_core/commands/sync/integration/dml/bigquery.sql create mode 100644 cli/tests/nao_core/commands/sync/integration/dml/databricks.sql create mode 100644 cli/tests/nao_core/commands/sync/integration/dml/postgres.sql create mode 100644 cli/tests/nao_core/commands/sync/integration/dml/snowflake.sql create mode 100644 cli/tests/nao_core/commands/sync/integration/test_bigquery.py create mode 100644 cli/tests/nao_core/commands/sync/integration/test_databricks.py create mode 100644 cli/tests/nao_core/commands/sync/integration/test_postgres.py create mode 100644 cli/tests/nao_core/commands/sync/integration/test_snowflake.py diff --git a/.gitignore b/.gitignore index f0fa7bfe..aa45a545 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,6 @@ chats nao-chat-server .pytest_cache/ -*.pyc \ No newline at end of file +*.pyc + +rsa_key.* \ No newline at end of file diff --git a/cli/nao_core/config/databases/postgres.py b/cli/nao_core/config/databases/postgres.py index 4eb6ef66..bc2df906 100644 --- a/cli/nao_core/config/databases/postgres.py +++ b/cli/nao_core/config/databases/postgres.py @@ -72,7 +72,11 @@ def get_schemas(self, conn: BaseBackend) -> list[str]: if self.schema_name: return [self.schema_name] list_databases = getattr(conn, "list_databases", None) - return list_databases() if list_databases else [] + if list_databases: + schemas = list_databases() + # Filter out system schemas + return [s for s in schemas if s not in ("pg_catalog", "information_schema") and not s.startswith("pg_")] + return [] def check_connection(self) -> tuple[bool, str]: """Test connectivity to PostgreSQL.""" diff --git a/cli/nao_core/config/databases/snowflake.py b/cli/nao_core/config/databases/snowflake.py index 9e5a0444..93d0ace8 100644 --- a/cli/nao_core/config/databases/snowflake.py +++ b/cli/nao_core/config/databases/snowflake.py @@ -21,8 +21,6 @@ class SnowflakeConfig(DatabaseConfig): database: str = Field(description="Snowflake database") schema_name: str | None = Field( default=None, - validation_alias="schema", - serialization_alias="schema", description="Snowflake schema (optional)", ) warehouse: str | None = Field(default=None, description="Snowflake warehouse to use (optional)") @@ -74,9 +72,9 @@ def connect(self) -> BaseBackend: kwargs: dict = {"user": self.username} kwargs["account"] = self.account_id - if self.database and self.schema_name: - kwargs["database"] = f"{self.database}/{self.schema_name}" - elif self.database: + # Always connect to just the database, not database/schema + # The sync provider will handle schema filtering via list_tables(database=schema) + if self.database: kwargs["database"] = self.database if self.warehouse: @@ -103,11 +101,45 @@ def get_database_name(self) -> str: """Get the database name for Snowflake.""" return self.database + def matches_pattern(self, schema: str, table: str) -> bool: + """Check if a schema.table matches the include/exclude patterns. + + Snowflake identifier matching is case-insensitive. + + Args: + schema: The schema name (uppercase from Snowflake) + table: The table name (uppercase from Snowflake) + + Returns: + True if the table should be included, False if excluded + """ + from fnmatch import fnmatch + + full_name = f"{schema}.{table}" + full_name_lower = full_name.lower() + + # If include patterns exist, table must match at least one + if self.include: + included = any(fnmatch(full_name_lower, pattern.lower()) for pattern in self.include) + if not included: + return False + + # If exclude patterns exist, table must not match any + if self.exclude: + excluded = any(fnmatch(full_name_lower, pattern.lower()) for pattern in self.exclude) + if excluded: + return False + + return True + def get_schemas(self, conn: BaseBackend) -> list[str]: if self.schema_name: - return [self.schema_name] + # Snowflake schema names are case-insensitive but stored as uppercase + return [self.schema_name.upper()] list_databases = getattr(conn, "list_databases", None) - return list_databases() if list_databases else [] + schemas = list_databases() if list_databases else [] + # Filter out INFORMATION_SCHEMA which contains system tables + return [s for s in schemas if s != "INFORMATION_SCHEMA"] def check_connection(self) -> tuple[bool, str]: """Test connectivity to Snowflake.""" diff --git a/cli/tests/nao_core/commands/sync/integration/dml/bigquery.sql b/cli/tests/nao_core/commands/sync/integration/dml/bigquery.sql new file mode 100644 index 00000000..f8184eef --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/dml/bigquery.sql @@ -0,0 +1,27 @@ +CREATE TABLE {public_dataset}.users ( +id INT64 NOT NULL, +name STRING NOT NULL, +email STRING, +active BOOL DEFAULT TRUE +); + +INSERT INTO {public_dataset}.users VALUES +(1, 'Alice', 'alice@example.com', true), +(2, 'Bob', NULL, false), +(3, 'Charlie', 'charlie@example.com', true); + + +CREATE TABLE {public_dataset}.orders ( +id INT64 NOT NULL, +user_id INT64 NOT NULL, +amount FLOAT64 NOT NULL +); + +INSERT INTO {public_dataset}.orders VALUES +(1, 1, 99.99), +(2, 1, 24.50); + +CREATE TABLE {another_dataset}.whatever ( +id INT64 NOT NULL, +price FLOAT64 NOT NULL +); diff --git a/cli/tests/nao_core/commands/sync/integration/dml/databricks.sql b/cli/tests/nao_core/commands/sync/integration/dml/databricks.sql new file mode 100644 index 00000000..8397bd84 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/dml/databricks.sql @@ -0,0 +1,29 @@ +CREATE TABLE {catalog}.public.users ( +id INTEGER NOT NULL, +name STRING NOT NULL, +email STRING, +active BOOLEAN DEFAULT TRUE +); + +INSERT INTO {catalog}.public.users VALUES +(1, 'Alice', 'alice@example.com', true), +(2, 'Bob', NULL, false), +(3, 'Charlie', 'charlie@example.com', true); + + +CREATE TABLE {catalog}.public.orders ( +id INTEGER NOT NULL, +user_id INTEGER NOT NULL, +amount DOUBLE NOT NULL +); + +INSERT INTO {catalog}.public.orders VALUES +(1, 1, 99.99), +(2, 1, 24.50); + +CREATE SCHEMA {catalog}.another; + +CREATE TABLE {catalog}.another.whatever ( +id INTEGER NOT NULL, +price DOUBLE NOT NULL +); diff --git a/cli/tests/nao_core/commands/sync/integration/dml/postgres.sql b/cli/tests/nao_core/commands/sync/integration/dml/postgres.sql new file mode 100644 index 00000000..6f4fb994 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/dml/postgres.sql @@ -0,0 +1,31 @@ +CREATE SCHEMA IF NOT EXISTS public; + +CREATE TABLE public.users ( +id INTEGER NOT NULL, +name VARCHAR NOT NULL, +email VARCHAR, +active BOOLEAN DEFAULT TRUE +); + +INSERT INTO public.users VALUES +(1, 'Alice', 'alice@example.com', true), +(2, 'Bob', NULL, false), +(3, 'Charlie', 'charlie@example.com', true); + + +CREATE TABLE public.orders ( +id INTEGER NOT NULL, +user_id INTEGER NOT NULL, +amount DOUBLE PRECISION NOT NULL +); + +INSERT INTO public.orders VALUES +(1, 1, 99.99), +(2, 1, 24.50); + +CREATE SCHEMA another; + +CREATE TABLE another.whatever ( +id INTEGER NOT NULL, +price DOUBLE PRECISION NOT NULL +); diff --git a/cli/tests/nao_core/commands/sync/integration/dml/redshift.sql b/cli/tests/nao_core/commands/sync/integration/dml/redshift.sql index 1c1fb49d..31dd083b 100644 --- a/cli/tests/nao_core/commands/sync/integration/dml/redshift.sql +++ b/cli/tests/nao_core/commands/sync/integration/dml/redshift.sql @@ -1,29 +1,29 @@ -CREATE TABLE nao_unit_tests.public.users ( +CREATE TABLE {database}.public.users ( id INTEGER NOT NULL, name VARCHAR NOT NULL, email VARCHAR, active BOOLEAN DEFAULT TRUE ); -INSERT INTO nao_unit_tests.public.users VALUES +INSERT INTO {database}.public.users VALUES (1, 'Alice', 'alice@example.com', true), (2, 'Bob', NULL, false), (3, 'Charlie', 'charlie@example.com', true); -CREATE TABLE nao_unit_tests.public.orders ( +CREATE TABLE {database}.public.orders ( id INTEGER NOT NULL, user_id INTEGER NOT NULL, amount FLOAT4 NOT NULL ); -INSERT INTO nao_unit_tests.public.orders VALUES +INSERT INTO {database}.public.orders VALUES (1, 1, 99.99), (2, 1, 24.50); -CREATE SCHEMA nao_unit_tests.another; +CREATE SCHEMA {database}.another; -CREATE TABLE nao_unit_tests.another.whatever ( +CREATE TABLE {database}.another.whatever ( id INTEGER NOT NULL, price FLOAT4 NOT NULL ); \ No newline at end of file diff --git a/cli/tests/nao_core/commands/sync/integration/dml/snowflake.sql b/cli/tests/nao_core/commands/sync/integration/dml/snowflake.sql new file mode 100644 index 00000000..676bc8eb --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/dml/snowflake.sql @@ -0,0 +1,29 @@ +CREATE TABLE {database}.public.users ( +id INTEGER NOT NULL, +name VARCHAR NOT NULL, +email VARCHAR, +active BOOLEAN DEFAULT TRUE +); + +INSERT INTO {database}.public.users VALUES +(1, 'Alice', 'alice@example.com', true), +(2, 'Bob', NULL, false), +(3, 'Charlie', 'charlie@example.com', true); + + +CREATE TABLE {database}.public.orders ( +id INTEGER NOT NULL, +user_id INTEGER NOT NULL, +amount FLOAT NOT NULL +); + +INSERT INTO {database}.public.orders VALUES +(1, 1, 99.99), +(2, 1, 24.50); + +CREATE SCHEMA {database}.another; + +CREATE TABLE {database}.another.whatever ( +id INTEGER NOT NULL, +price FLOAT NOT NULL +); diff --git a/cli/tests/nao_core/commands/sync/integration/test_bigquery.py b/cli/tests/nao_core/commands/sync/integration/test_bigquery.py new file mode 100644 index 00000000..d7496896 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/test_bigquery.py @@ -0,0 +1,385 @@ +"""Integration tests for the database sync pipeline against a real BigQuery project. + +Connection is configured via environment variables: + BIGQUERY_PROJECT_ID, BIGQUERY_DATASET_ID (default public), + BIGQUERY_CREDENTIALS_JSON (JSON string of service account credentials). + +The test suite is skipped entirely when BIGQUERY_PROJECT_ID is not set. +""" + +import json +import os +from pathlib import Path + +import ibis +import pytest +from google.cloud import bigquery +from google.oauth2 import service_account +from rich.progress import Progress + +from nao_core.commands.sync.providers.databases.provider import sync_database +from nao_core.config.databases.bigquery import BigQueryConfig + +BIGQUERY_PROJECT_ID = os.environ.get("BIGQUERY_PROJECT_ID") + +pytestmark = pytest.mark.skipif( + BIGQUERY_PROJECT_ID is None, reason="BIGQUERY_PROJECT_ID not set — skipping BigQuery integration tests" +) + + +@pytest.fixture(scope="module") +def temp_datasets(): + """Create or reuse test datasets with test data.""" + public_dataset_id = "nao_integration_tests_public" + another_dataset_id = "nao_integration_tests_another" + + # Create BigQuery client for dataset management + credentials_json_str = os.environ.get("BIGQUERY_CREDENTIALS_JSON") + project_id = os.environ["BIGQUERY_PROJECT_ID"] + + credentials = None + if credentials_json_str: + credentials_json = json.loads(credentials_json_str) + credentials = service_account.Credentials.from_service_account_info( + credentials_json, + scopes=["https://www.googleapis.com/auth/bigquery"], + ) + + bq_client = bigquery.Client(project=project_id, credentials=credentials) + + # Create ibis connection for data operations + ibis_kwargs = {"project_id": project_id} + if credentials_json_str: + ibis_kwargs["credentials"] = credentials + + conn = ibis.bigquery.connect(**ibis_kwargs) + + try: + # Delete existing test datasets from previous runs to start fresh + bq_client.delete_dataset(f"{project_id}.{public_dataset_id}", delete_contents=True, not_found_ok=True) + bq_client.delete_dataset(f"{project_id}.{another_dataset_id}", delete_contents=True, not_found_ok=True) + + # Clean up any old nao_unit_tests_ datasets from failed runs + for dataset in bq_client.list_datasets(): + if dataset.dataset_id.startswith("nao_unit_tests_"): + bq_client.delete_dataset(f"{project_id}.{dataset.dataset_id}", delete_contents=True, not_found_ok=True) + + # Create datasets using BigQuery client + public_dataset = bigquery.Dataset(f"{project_id}.{public_dataset_id}") + public_dataset.location = "US" + bq_client.create_dataset(public_dataset) + + another_dataset = bigquery.Dataset(f"{project_id}.{another_dataset_id}") + another_dataset.location = "US" + bq_client.create_dataset(another_dataset) + + # Read and execute SQL script + sql_file = Path(__file__).parent / "dml" / "bigquery.sql" + sql_template = sql_file.read_text() + + # Inject dataset names into SQL + sql_content = sql_template.format( + public_dataset=public_dataset_id, + another_dataset=another_dataset_id, + ) + + # Execute SQL statements using ibis + for statement in sql_content.split(";"): + statement = statement.strip() + if statement: + conn.raw_sql(statement) + + yield {"public": public_dataset_id, "another": another_dataset_id} + + finally: + # Don't clean up datasets - keep them for reuse across test runs + conn.disconnect() + + +@pytest.fixture(scope="module") +def bigquery_config(temp_datasets): + """Build a BigQueryConfig from environment variables using the temporary dataset.""" + credentials_json_str = os.environ.get("BIGQUERY_CREDENTIALS_JSON") + credentials_json = json.loads(credentials_json_str) if credentials_json_str else None + + return BigQueryConfig( + name="test-bigquery", + project_id=os.environ["BIGQUERY_PROJECT_ID"], + dataset_id=temp_datasets["public"], + credentials_json=credentials_json, + ) + + +@pytest.fixture(scope="module") +def synced(tmp_path_factory, bigquery_config): + """Run sync once for the whole module and return (state, output_path, config).""" + output = tmp_path_factory.mktemp("bigquery_sync") + + with Progress(transient=True) as progress: + state = sync_database(bigquery_config, output, progress) + + return state, output, bigquery_config + + +class TestBigQuerySyncIntegration: + """Verify the sync pipeline produces correct output against a live BigQuery project.""" + + def test_creates_expected_directory_tree(self, synced, temp_datasets): + state, output, config = synced + + base = output / "type=bigquery" / f"database={config.project_id}" / f"schema={config.dataset_id}" + + # Schema directory + assert base.is_dir() + + # Each table should have exactly the 3 default template outputs + for table in ("orders", "users"): + assert (base / f"table={table}").is_dir() + table_dir = base / f"table={table}" + files = sorted(f.name for f in table_dir.iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify that the "another" dataset was NOT synced + another_dataset_dir = ( + output / "type=bigquery" / f"database={config.project_id}" / f"schema={temp_datasets['another']}" + ) + assert not another_dataset_dir.exists() + + def test_columns_md_users(self, synced): + state, output, config = synced + + content = ( + output + / "type=bigquery" + / f"database={config.project_id}" + / f"schema={config.dataset_id}" + / "table=users" + / "columns.md" + ).read_text() + + # BigQuery uses int64, string, bool types + assert "# users" in content + assert f"**Dataset:** `{config.dataset_id}`" in content + assert "## Columns (4)" in content + assert "- id (int64 NOT NULL)" in content + assert "- name (string NOT NULL)" in content + assert "- email (string)" in content + assert "- active (boolean)" in content + + def test_columns_md_orders(self, synced): + state, output, config = synced + + content = ( + output + / "type=bigquery" + / f"database={config.project_id}" + / f"schema={config.dataset_id}" + / "table=orders" + / "columns.md" + ).read_text() + + assert "# orders" in content + assert f"**Dataset:** `{config.dataset_id}`" in content + assert "## Columns (3)" in content + assert "- id (int64 NOT NULL)" in content + assert "- user_id (int64 NOT NULL)" in content + assert "- amount (float64 NOT NULL)" in content + + def test_description_md_users(self, synced): + state, output, config = synced + + content = ( + output + / "type=bigquery" + / f"database={config.project_id}" + / f"schema={config.dataset_id}" + / "table=users" + / "description.md" + ).read_text() + + assert "# users" in content + assert f"**Dataset:** `{config.dataset_id}`" in content + assert "## Table Metadata" in content + assert "| **Row Count** | 3 |" in content + assert "| **Column Count** | 4 |" in content + + def test_description_md_orders(self, synced): + state, output, config = synced + + content = ( + output + / "type=bigquery" + / f"database={config.project_id}" + / f"schema={config.dataset_id}" + / "table=orders" + / "description.md" + ).read_text() + + assert "| **Row Count** | 2 |" in content + assert "| **Column Count** | 3 |" in content + + def test_preview_md_users(self, synced): + state, output, config = synced + + content = ( + output + / "type=bigquery" + / f"database={config.project_id}" + / f"schema={config.dataset_id}" + / "table=users" + / "preview.md" + ).read_text() + + assert "# users - Preview" in content + assert f"**Dataset:** `{config.dataset_id}`" in content + assert "## Rows (3)" in content + + # Parse the JSONL rows from the markdown + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 3 + # Sort by id since BigQuery doesn't guarantee order + rows_sorted = sorted(rows, key=lambda r: r["id"]) + assert rows_sorted[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} + assert rows_sorted[1] == {"id": 2, "name": "Bob", "email": None, "active": False} + assert rows_sorted[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} + + def test_preview_md_orders(self, synced): + state, output, config = synced + + content = ( + output + / "type=bigquery" + / f"database={config.project_id}" + / f"schema={config.dataset_id}" + / "table=orders" + / "preview.md" + ).read_text() + + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 2 + # Sort by id since BigQuery doesn't guarantee order + rows_sorted = sorted(rows, key=lambda r: r["id"]) + assert rows_sorted[0] == {"id": 1, "user_id": 1, "amount": 99.99} + assert rows_sorted[1] == {"id": 2, "user_id": 1, "amount": 24.5} + + def test_sync_state_tracks_schemas_and_tables(self, synced): + state, output, config = synced + + assert state.schemas_synced == 1 + assert state.tables_synced == 2 + assert config.dataset_id in state.synced_schemas + assert "users" in state.synced_tables[config.dataset_id] + assert "orders" in state.synced_tables[config.dataset_id] + + def test_include_filter(self, tmp_path_factory, bigquery_config): + """Only tables matching include patterns should be synced.""" + config = BigQueryConfig( + name=bigquery_config.name, + project_id=bigquery_config.project_id, + dataset_id=bigquery_config.dataset_id, + credentials_json=bigquery_config.credentials_json, + include=[f"{bigquery_config.dataset_id}.users"], + ) + + output = tmp_path_factory.mktemp("bigquery_include") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=bigquery" / f"database={config.project_id}" / f"schema={config.dataset_id}" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_exclude_filter(self, tmp_path_factory, bigquery_config): + """Tables matching exclude patterns should be skipped.""" + config = BigQueryConfig( + name=bigquery_config.name, + project_id=bigquery_config.project_id, + dataset_id=bigquery_config.dataset_id, + credentials_json=bigquery_config.credentials_json, + exclude=[f"{bigquery_config.dataset_id}.orders"], + ) + + output = tmp_path_factory.mktemp("bigquery_exclude") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=bigquery" / f"database={config.project_id}" / f"schema={config.dataset_id}" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_sync_all_schemas_when_dataset_id_not_specified(self, tmp_path_factory, bigquery_config, temp_datasets): + """When dataset_id is not provided, all datasets should be synced.""" + public_dataset = temp_datasets["public"] + another_dataset = temp_datasets["another"] + + config = BigQueryConfig( + name=bigquery_config.name, + project_id=bigquery_config.project_id, + dataset_id=None, + credentials_json=bigquery_config.credentials_json, + ) + + output = tmp_path_factory.mktemp("bigquery_all_schemas") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + # Verify public dataset tables + assert (output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}").is_dir() + assert ( + output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}" / "table=users" + ).is_dir() + assert ( + output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}" / "table=orders" + ).is_dir() + + # Verify public.users files + files = sorted( + f.name + for f in ( + output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}" / "table=users" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify public.orders files + files = sorted( + f.name + for f in ( + output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}" / "table=orders" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify another dataset table + assert (output / "type=bigquery" / f"database={config.project_id}" / f"schema={another_dataset}").is_dir() + assert ( + output / "type=bigquery" / f"database={config.project_id}" / f"schema={another_dataset}" / "table=whatever" + ).is_dir() + + # Verify another.whatever files + files = sorted( + f.name + for f in ( + output + / "type=bigquery" + / f"database={config.project_id}" + / f"schema={another_dataset}" + / "table=whatever" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify state + assert state.schemas_synced == 2 + assert state.tables_synced == 3 + assert public_dataset in state.synced_schemas + assert another_dataset in state.synced_schemas + assert "users" in state.synced_tables[public_dataset] + assert "orders" in state.synced_tables[public_dataset] + assert "whatever" in state.synced_tables[another_dataset] diff --git a/cli/tests/nao_core/commands/sync/integration/test_databricks.py b/cli/tests/nao_core/commands/sync/integration/test_databricks.py new file mode 100644 index 00000000..03105d9e --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/test_databricks.py @@ -0,0 +1,324 @@ +"""Integration tests for the database sync pipeline against a real Databricks workspace. + +Connection is configured via environment variables: + DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, DATABRICKS_ACCESS_TOKEN, + DATABRICKS_CATALOG, DATABRICKS_SCHEMA (default public). + +The test suite is skipped entirely when DATABRICKS_SERVER_HOSTNAME is not set. +""" + +import json +import os +import uuid +from pathlib import Path + +import ibis +import pytest +from rich.progress import Progress + +from nao_core.commands.sync.providers.databases.provider import sync_database +from nao_core.config.databases.databricks import DatabricksConfig + +DATABRICKS_SERVER_HOSTNAME = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + +pytestmark = pytest.mark.skipif( + DATABRICKS_SERVER_HOSTNAME is None, + reason="DATABRICKS_SERVER_HOSTNAME not set — skipping Databricks integration tests", +) + + +@pytest.fixture(scope="module") +def temp_catalog(): + """Create a temporary catalog and populate it with test data, then clean up.""" + catalog_name = f"nao_unit_tests_{uuid.uuid4().hex[:8]}" + + # Connect to Databricks using ibis + conn = ibis.databricks.connect( + server_hostname=os.environ["DATABRICKS_SERVER_HOSTNAME"], + http_path=os.environ["DATABRICKS_HTTP_PATH"], + access_token=os.environ["DATABRICKS_ACCESS_TOKEN"], + ) + + try: + # Create temporary catalog + conn.raw_sql(f"CREATE CATALOG {catalog_name}").fetchall() + conn.raw_sql(f"USE CATALOG {catalog_name}").fetchall() + conn.raw_sql("CREATE SCHEMA public").fetchall() + conn.raw_sql("USE SCHEMA public").fetchall() + + # Read and execute SQL script + sql_file = Path(__file__).parent / "dml" / "databricks.sql" + sql_template = sql_file.read_text() + + # Inject catalog name into SQL + sql_content = sql_template.format(catalog=catalog_name) + + # Execute SQL statements + for statement in sql_content.split(";"): + statement = statement.strip() + if statement: + conn.raw_sql(statement).fetchall() + + yield catalog_name + + finally: + # Clean up: drop the temporary catalog + conn.raw_sql(f"DROP CATALOG IF EXISTS {catalog_name} CASCADE").fetchall() + conn.disconnect() + + +@pytest.fixture(scope="module") +def databricks_config(temp_catalog): + """Build a DatabricksConfig from environment variables using the temporary catalog.""" + return DatabricksConfig( + name="test-databricks", + server_hostname=os.environ["DATABRICKS_SERVER_HOSTNAME"], + http_path=os.environ["DATABRICKS_HTTP_PATH"], + access_token=os.environ["DATABRICKS_ACCESS_TOKEN"], + catalog=temp_catalog, + schema_name=os.environ.get("DATABRICKS_SCHEMA", "public"), + ) + + +@pytest.fixture(scope="module") +def synced(tmp_path_factory, databricks_config): + """Run sync once for the whole module and return (state, output_path, config).""" + output = tmp_path_factory.mktemp("databricks_sync") + + with Progress(transient=True) as progress: + state = sync_database(databricks_config, output, progress) + + return state, output, databricks_config + + +class TestDatabricksSyncIntegration: + """Verify the sync pipeline produces correct output against a live Databricks workspace.""" + + def test_creates_expected_directory_tree(self, synced): + state, output, config = synced + + base = output / "type=databricks" / f"database={config.catalog}" / "schema=public" + + # Schema directory + assert base.is_dir() + + # Each table should have exactly the 3 default template outputs + for table in ("orders", "users"): + assert (base / f"table={table}").is_dir() + table_dir = base / f"table={table}" + files = sorted(f.name for f in table_dir.iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify that the "another" schema was NOT synced + another_schema_dir = output / "type=databricks" / f"database={config.catalog}" / "schema=another" + assert not another_schema_dir.exists() + + def test_columns_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=users" / "columns.md" + ).read_text() + + # Databricks column types + assert "# users" in content + assert "**Dataset:** `public`" in content + assert "## Columns (4)" in content + assert "- id" in content + assert "- name" in content + assert "- email" in content + assert "- active" in content + + def test_columns_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=orders" / "columns.md" + ).read_text() + + assert "# orders" in content + assert "**Dataset:** `public`" in content + assert "## Columns (3)" in content + assert "- id" in content + assert "- user_id" in content + assert "- amount" in content + + def test_description_md_users(self, synced): + state, output, config = synced + + content = ( + output + / "type=databricks" + / f"database={config.catalog}" + / "schema=public" + / "table=users" + / "description.md" + ).read_text() + + assert "# users" in content + assert "**Dataset:** `public`" in content + assert "## Table Metadata" in content + assert "| **Row Count** | 3 |" in content + assert "| **Column Count** | 4 |" in content + + def test_description_md_orders(self, synced): + state, output, config = synced + + content = ( + output + / "type=databricks" + / f"database={config.catalog}" + / "schema=public" + / "table=orders" + / "description.md" + ).read_text() + + assert "| **Row Count** | 2 |" in content + assert "| **Column Count** | 3 |" in content + + def test_preview_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=users" / "preview.md" + ).read_text() + + assert "# users - Preview" in content + assert "**Dataset:** `public`" in content + assert "## Rows (3)" in content + + # Parse the JSONL rows from the markdown + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 3 + assert rows[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} + assert rows[1] == {"id": 2, "name": "Bob", "email": None, "active": False} + assert rows[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} + + def test_preview_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=orders" / "preview.md" + ).read_text() + + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 2 + assert rows[0] == {"id": 1, "user_id": 1, "amount": 99.99} + assert rows[1] == {"id": 2, "user_id": 1, "amount": 24.5} + + def test_sync_state_tracks_schemas_and_tables(self, synced): + state, output, config = synced + + assert state.schemas_synced == 1 + assert state.tables_synced == 2 + assert "public" in state.synced_schemas + assert "users" in state.synced_tables["public"] + assert "orders" in state.synced_tables["public"] + + def test_include_filter(self, tmp_path_factory, databricks_config): + """Only tables matching include patterns should be synced.""" + config = DatabricksConfig( + name=databricks_config.name, + server_hostname=databricks_config.server_hostname, + http_path=databricks_config.http_path, + access_token=databricks_config.access_token, + catalog=databricks_config.catalog, + schema_name=databricks_config.schema_name, + include=["public.users"], + ) + + output = tmp_path_factory.mktemp("databricks_include") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=databricks" / f"database={config.catalog}" / "schema=public" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_exclude_filter(self, tmp_path_factory, databricks_config): + """Tables matching exclude patterns should be skipped.""" + config = DatabricksConfig( + name=databricks_config.name, + server_hostname=databricks_config.server_hostname, + http_path=databricks_config.http_path, + access_token=databricks_config.access_token, + catalog=databricks_config.catalog, + schema_name=databricks_config.schema_name, + exclude=["public.orders"], + ) + + output = tmp_path_factory.mktemp("databricks_exclude") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=databricks" / f"database={config.catalog}" / "schema=public" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, databricks_config): + """When schema_name is not provided, all schemas should be synced.""" + config = DatabricksConfig( + name=databricks_config.name, + server_hostname=databricks_config.server_hostname, + http_path=databricks_config.http_path, + access_token=databricks_config.access_token, + catalog=databricks_config.catalog, + schema_name=None, + ) + + output = tmp_path_factory.mktemp("databricks_all_schemas") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + # Verify public schema tables + assert (output / "type=databricks" / f"database={config.catalog}" / "schema=public").is_dir() + assert (output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=users").is_dir() + assert (output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=orders").is_dir() + + # Verify public.users files + files = sorted( + f.name + for f in ( + output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=users" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify public.orders files + files = sorted( + f.name + for f in ( + output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=orders" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify another schema table + assert (output / "type=databricks" / f"database={config.catalog}" / "schema=another").is_dir() + assert ( + output / "type=databricks" / f"database={config.catalog}" / "schema=another" / "table=whatever" + ).is_dir() + + # Verify another.whatever files + files = sorted( + f.name + for f in ( + output / "type=databricks" / f"database={config.catalog}" / "schema=another" / "table=whatever" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify state + assert state.schemas_synced == 2 + assert state.tables_synced == 3 + assert "public" in state.synced_schemas + assert "another" in state.synced_schemas + assert "users" in state.synced_tables["public"] + assert "orders" in state.synced_tables["public"] + assert "whatever" in state.synced_tables["another"] diff --git a/cli/tests/nao_core/commands/sync/integration/test_postgres.py b/cli/tests/nao_core/commands/sync/integration/test_postgres.py new file mode 100644 index 00000000..1fa1e780 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/test_postgres.py @@ -0,0 +1,355 @@ +"""Integration tests for the database sync pipeline against a real Postgres database. + +Connection is configured via environment variables: + POSTGRES_HOST, POSTGRES_PORT (default 5432), POSTGRES_DATABASE, + POSTGRES_USER, POSTGRES_PASSWORD, + POSTGRES_SCHEMA (default public). + +The test suite is skipped entirely when POSTGRES_HOST is not set. +""" + +import json +import os +import uuid +from pathlib import Path + +import ibis +import pytest +from rich.progress import Progress + +from nao_core.commands.sync.providers.databases.provider import sync_database +from nao_core.config.databases.postgres import PostgresConfig + +POSTGRES_HOST = os.environ.get("POSTGRES_HOST") + +pytestmark = pytest.mark.skipif( + POSTGRES_HOST is None, reason="POSTGRES_HOST not set — skipping Postgres integration tests" +) + + +@pytest.fixture(scope="module") +def temp_database(): + """Create a temporary database and populate it with test data, then clean up.""" + db_name = f"nao_unit_tests_{uuid.uuid4().hex[:8].lower()}" + + # Connect to default postgres database to create test database + conn = ibis.postgres.connect( + host=os.environ["POSTGRES_HOST"], + port=int(os.environ.get("POSTGRES_PORT", "5432")), + database="postgres", + user=os.environ["POSTGRES_USER"], + password=os.environ["POSTGRES_PASSWORD"], + ) + + try: + # Create temporary database + conn.raw_sql(f"CREATE DATABASE {db_name}") + conn.disconnect() + + # Connect to the new database + conn = ibis.postgres.connect( + host=os.environ["POSTGRES_HOST"], + port=int(os.environ.get("POSTGRES_PORT", "5432")), + database=db_name, + user=os.environ["POSTGRES_USER"], + password=os.environ["POSTGRES_PASSWORD"], + ) + + # Read and execute SQL script + sql_file = Path(__file__).parent / "dml" / "postgres.sql" + sql_content = sql_file.read_text() + + # Execute SQL statements + for statement in sql_content.split(";"): + statement = statement.strip() + if statement: + try: + conn.raw_sql(statement).fetchall() + except Exception: + # Some statements (like CREATE SCHEMA) don't return results + pass + + yield db_name + + finally: + # Clean up: disconnect and drop the temporary database + conn.disconnect() + + # Reconnect to postgres database to drop test database + conn = ibis.postgres.connect( + host=os.environ["POSTGRES_HOST"], + port=int(os.environ.get("POSTGRES_PORT", "5432")), + database="postgres", + user=os.environ["POSTGRES_USER"], + password=os.environ["POSTGRES_PASSWORD"], + ) + + # Terminate any active connections to the test database + conn.raw_sql(f""" + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{db_name}' + AND pid <> pg_backend_pid() + """) + + # Drop the database + conn.raw_sql(f"DROP DATABASE IF EXISTS {db_name}") + conn.disconnect() + + +@pytest.fixture(scope="module") +def postgres_config(temp_database): + """Build a PostgresConfig from environment variables using the temporary database.""" + return PostgresConfig( + name="test-postgres", + host=os.environ["POSTGRES_HOST"], + port=int(os.environ.get("POSTGRES_PORT", "5432")), + database=temp_database, + user=os.environ["POSTGRES_USER"], + password=os.environ["POSTGRES_PASSWORD"], + schema_name=os.environ.get("POSTGRES_SCHEMA", "public"), + ) + + +@pytest.fixture(scope="module") +def synced(tmp_path_factory, postgres_config): + """Run sync once for the whole module and return (state, output_path, config).""" + output = tmp_path_factory.mktemp("postgres_sync") + + with Progress(transient=True) as progress: + state = sync_database(postgres_config, output, progress) + + return state, output, postgres_config + + +class TestPostgresSyncIntegration: + """Verify the sync pipeline produces correct output against a live Postgres database.""" + + def test_creates_expected_directory_tree(self, synced): + state, output, config = synced + + base = output / "type=postgres" / f"database={config.database}" / "schema=public" + + # Schema directory + assert base.is_dir() + + # Each table should have exactly the 3 default template outputs + for table in ("orders", "users"): + assert (base / f"table={table}").is_dir() + table_dir = base / f"table={table}" + files = sorted(f.name for f in table_dir.iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify that the "another" schema was NOT synced + another_schema_dir = output / "type=postgres" / f"database={config.database}" / "schema=another" + assert not another_schema_dir.exists() + + def test_columns_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=users" / "columns.md" + ).read_text() + + assert "# users" in content + assert "**Dataset:** `public`" in content + assert "## Columns (4)" in content + assert "- id" in content + assert "- name" in content + assert "- email" in content + assert "- active" in content + + def test_columns_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=orders" / "columns.md" + ).read_text() + + assert "# orders" in content + assert "**Dataset:** `public`" in content + assert "## Columns (3)" in content + assert "- id" in content + assert "- user_id" in content + assert "- amount" in content + + def test_description_md_users(self, synced): + state, output, config = synced + + content = ( + output + / "type=postgres" + / f"database={config.database}" + / "schema=public" + / "table=users" + / "description.md" + ).read_text() + + assert "# users" in content + assert "**Dataset:** `public`" in content + assert "## Table Metadata" in content + assert "| **Row Count** | 3 |" in content + assert "| **Column Count** | 4 |" in content + + def test_description_md_orders(self, synced): + state, output, config = synced + + content = ( + output + / "type=postgres" + / f"database={config.database}" + / "schema=public" + / "table=orders" + / "description.md" + ).read_text() + + assert "| **Row Count** | 2 |" in content + assert "| **Column Count** | 3 |" in content + + def test_preview_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=users" / "preview.md" + ).read_text() + + assert "# users - Preview" in content + assert "**Dataset:** `public`" in content + assert "## Rows (3)" in content + + # Parse the JSONL rows from the markdown + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 3 + assert rows[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} + assert rows[1] == {"id": 2, "name": "Bob", "email": None, "active": False} + assert rows[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} + + def test_preview_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=orders" / "preview.md" + ).read_text() + + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 2 + assert rows[0] == {"id": 1, "user_id": 1, "amount": 99.99} + assert rows[1] == {"id": 2, "user_id": 1, "amount": 24.5} + + def test_sync_state_tracks_schemas_and_tables(self, synced): + state, output, config = synced + + assert state.schemas_synced == 1 + assert state.tables_synced == 2 + assert "public" in state.synced_schemas + assert "users" in state.synced_tables["public"] + assert "orders" in state.synced_tables["public"] + + def test_include_filter(self, tmp_path_factory, postgres_config): + """Only tables matching include patterns should be synced.""" + config = PostgresConfig( + name=postgres_config.name, + host=postgres_config.host, + port=postgres_config.port, + database=postgres_config.database, + user=postgres_config.user, + password=postgres_config.password, + schema_name=postgres_config.schema_name, + include=["public.users"], + ) + + output = tmp_path_factory.mktemp("postgres_include") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=postgres" / f"database={config.database}" / "schema=public" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_exclude_filter(self, tmp_path_factory, postgres_config): + """Tables matching exclude patterns should be skipped.""" + config = PostgresConfig( + name=postgres_config.name, + host=postgres_config.host, + port=postgres_config.port, + database=postgres_config.database, + user=postgres_config.user, + password=postgres_config.password, + schema_name=postgres_config.schema_name, + exclude=["public.orders"], + ) + + output = tmp_path_factory.mktemp("postgres_exclude") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = output / "type=postgres" / f"database={config.database}" / "schema=public" + assert (base / "table=users").is_dir() + assert not (base / "table=orders").exists() + assert state.tables_synced == 1 + + def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, postgres_config): + """When schema_name is not provided, all schemas should be synced.""" + config = PostgresConfig( + name=postgres_config.name, + host=postgres_config.host, + port=postgres_config.port, + database=postgres_config.database, + user=postgres_config.user, + password=postgres_config.password, + schema_name=None, + ) + + output = tmp_path_factory.mktemp("postgres_all_schemas") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + # Verify public schema tables + assert (output / "type=postgres" / f"database={config.database}" / "schema=public").is_dir() + assert (output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=users").is_dir() + assert (output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=orders").is_dir() + + # Verify public.users files + files = sorted( + f.name + for f in ( + output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=users" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify public.orders files + files = sorted( + f.name + for f in ( + output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=orders" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify another schema table + assert (output / "type=postgres" / f"database={config.database}" / "schema=another").is_dir() + assert (output / "type=postgres" / f"database={config.database}" / "schema=another" / "table=whatever").is_dir() + + # Verify another.whatever files + files = sorted( + f.name + for f in ( + output / "type=postgres" / f"database={config.database}" / "schema=another" / "table=whatever" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify state + assert state.schemas_synced == 2 + assert state.tables_synced == 3 + assert "public" in state.synced_schemas + assert "another" in state.synced_schemas + assert "users" in state.synced_tables["public"] + assert "orders" in state.synced_tables["public"] + assert "whatever" in state.synced_tables["another"] diff --git a/cli/tests/nao_core/commands/sync/integration/test_redshift.py b/cli/tests/nao_core/commands/sync/integration/test_redshift.py index d1a371f8..eb1b1f55 100644 --- a/cli/tests/nao_core/commands/sync/integration/test_redshift.py +++ b/cli/tests/nao_core/commands/sync/integration/test_redshift.py @@ -10,7 +10,10 @@ import json import os +import uuid +from pathlib import Path +import ibis import pytest from rich.progress import Progress @@ -28,13 +31,70 @@ @pytest.fixture(scope="module") -def redshift_config(): - """Build a RedshiftConfig from environment variables.""" +def temp_database(): + """Create a temporary database and populate it with test data, then clean up.""" + db_name = f"nao_unit_tests_{uuid.uuid4().hex[:8]}" + + # Connect to default database to create temp database + conn = ibis.postgres.connect( + host=os.environ["REDSHIFT_HOST"], + port=int(os.environ.get("REDSHIFT_PORT", "5439")), + database=os.environ.get("REDSHIFT_DATABASE", "dev"), + user=os.environ["REDSHIFT_USER"], + password=os.environ["REDSHIFT_PASSWORD"], + client_encoding="utf8", + sslmode=os.environ.get("REDSHIFT_SSLMODE", "require"), + ) + + try: + # Create temporary database + conn.raw_sql(f"CREATE DATABASE {db_name}") + + # Connect to the new database and run setup script + test_conn = ibis.postgres.connect( + host=os.environ["REDSHIFT_HOST"], + port=int(os.environ.get("REDSHIFT_PORT", "5439")), + database=db_name, + user=os.environ["REDSHIFT_USER"], + password=os.environ["REDSHIFT_PASSWORD"], + client_encoding="utf8", + sslmode=os.environ.get("REDSHIFT_SSLMODE", "require"), + ) + + # Read and execute SQL script + sql_file = Path(__file__).parent / "dml" / "redshift.sql" + sql_template = sql_file.read_text() + + # Inject database name into SQL + sql_content = sql_template.format(database=db_name) + + # Execute SQL statements + for statement in sql_content.split(";"): + statement = statement.strip() + if statement: + test_conn.raw_sql(statement) + + test_conn.disconnect() + + yield db_name + + finally: + # Clean up: drop the temporary database (Redshift doesn't support IF EXISTS) + try: + conn.raw_sql(f"DROP DATABASE {db_name}") + except Exception: + pass # Database might not exist if setup failed + conn.disconnect() + + +@pytest.fixture(scope="module") +def redshift_config(temp_database): + """Build a RedshiftConfig from environment variables using the temporary database.""" return RedshiftConfig( name="test-redshift", host=os.environ["REDSHIFT_HOST"], port=int(os.environ.get("REDSHIFT_PORT", "5439")), - database=os.environ["REDSHIFT_DATABASE"], + database=temp_database, user=os.environ["REDSHIFT_USER"], password=os.environ["REDSHIFT_PASSWORD"], schema_name=os.environ.get("REDSHIFT_SCHEMA", "public"), @@ -59,7 +119,7 @@ class TestRedshiftSyncIntegration: def test_creates_expected_directory_tree(self, synced): state, output, config = synced - base = output / "type=redshift" / "database=nao_unit_tests" / "schema=public" + base = output / "type=redshift" / f"database={config.database}" / "schema=public" # Schema directory assert base.is_dir() @@ -71,11 +131,15 @@ def test_creates_expected_directory_tree(self, synced): files = sorted(f.name for f in table_dir.iterdir()) assert files == ["columns.md", "description.md", "preview.md"] + # Verify that the "another" schema was NOT synced + another_schema_dir = output / "type=redshift" / f"database={config.database}" / "schema=another" + assert not another_schema_dir.exists() + def test_columns_md_users(self, synced): state, output, config = synced content = ( - output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users" / "columns.md" + output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=users" / "columns.md" ).read_text() # NOT NULL columns are prefixed with ! by Ibis (e.g. !int32) @@ -96,7 +160,7 @@ def test_columns_md_orders(self, synced): state, output, config = synced content = ( - output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders" / "columns.md" + output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=orders" / "columns.md" ).read_text() assert content == ( @@ -115,7 +179,12 @@ def test_description_md_users(self, synced): state, output, config = synced content = ( - output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users" / "description.md" + output + / "type=redshift" + / f"database={config.database}" + / "schema=public" + / "table=users" + / "description.md" ).read_text() assert content == ( @@ -139,7 +208,12 @@ def test_description_md_orders(self, synced): state, output, config = synced content = ( - output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders" / "description.md" + output + / "type=redshift" + / f"database={config.database}" + / "schema=public" + / "table=orders" + / "description.md" ).read_text() assert "| **Row Count** | 2 |" in content @@ -149,7 +223,7 @@ def test_preview_md_users(self, synced): state, output, config = synced content = ( - output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users" / "preview.md" + output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=users" / "preview.md" ).read_text() assert "# users - Preview" in content @@ -169,7 +243,7 @@ def test_preview_md_orders(self, synced): state, output, config = synced content = ( - output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders" / "preview.md" + output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=orders" / "preview.md" ).read_text() lines = [line for line in content.splitlines() if line.startswith("- {")] @@ -206,7 +280,7 @@ def test_include_filter(self, tmp_path_factory, redshift_config): with Progress(transient=True) as progress: state = sync_database(config, output, progress) - base = output / "type=redshift" / "database=nao_unit_tests" / "schema=public" + base = output / "type=redshift" / f"database={config.database}" / "schema=public" assert (base / "table=users").is_dir() assert not (base / "table=orders").exists() assert state.tables_synced == 1 @@ -229,7 +303,7 @@ def test_exclude_filter(self, tmp_path_factory, redshift_config): with Progress(transient=True) as progress: state = sync_database(config, output, progress) - base = output / "type=redshift" / "database=nao_unit_tests" / "schema=public" + base = output / "type=redshift" / f"database={config.database}" / "schema=public" assert (base / "table=users").is_dir() assert not (base / "table=orders").exists() assert state.tables_synced == 1 @@ -252,33 +326,37 @@ def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, state = sync_database(config, output, progress) # Verify public schema tables - assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=public").is_dir() - assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users").is_dir() - assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders").is_dir() + assert (output / "type=redshift" / f"database={config.database}" / "schema=public").is_dir() + assert (output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=users").is_dir() + assert (output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=orders").is_dir() # Verify public.users files files = sorted( f.name - for f in (output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=users").iterdir() + for f in ( + output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=users" + ).iterdir() ) assert files == ["columns.md", "description.md", "preview.md"] # Verify public.orders files files = sorted( f.name - for f in (output / "type=redshift" / "database=nao_unit_tests" / "schema=public" / "table=orders").iterdir() + for f in ( + output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=orders" + ).iterdir() ) assert files == ["columns.md", "description.md", "preview.md"] # Verify another schema table - assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=another").is_dir() - assert (output / "type=redshift" / "database=nao_unit_tests" / "schema=another" / "table=whatever").is_dir() + assert (output / "type=redshift" / f"database={config.database}" / "schema=another").is_dir() + assert (output / "type=redshift" / f"database={config.database}" / "schema=another" / "table=whatever").is_dir() # Verify another.whatever files files = sorted( f.name for f in ( - output / "type=redshift" / "database=nao_unit_tests" / "schema=another" / "table=whatever" + output / "type=redshift" / f"database={config.database}" / "schema=another" / "table=whatever" ).iterdir() ) assert files == ["columns.md", "description.md", "preview.md"] diff --git a/cli/tests/nao_core/commands/sync/integration/test_snowflake.py b/cli/tests/nao_core/commands/sync/integration/test_snowflake.py new file mode 100644 index 00000000..d1242bf5 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/test_snowflake.py @@ -0,0 +1,368 @@ +"""Integration tests for the database sync pipeline against a real Snowflake database. + +Connection is configured via environment variables: + SNOWFLAKE_ACCOUNT_ID, SNOWFLAKE_USERNAME + SNOWFLAKE_PRIVATE_KEY_PATH, SNOWFLAKE_PASSPHRASE (optional), + SNOWFLAKE_SCHEMA (default public), SNOWFLAKE_WAREHOUSE (optional). + +The test suite is skipped entirely when SNOWFLAKE_ACCOUNT_ID is not set. +""" + +import json +import os +import uuid +from pathlib import Path + +import ibis +import pytest +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from rich.progress import Progress + +from nao_core.commands.sync.providers.databases.provider import sync_database +from nao_core.config.databases.snowflake import SnowflakeConfig + +SNOWFLAKE_ACCOUNT_ID = os.environ.get("SNOWFLAKE_ACCOUNT_ID") + +pytestmark = pytest.mark.skipif( + SNOWFLAKE_ACCOUNT_ID is None, reason="SNOWFLAKE_ACCOUNT_ID not set — skipping Snowflake integration tests" +) + + +@pytest.fixture(scope="module") +def temp_database(): + """Create a temporary database and populate it with test data, then clean up.""" + db_name = f"NAO_UNIT_TESTS_{uuid.uuid4().hex[:8].upper()}" + + # Load private key for authentication + private_key_path = os.environ["SNOWFLAKE_PRIVATE_KEY_PATH"] + passphrase = os.environ.get("SNOWFLAKE_PASSPHRASE") + + with open(private_key_path, "rb") as key_file: + private_key = serialization.load_pem_private_key( + key_file.read(), + password=passphrase.encode() if passphrase else None, + backend=default_backend(), + ) + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Connect to Snowflake (without specifying database) to create temp database + conn = ibis.snowflake.connect( + user=os.environ["SNOWFLAKE_USERNAME"], + account=os.environ["SNOWFLAKE_ACCOUNT_ID"], + private_key=private_key_bytes, + warehouse=os.environ.get("SNOWFLAKE_WAREHOUSE"), + create_object_udfs=False, + ) + + try: + # Create temporary database + conn.raw_sql(f"CREATE DATABASE {db_name}").fetchall() + + # Connect to the new database and run setup script + test_conn = ibis.snowflake.connect( + user=os.environ["SNOWFLAKE_USERNAME"], + account=os.environ["SNOWFLAKE_ACCOUNT_ID"], + private_key=private_key_bytes, + warehouse=os.environ.get("SNOWFLAKE_WAREHOUSE"), + database=db_name, + ) + + # Create schema + test_conn.raw_sql("CREATE SCHEMA IF NOT EXISTS public").fetchall() + + # Read and execute SQL script + sql_file = Path(__file__).parent / "dml" / "snowflake.sql" + sql_template = sql_file.read_text() + + # Inject database name into SQL + sql_content = sql_template.format(database=db_name) + + # Execute SQL statements + for statement in sql_content.split(";"): + statement = statement.strip() + if statement: + test_conn.raw_sql(statement).fetchall() + + test_conn.disconnect() + + yield db_name + + finally: + # Clean up: drop the temporary database + conn.raw_sql(f"DROP DATABASE IF EXISTS {db_name}").fetchall() + conn.disconnect() + + +@pytest.fixture(scope="module") +def snowflake_config(temp_database): + """Build a SnowflakeConfig from environment variables using the temporary database.""" + return SnowflakeConfig( + name="test-snowflake", + account_id=os.environ["SNOWFLAKE_ACCOUNT_ID"], + username=os.environ["SNOWFLAKE_USERNAME"], + database=temp_database, + private_key_path=os.environ["SNOWFLAKE_PRIVATE_KEY_PATH"], + passphrase=os.environ.get("SNOWFLAKE_PASSPHRASE"), + schema_name="public", + warehouse=os.environ.get("SNOWFLAKE_WAREHOUSE"), + ) + + +@pytest.fixture(scope="module") +def synced(tmp_path_factory, snowflake_config): + """Run sync once for the whole module and return (state, output_path, config).""" + output = tmp_path_factory.mktemp("snowflake_sync") + + with Progress(transient=True) as progress: + state = sync_database(snowflake_config, output, progress) + + return state, output, snowflake_config + + +class TestSnowflakeSyncIntegration: + """Verify the sync pipeline produces correct output against a live Snowflake database.""" + + def test_creates_expected_directory_tree(self, synced): + state, output, config = synced + + base = output / "type=snowflake" / f"database={config.database}" / "schema=public" + + # Schema directory + assert base.is_dir() + + # Each table should have exactly the 3 default template outputs + for table in ("orders", "users"): + assert (base / f"table={table}").is_dir() + table_dir = base / f"table={table}" + files = sorted(f.name for f in table_dir.iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify that the "another" schema was NOT synced + another_schema_dir = output / "type=snowflake" / f"database={config.database}" / "schema=another" + assert not another_schema_dir.exists() + + def test_columns_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=snowflake" / f"database={config.database}" / "schema=public" / "table=users" / "columns.md" + ).read_text() + + # Snowflake stores identifiers in uppercase by default + assert "# USERS" in content + assert "**Dataset:** `PUBLIC`" in content + assert "## Columns (4)" in content + assert "- ID" in content + assert "- NAME" in content + assert "- EMAIL" in content + assert "- ACTIVE" in content + + def test_columns_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=snowflake" / f"database={config.database}" / "schema=public" / "table=orders" / "columns.md" + ).read_text() + + assert "# ORDERS" in content + assert "**Dataset:** `PUBLIC`" in content + assert "## Columns (3)" in content + assert "- ID" in content + assert "- USER_ID" in content + assert "- AMOUNT" in content + + def test_description_md_users(self, synced): + state, output, config = synced + + content = ( + output + / "type=snowflake" + / f"database={config.database}" + / "schema=public" + / "table=users" + / "description.md" + ).read_text() + + assert "# USERS" in content + assert "**Dataset:** `PUBLIC`" in content + assert "## Table Metadata" in content + assert "| **Row Count** | 3 |" in content + assert "| **Column Count** | 4 |" in content + + def test_description_md_orders(self, synced): + state, output, config = synced + + content = ( + output + / "type=snowflake" + / f"database={config.database}" + / "schema=public" + / "table=orders" + / "description.md" + ).read_text() + + assert "| **Row Count** | 2 |" in content + assert "| **Column Count** | 3 |" in content + + def test_preview_md_users(self, synced): + state, output, config = synced + + content = ( + output / "type=snowflake" / f"database={config.database}" / "schema=public" / "table=users" / "preview.md" + ).read_text() + + assert "# USERS - Preview" in content + assert "**Dataset:** `PUBLIC`" in content + assert "## Rows (3)" in content + + # Parse the JSONL rows from the markdown + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 3 + # Snowflake returns column names in uppercase + assert rows[0] == {"ID": 1, "NAME": "Alice", "EMAIL": "alice@example.com", "ACTIVE": True} + assert rows[1] == {"ID": 2, "NAME": "Bob", "EMAIL": None, "ACTIVE": False} + assert rows[2] == {"ID": 3, "NAME": "Charlie", "EMAIL": "charlie@example.com", "ACTIVE": True} + + def test_preview_md_orders(self, synced): + state, output, config = synced + + content = ( + output / "type=snowflake" / f"database={config.database}" / "schema=public" / "table=orders" / "preview.md" + ).read_text() + + lines = [line for line in content.splitlines() if line.startswith("- {")] + rows = [json.loads(line[2:]) for line in lines] + + assert len(rows) == 2 + # Snowflake returns column names in uppercase and integers as floats + assert rows[0] == {"ID": 1.0, "USER_ID": 1.0, "AMOUNT": 99.99} + assert rows[1] == {"ID": 2.0, "USER_ID": 1.0, "AMOUNT": 24.5} + + def test_sync_state_tracks_schemas_and_tables(self, synced): + state, output, config = synced + + assert state.schemas_synced == 1 + assert state.tables_synced == 2 + # Snowflake stores schema and table names in uppercase + assert "PUBLIC" in state.synced_schemas + assert "USERS" in state.synced_tables["PUBLIC"] + assert "ORDERS" in state.synced_tables["PUBLIC"] + + def test_include_filter(self, tmp_path_factory, snowflake_config): + """Only tables matching include patterns should be synced.""" + config = SnowflakeConfig( + name=snowflake_config.name, + account_id=snowflake_config.account_id, + username=snowflake_config.username, + database=snowflake_config.database, + private_key_path=snowflake_config.private_key_path, + passphrase=snowflake_config.passphrase, + schema_name=snowflake_config.schema_name, + warehouse=snowflake_config.warehouse, + include=["public.users"], + ) + + output = tmp_path_factory.mktemp("snowflake_include") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + # Snowflake uses uppercase names for schemas and tables + base = output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" + assert (base / "table=USERS").is_dir() + assert not (base / "table=ORDERS").exists() + assert state.tables_synced == 1 + + def test_exclude_filter(self, tmp_path_factory, snowflake_config): + """Tables matching exclude patterns should be skipped.""" + config = SnowflakeConfig( + name=snowflake_config.name, + account_id=snowflake_config.account_id, + username=snowflake_config.username, + database=snowflake_config.database, + private_key_path=snowflake_config.private_key_path, + passphrase=snowflake_config.passphrase, + schema_name=snowflake_config.schema_name, + warehouse=snowflake_config.warehouse, + exclude=["public.orders"], + ) + + output = tmp_path_factory.mktemp("snowflake_exclude") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + # Snowflake uses uppercase names for schemas and tables + base = output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" + assert (base / "table=USERS").is_dir() + assert not (base / "table=ORDERS").exists() + assert state.tables_synced == 1 + + def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, snowflake_config): + """When schema_name is not provided, all schemas should be synced.""" + config = SnowflakeConfig( + name=snowflake_config.name, + account_id=snowflake_config.account_id, + username=snowflake_config.username, + database=snowflake_config.database, + private_key_path=snowflake_config.private_key_path, + passphrase=snowflake_config.passphrase, + schema_name=None, + warehouse=snowflake_config.warehouse, + ) + + output = tmp_path_factory.mktemp("snowflake_all_schemas") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + # Verify PUBLIC schema tables (Snowflake uses uppercase names) + assert (output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC").is_dir() + assert (output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" / "table=USERS").is_dir() + assert (output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" / "table=ORDERS").is_dir() + + # Verify PUBLIC.USERS files + files = sorted( + f.name + for f in ( + output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" / "table=USERS" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify PUBLIC.ORDERS files + files = sorted( + f.name + for f in ( + output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" / "table=ORDERS" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify ANOTHER schema table + assert (output / "type=snowflake" / f"database={config.database}" / "schema=ANOTHER").is_dir() + assert ( + output / "type=snowflake" / f"database={config.database}" / "schema=ANOTHER" / "table=WHATEVER" + ).is_dir() + + # Verify ANOTHER.WHATEVER files + files = sorted( + f.name + for f in ( + output / "type=snowflake" / f"database={config.database}" / "schema=ANOTHER" / "table=WHATEVER" + ).iterdir() + ) + assert files == ["columns.md", "description.md", "preview.md"] + + # Verify state + assert state.schemas_synced == 2 + assert state.tables_synced == 3 + assert "PUBLIC" in state.synced_schemas + assert "ANOTHER" in state.synced_schemas + assert "USERS" in state.synced_tables["PUBLIC"] + assert "ORDERS" in state.synced_tables["PUBLIC"] + assert "WHATEVER" in state.synced_tables["ANOTHER"] From bbfca2c2a9e57d6d8fd79f14013e557238ca95f7 Mon Sep 17 00:00:00 2001 From: Christophe Blefari Date: Tue, 10 Feb 2026 19:28:15 +0100 Subject: [PATCH 3/7] Refacto of tests so the code is less ugly --- .../commands/sync/integration/base.py | 240 ++++++++++++++ .../commands/sync/integration/conftest.py | 13 + .../sync/integration/test_bigquery.py | 311 +++--------------- .../sync/integration/test_databricks.py | 280 +++------------- .../commands/sync/integration/test_duckdb.py | 223 +++---------- .../sync/integration/test_postgres.py | 280 +++------------- .../sync/integration/test_redshift.py | 308 +++-------------- .../sync/integration/test_snowflake.py | 296 +++-------------- 8 files changed, 496 insertions(+), 1455 deletions(-) create mode 100644 cli/tests/nao_core/commands/sync/integration/base.py diff --git a/cli/tests/nao_core/commands/sync/integration/base.py b/cli/tests/nao_core/commands/sync/integration/base.py new file mode 100644 index 00000000..da738292 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/base.py @@ -0,0 +1,240 @@ +"""Shared base class and spec for database sync integration tests. + +Each provider test file defines fixtures (`db_config`, `spec`) and a test class +that inherits from `BaseSyncIntegrationTests` to get all shared assertions for free. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path + +import pytest +from rich.progress import Progress + +from nao_core.commands.sync.providers.databases.provider import sync_database + + +@dataclass(frozen=True) +class SyncTestSpec: + """Provider-specific expected values for database sync integration tests.""" + + db_type: str + primary_schema: str + + # Table names as they appear in output paths and state + users_table: str = "users" + orders_table: str = "orders" + + # Strings expected in columns.md (checked with `in content`) + users_column_assertions: tuple[str, ...] = () + orders_column_assertions: tuple[str, ...] = () + + # Expected preview rows (sorted by row_id_key when sort_rows is True) + users_preview_rows: list[dict] = field(default_factory=list) + orders_preview_rows: list[dict] = field(default_factory=list) + sort_rows: bool = False + row_id_key: str = "id" + + # Schema prefix for include/exclude filter patterns (defaults to primary_schema) + filter_schema: str | None = None + + # Multi-schema test (optional — skipped when schema_field is None) + schema_field: str | None = None + another_schema: str | None = None + another_table: str | None = None + + @property + def effective_filter_schema(self) -> str: + return self.filter_schema or self.primary_schema + + +class BaseSyncIntegrationTests: + """Shared sync integration tests. + + Subclasses must provide `synced`, `db_config`, and `spec` fixtures + (typically via module-level fixtures in each test file). + """ + + # ── helpers ────────────────────────────────────────────────────── + + def _base_path(self, output: Path, config, spec: SyncTestSpec) -> Path: + return ( + output / f"type={spec.db_type}" / f"database={config.get_database_name()}" / f"schema={spec.primary_schema}" + ) + + def _read_table_file(self, output, config, spec, table, filename): + return (self._base_path(output, config, spec) / f"table={table}" / filename).read_text() + + def _parse_preview_rows(self, content: str) -> list[dict]: + lines = [line for line in content.splitlines() if line.startswith("- {")] + return [json.loads(line[2:]) for line in lines] + + # ── directory tree ─────────────────────────────────────────────── + + def test_creates_expected_directory_tree(self, synced, spec): + state, output, config = synced + base = self._base_path(output, config, spec) + + assert base.is_dir() + + for table in (spec.users_table, spec.orders_table): + table_dir = base / f"table={table}" + assert table_dir.is_dir() + files = sorted(f.name for f in table_dir.iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + # "another" schema was NOT synced (only when provider has one) + if spec.another_schema: + another_dir = ( + output + / f"type={spec.db_type}" + / f"database={config.get_database_name()}" + / f"schema={spec.another_schema}" + ) + assert not another_dir.exists() + + # ── columns.md ─────────────────────────────────────────────────── + + def test_columns_md_users(self, synced, spec): + _, output, config = synced + content = self._read_table_file(output, config, spec, spec.users_table, "columns.md") + + for expected in spec.users_column_assertions: + assert expected in content + + def test_columns_md_orders(self, synced, spec): + _, output, config = synced + content = self._read_table_file(output, config, spec, spec.orders_table, "columns.md") + + for expected in spec.orders_column_assertions: + assert expected in content + + # ── description.md ─────────────────────────────────────────────── + + def test_description_md_users(self, synced, spec): + _, output, config = synced + content = self._read_table_file(output, config, spec, spec.users_table, "description.md") + + assert "## Table Metadata" in content + assert "| **Row Count** | 3 |" in content + assert "| **Column Count** | 4 |" in content + + def test_description_md_orders(self, synced, spec): + _, output, config = synced + content = self._read_table_file(output, config, spec, spec.orders_table, "description.md") + + assert "| **Row Count** | 2 |" in content + assert "| **Column Count** | 3 |" in content + + # ── preview.md ─────────────────────────────────────────────────── + + def test_preview_md_users(self, synced, spec): + _, output, config = synced + content = self._read_table_file(output, config, spec, spec.users_table, "preview.md") + + assert "## Rows (3)" in content + + rows = self._parse_preview_rows(content) + assert len(rows) == 3 + + if spec.sort_rows: + rows = sorted(rows, key=lambda r: r[spec.row_id_key]) + + assert rows == spec.users_preview_rows + + def test_preview_md_orders(self, synced, spec): + _, output, config = synced + content = self._read_table_file(output, config, spec, spec.orders_table, "preview.md") + + rows = self._parse_preview_rows(content) + assert len(rows) == 2 + + if spec.sort_rows: + rows = sorted(rows, key=lambda r: r[spec.row_id_key]) + + assert rows == spec.orders_preview_rows + + # ── sync state ─────────────────────────────────────────────────── + + def test_sync_state_tracks_schemas_and_tables(self, synced, spec): + state, _, _ = synced + + assert state.schemas_synced == 1 + assert state.tables_synced == 2 + assert spec.primary_schema in state.synced_schemas + assert spec.users_table in state.synced_tables[spec.primary_schema] + assert spec.orders_table in state.synced_tables[spec.primary_schema] + + # ── include / exclude filters ──────────────────────────────────── + + def test_include_filter(self, tmp_path_factory, db_config, spec): + """Only tables matching include patterns should be synced.""" + schema = spec.effective_filter_schema + config = db_config.model_copy(update={"include": [f"{schema}.{spec.users_table}"]}) + + output = tmp_path_factory.mktemp(f"{spec.db_type}_include") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = self._base_path(output, config, spec) + assert (base / f"table={spec.users_table}").is_dir() + assert not (base / f"table={spec.orders_table}").exists() + assert state.tables_synced == 1 + + def test_exclude_filter(self, tmp_path_factory, db_config, spec): + """Tables matching exclude patterns should be skipped.""" + schema = spec.effective_filter_schema + config = db_config.model_copy(update={"exclude": [f"{schema}.{spec.orders_table}"]}) + + output = tmp_path_factory.mktemp(f"{spec.db_type}_exclude") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + base = self._base_path(output, config, spec) + assert (base / f"table={spec.users_table}").is_dir() + assert not (base / f"table={spec.orders_table}").exists() + assert state.tables_synced == 1 + + # ── multi-schema sync ──────────────────────────────────────────── + + def test_sync_all_schemas(self, tmp_path_factory, db_config, spec): + """When schema is not specified, all schemas should be synced.""" + if spec.schema_field is None: + pytest.skip("Provider does not support multi-schema test") + + config = db_config.model_copy(update={spec.schema_field: None}) + + output = tmp_path_factory.mktemp(f"{spec.db_type}_all_schemas") + with Progress(transient=True) as progress: + state = sync_database(config, output, progress) + + db_name = config.get_database_name() + + # Primary schema tables + primary_base = output / f"type={spec.db_type}" / f"database={db_name}" / f"schema={spec.primary_schema}" + assert primary_base.is_dir() + assert (primary_base / f"table={spec.users_table}").is_dir() + assert (primary_base / f"table={spec.orders_table}").is_dir() + + for table in (spec.users_table, spec.orders_table): + files = sorted(f.name for f in (primary_base / f"table={table}").iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + # Another schema + another_base = output / f"type={spec.db_type}" / f"database={db_name}" / f"schema={spec.another_schema}" + assert another_base.is_dir() + assert (another_base / f"table={spec.another_table}").is_dir() + + files = sorted(f.name for f in (another_base / f"table={spec.another_table}").iterdir()) + assert files == ["columns.md", "description.md", "preview.md"] + + # State + assert state.schemas_synced == 2 + assert state.tables_synced == 3 + assert spec.primary_schema in state.synced_schemas + assert spec.another_schema in state.synced_schemas + assert spec.users_table in state.synced_tables[spec.primary_schema] + assert spec.orders_table in state.synced_tables[spec.primary_schema] + assert spec.another_table in state.synced_tables[spec.another_schema] diff --git a/cli/tests/nao_core/commands/sync/integration/conftest.py b/cli/tests/nao_core/commands/sync/integration/conftest.py index c62f792e..60e67d11 100644 --- a/cli/tests/nao_core/commands/sync/integration/conftest.py +++ b/cli/tests/nao_core/commands/sync/integration/conftest.py @@ -4,8 +4,10 @@ import pytest from dotenv import load_dotenv +from rich.progress import Progress import nao_core.templates.engine as engine_module +from nao_core.commands.sync.providers.databases.provider import sync_database # Auto-load .env sitting next to this conftest so env vars are available # before pytest collects test modules (where skipif reads them). @@ -18,3 +20,14 @@ def reset_template_engine(): engine_module._engine = None yield engine_module._engine = None + + +@pytest.fixture(scope="module") +def synced(tmp_path_factory, db_config): + """Run sync once for the whole module and return (state, output_path, config).""" + output = tmp_path_factory.mktemp(f"{db_config.type}_sync") + + with Progress(transient=True) as progress: + state = sync_database(db_config, output, progress) + + return state, output, db_config diff --git a/cli/tests/nao_core/commands/sync/integration/test_bigquery.py b/cli/tests/nao_core/commands/sync/integration/test_bigquery.py index d7496896..b3bbddfc 100644 --- a/cli/tests/nao_core/commands/sync/integration/test_bigquery.py +++ b/cli/tests/nao_core/commands/sync/integration/test_bigquery.py @@ -15,11 +15,11 @@ import pytest from google.cloud import bigquery from google.oauth2 import service_account -from rich.progress import Progress -from nao_core.commands.sync.providers.databases.provider import sync_database from nao_core.config.databases.bigquery import BigQueryConfig +from .base import BaseSyncIntegrationTests, SyncTestSpec + BIGQUERY_PROJECT_ID = os.environ.get("BIGQUERY_PROJECT_ID") pytestmark = pytest.mark.skipif( @@ -97,7 +97,7 @@ def temp_datasets(): @pytest.fixture(scope="module") -def bigquery_config(temp_datasets): +def db_config(temp_datasets): """Build a BigQueryConfig from environment variables using the temporary dataset.""" credentials_json_str = os.environ.get("BIGQUERY_CREDENTIALS_JSON") credentials_json = json.loads(credentials_json_str) if credentials_json_str else None @@ -111,275 +111,42 @@ def bigquery_config(temp_datasets): @pytest.fixture(scope="module") -def synced(tmp_path_factory, bigquery_config): - """Run sync once for the whole module and return (state, output_path, config).""" - output = tmp_path_factory.mktemp("bigquery_sync") - - with Progress(transient=True) as progress: - state = sync_database(bigquery_config, output, progress) - - return state, output, bigquery_config +def spec(db_config, temp_datasets): + return SyncTestSpec( + db_type="bigquery", + primary_schema=db_config.dataset_id, + users_column_assertions=( + "# users", + f"**Dataset:** `{db_config.dataset_id}`", + "## Columns (4)", + "- id (int64 NOT NULL)", + "- name (string NOT NULL)", + "- email (string)", + "- active (boolean)", + ), + orders_column_assertions=( + "# orders", + f"**Dataset:** `{db_config.dataset_id}`", + "## Columns (3)", + "- id (int64 NOT NULL)", + "- user_id (int64 NOT NULL)", + "- amount (float64 NOT NULL)", + ), + users_preview_rows=[ + {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True}, + {"id": 2, "name": "Bob", "email": None, "active": False}, + {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True}, + ], + orders_preview_rows=[ + {"id": 1, "user_id": 1, "amount": 99.99}, + {"id": 2, "user_id": 1, "amount": 24.5}, + ], + sort_rows=True, + schema_field="dataset_id", + another_schema=temp_datasets["another"], + another_table="whatever", + ) -class TestBigQuerySyncIntegration: +class TestBigQuerySyncIntegration(BaseSyncIntegrationTests): """Verify the sync pipeline produces correct output against a live BigQuery project.""" - - def test_creates_expected_directory_tree(self, synced, temp_datasets): - state, output, config = synced - - base = output / "type=bigquery" / f"database={config.project_id}" / f"schema={config.dataset_id}" - - # Schema directory - assert base.is_dir() - - # Each table should have exactly the 3 default template outputs - for table in ("orders", "users"): - assert (base / f"table={table}").is_dir() - table_dir = base / f"table={table}" - files = sorted(f.name for f in table_dir.iterdir()) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify that the "another" dataset was NOT synced - another_dataset_dir = ( - output / "type=bigquery" / f"database={config.project_id}" / f"schema={temp_datasets['another']}" - ) - assert not another_dataset_dir.exists() - - def test_columns_md_users(self, synced): - state, output, config = synced - - content = ( - output - / "type=bigquery" - / f"database={config.project_id}" - / f"schema={config.dataset_id}" - / "table=users" - / "columns.md" - ).read_text() - - # BigQuery uses int64, string, bool types - assert "# users" in content - assert f"**Dataset:** `{config.dataset_id}`" in content - assert "## Columns (4)" in content - assert "- id (int64 NOT NULL)" in content - assert "- name (string NOT NULL)" in content - assert "- email (string)" in content - assert "- active (boolean)" in content - - def test_columns_md_orders(self, synced): - state, output, config = synced - - content = ( - output - / "type=bigquery" - / f"database={config.project_id}" - / f"schema={config.dataset_id}" - / "table=orders" - / "columns.md" - ).read_text() - - assert "# orders" in content - assert f"**Dataset:** `{config.dataset_id}`" in content - assert "## Columns (3)" in content - assert "- id (int64 NOT NULL)" in content - assert "- user_id (int64 NOT NULL)" in content - assert "- amount (float64 NOT NULL)" in content - - def test_description_md_users(self, synced): - state, output, config = synced - - content = ( - output - / "type=bigquery" - / f"database={config.project_id}" - / f"schema={config.dataset_id}" - / "table=users" - / "description.md" - ).read_text() - - assert "# users" in content - assert f"**Dataset:** `{config.dataset_id}`" in content - assert "## Table Metadata" in content - assert "| **Row Count** | 3 |" in content - assert "| **Column Count** | 4 |" in content - - def test_description_md_orders(self, synced): - state, output, config = synced - - content = ( - output - / "type=bigquery" - / f"database={config.project_id}" - / f"schema={config.dataset_id}" - / "table=orders" - / "description.md" - ).read_text() - - assert "| **Row Count** | 2 |" in content - assert "| **Column Count** | 3 |" in content - - def test_preview_md_users(self, synced): - state, output, config = synced - - content = ( - output - / "type=bigquery" - / f"database={config.project_id}" - / f"schema={config.dataset_id}" - / "table=users" - / "preview.md" - ).read_text() - - assert "# users - Preview" in content - assert f"**Dataset:** `{config.dataset_id}`" in content - assert "## Rows (3)" in content - - # Parse the JSONL rows from the markdown - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 3 - # Sort by id since BigQuery doesn't guarantee order - rows_sorted = sorted(rows, key=lambda r: r["id"]) - assert rows_sorted[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} - assert rows_sorted[1] == {"id": 2, "name": "Bob", "email": None, "active": False} - assert rows_sorted[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} - - def test_preview_md_orders(self, synced): - state, output, config = synced - - content = ( - output - / "type=bigquery" - / f"database={config.project_id}" - / f"schema={config.dataset_id}" - / "table=orders" - / "preview.md" - ).read_text() - - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 2 - # Sort by id since BigQuery doesn't guarantee order - rows_sorted = sorted(rows, key=lambda r: r["id"]) - assert rows_sorted[0] == {"id": 1, "user_id": 1, "amount": 99.99} - assert rows_sorted[1] == {"id": 2, "user_id": 1, "amount": 24.5} - - def test_sync_state_tracks_schemas_and_tables(self, synced): - state, output, config = synced - - assert state.schemas_synced == 1 - assert state.tables_synced == 2 - assert config.dataset_id in state.synced_schemas - assert "users" in state.synced_tables[config.dataset_id] - assert "orders" in state.synced_tables[config.dataset_id] - - def test_include_filter(self, tmp_path_factory, bigquery_config): - """Only tables matching include patterns should be synced.""" - config = BigQueryConfig( - name=bigquery_config.name, - project_id=bigquery_config.project_id, - dataset_id=bigquery_config.dataset_id, - credentials_json=bigquery_config.credentials_json, - include=[f"{bigquery_config.dataset_id}.users"], - ) - - output = tmp_path_factory.mktemp("bigquery_include") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=bigquery" / f"database={config.project_id}" / f"schema={config.dataset_id}" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_exclude_filter(self, tmp_path_factory, bigquery_config): - """Tables matching exclude patterns should be skipped.""" - config = BigQueryConfig( - name=bigquery_config.name, - project_id=bigquery_config.project_id, - dataset_id=bigquery_config.dataset_id, - credentials_json=bigquery_config.credentials_json, - exclude=[f"{bigquery_config.dataset_id}.orders"], - ) - - output = tmp_path_factory.mktemp("bigquery_exclude") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=bigquery" / f"database={config.project_id}" / f"schema={config.dataset_id}" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_sync_all_schemas_when_dataset_id_not_specified(self, tmp_path_factory, bigquery_config, temp_datasets): - """When dataset_id is not provided, all datasets should be synced.""" - public_dataset = temp_datasets["public"] - another_dataset = temp_datasets["another"] - - config = BigQueryConfig( - name=bigquery_config.name, - project_id=bigquery_config.project_id, - dataset_id=None, - credentials_json=bigquery_config.credentials_json, - ) - - output = tmp_path_factory.mktemp("bigquery_all_schemas") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - # Verify public dataset tables - assert (output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}").is_dir() - assert ( - output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}" / "table=users" - ).is_dir() - assert ( - output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}" / "table=orders" - ).is_dir() - - # Verify public.users files - files = sorted( - f.name - for f in ( - output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}" / "table=users" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify public.orders files - files = sorted( - f.name - for f in ( - output / "type=bigquery" / f"database={config.project_id}" / f"schema={public_dataset}" / "table=orders" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify another dataset table - assert (output / "type=bigquery" / f"database={config.project_id}" / f"schema={another_dataset}").is_dir() - assert ( - output / "type=bigquery" / f"database={config.project_id}" / f"schema={another_dataset}" / "table=whatever" - ).is_dir() - - # Verify another.whatever files - files = sorted( - f.name - for f in ( - output - / "type=bigquery" - / f"database={config.project_id}" - / f"schema={another_dataset}" - / "table=whatever" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify state - assert state.schemas_synced == 2 - assert state.tables_synced == 3 - assert public_dataset in state.synced_schemas - assert another_dataset in state.synced_schemas - assert "users" in state.synced_tables[public_dataset] - assert "orders" in state.synced_tables[public_dataset] - assert "whatever" in state.synced_tables[another_dataset] diff --git a/cli/tests/nao_core/commands/sync/integration/test_databricks.py b/cli/tests/nao_core/commands/sync/integration/test_databricks.py index 03105d9e..4186fb19 100644 --- a/cli/tests/nao_core/commands/sync/integration/test_databricks.py +++ b/cli/tests/nao_core/commands/sync/integration/test_databricks.py @@ -7,18 +7,17 @@ The test suite is skipped entirely when DATABRICKS_SERVER_HOSTNAME is not set. """ -import json import os import uuid from pathlib import Path import ibis import pytest -from rich.progress import Progress -from nao_core.commands.sync.providers.databases.provider import sync_database from nao_core.config.databases.databricks import DatabricksConfig +from .base import BaseSyncIntegrationTests, SyncTestSpec + DATABRICKS_SERVER_HOSTNAME = os.environ.get("DATABRICKS_SERVER_HOSTNAME") pytestmark = pytest.mark.skipif( @@ -68,7 +67,7 @@ def temp_catalog(): @pytest.fixture(scope="module") -def databricks_config(temp_catalog): +def db_config(temp_catalog): """Build a DatabricksConfig from environment variables using the temporary catalog.""" return DatabricksConfig( name="test-databricks", @@ -81,244 +80,41 @@ def databricks_config(temp_catalog): @pytest.fixture(scope="module") -def synced(tmp_path_factory, databricks_config): - """Run sync once for the whole module and return (state, output_path, config).""" - output = tmp_path_factory.mktemp("databricks_sync") - - with Progress(transient=True) as progress: - state = sync_database(databricks_config, output, progress) - - return state, output, databricks_config +def spec(): + return SyncTestSpec( + db_type="databricks", + primary_schema="public", + users_column_assertions=( + "# users", + "**Dataset:** `public`", + "## Columns (4)", + "- id", + "- name", + "- email", + "- active", + ), + orders_column_assertions=( + "# orders", + "**Dataset:** `public`", + "## Columns (3)", + "- id", + "- user_id", + "- amount", + ), + users_preview_rows=[ + {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True}, + {"id": 2, "name": "Bob", "email": None, "active": False}, + {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True}, + ], + orders_preview_rows=[ + {"id": 1, "user_id": 1, "amount": 99.99}, + {"id": 2, "user_id": 1, "amount": 24.5}, + ], + schema_field="schema_name", + another_schema="another", + another_table="whatever", + ) -class TestDatabricksSyncIntegration: +class TestDatabricksSyncIntegration(BaseSyncIntegrationTests): """Verify the sync pipeline produces correct output against a live Databricks workspace.""" - - def test_creates_expected_directory_tree(self, synced): - state, output, config = synced - - base = output / "type=databricks" / f"database={config.catalog}" / "schema=public" - - # Schema directory - assert base.is_dir() - - # Each table should have exactly the 3 default template outputs - for table in ("orders", "users"): - assert (base / f"table={table}").is_dir() - table_dir = base / f"table={table}" - files = sorted(f.name for f in table_dir.iterdir()) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify that the "another" schema was NOT synced - another_schema_dir = output / "type=databricks" / f"database={config.catalog}" / "schema=another" - assert not another_schema_dir.exists() - - def test_columns_md_users(self, synced): - state, output, config = synced - - content = ( - output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=users" / "columns.md" - ).read_text() - - # Databricks column types - assert "# users" in content - assert "**Dataset:** `public`" in content - assert "## Columns (4)" in content - assert "- id" in content - assert "- name" in content - assert "- email" in content - assert "- active" in content - - def test_columns_md_orders(self, synced): - state, output, config = synced - - content = ( - output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=orders" / "columns.md" - ).read_text() - - assert "# orders" in content - assert "**Dataset:** `public`" in content - assert "## Columns (3)" in content - assert "- id" in content - assert "- user_id" in content - assert "- amount" in content - - def test_description_md_users(self, synced): - state, output, config = synced - - content = ( - output - / "type=databricks" - / f"database={config.catalog}" - / "schema=public" - / "table=users" - / "description.md" - ).read_text() - - assert "# users" in content - assert "**Dataset:** `public`" in content - assert "## Table Metadata" in content - assert "| **Row Count** | 3 |" in content - assert "| **Column Count** | 4 |" in content - - def test_description_md_orders(self, synced): - state, output, config = synced - - content = ( - output - / "type=databricks" - / f"database={config.catalog}" - / "schema=public" - / "table=orders" - / "description.md" - ).read_text() - - assert "| **Row Count** | 2 |" in content - assert "| **Column Count** | 3 |" in content - - def test_preview_md_users(self, synced): - state, output, config = synced - - content = ( - output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=users" / "preview.md" - ).read_text() - - assert "# users - Preview" in content - assert "**Dataset:** `public`" in content - assert "## Rows (3)" in content - - # Parse the JSONL rows from the markdown - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 3 - assert rows[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} - assert rows[1] == {"id": 2, "name": "Bob", "email": None, "active": False} - assert rows[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} - - def test_preview_md_orders(self, synced): - state, output, config = synced - - content = ( - output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=orders" / "preview.md" - ).read_text() - - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 2 - assert rows[0] == {"id": 1, "user_id": 1, "amount": 99.99} - assert rows[1] == {"id": 2, "user_id": 1, "amount": 24.5} - - def test_sync_state_tracks_schemas_and_tables(self, synced): - state, output, config = synced - - assert state.schemas_synced == 1 - assert state.tables_synced == 2 - assert "public" in state.synced_schemas - assert "users" in state.synced_tables["public"] - assert "orders" in state.synced_tables["public"] - - def test_include_filter(self, tmp_path_factory, databricks_config): - """Only tables matching include patterns should be synced.""" - config = DatabricksConfig( - name=databricks_config.name, - server_hostname=databricks_config.server_hostname, - http_path=databricks_config.http_path, - access_token=databricks_config.access_token, - catalog=databricks_config.catalog, - schema_name=databricks_config.schema_name, - include=["public.users"], - ) - - output = tmp_path_factory.mktemp("databricks_include") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=databricks" / f"database={config.catalog}" / "schema=public" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_exclude_filter(self, tmp_path_factory, databricks_config): - """Tables matching exclude patterns should be skipped.""" - config = DatabricksConfig( - name=databricks_config.name, - server_hostname=databricks_config.server_hostname, - http_path=databricks_config.http_path, - access_token=databricks_config.access_token, - catalog=databricks_config.catalog, - schema_name=databricks_config.schema_name, - exclude=["public.orders"], - ) - - output = tmp_path_factory.mktemp("databricks_exclude") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=databricks" / f"database={config.catalog}" / "schema=public" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, databricks_config): - """When schema_name is not provided, all schemas should be synced.""" - config = DatabricksConfig( - name=databricks_config.name, - server_hostname=databricks_config.server_hostname, - http_path=databricks_config.http_path, - access_token=databricks_config.access_token, - catalog=databricks_config.catalog, - schema_name=None, - ) - - output = tmp_path_factory.mktemp("databricks_all_schemas") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - # Verify public schema tables - assert (output / "type=databricks" / f"database={config.catalog}" / "schema=public").is_dir() - assert (output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=users").is_dir() - assert (output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=orders").is_dir() - - # Verify public.users files - files = sorted( - f.name - for f in ( - output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=users" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify public.orders files - files = sorted( - f.name - for f in ( - output / "type=databricks" / f"database={config.catalog}" / "schema=public" / "table=orders" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify another schema table - assert (output / "type=databricks" / f"database={config.catalog}" / "schema=another").is_dir() - assert ( - output / "type=databricks" / f"database={config.catalog}" / "schema=another" / "table=whatever" - ).is_dir() - - # Verify another.whatever files - files = sorted( - f.name - for f in ( - output / "type=databricks" / f"database={config.catalog}" / "schema=another" / "table=whatever" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify state - assert state.schemas_synced == 2 - assert state.tables_synced == 3 - assert "public" in state.synced_schemas - assert "another" in state.synced_schemas - assert "users" in state.synced_tables["public"] - assert "orders" in state.synced_tables["public"] - assert "whatever" in state.synced_tables["another"] diff --git a/cli/tests/nao_core/commands/sync/integration/test_duckdb.py b/cli/tests/nao_core/commands/sync/integration/test_duckdb.py index 6d20f3ba..63966edc 100644 --- a/cli/tests/nao_core/commands/sync/integration/test_duckdb.py +++ b/cli/tests/nao_core/commands/sync/integration/test_duckdb.py @@ -1,19 +1,17 @@ """Integration tests for the database sync pipeline using a real DuckDB database.""" -import json - import duckdb import pytest -from rich.progress import Progress -from nao_core.commands.sync.providers.databases.provider import sync_database from nao_core.config.databases.duckdb import DuckDBConfig +from .base import BaseSyncIntegrationTests, SyncTestSpec + -@pytest.fixture -def duckdb_path(tmp_path): +@pytest.fixture(scope="module") +def duckdb_path(tmp_path_factory): """Create a DuckDB database with two tables: users and orders.""" - db_path = tmp_path / "test.duckdb" + db_path = tmp_path_factory.mktemp("duckdb_data") / "test.duckdb" conn = duckdb.connect(str(db_path)) conn.execute(""" @@ -48,172 +46,45 @@ def duckdb_path(tmp_path): return db_path -class TestDuckDBSyncIntegration: - def _sync(self, duckdb_path, output_path): - config = DuckDBConfig(name="test-db", path=str(duckdb_path)) - - with Progress(transient=True) as progress: - state = sync_database(config, output_path, progress) - - return state - - def test_creates_expected_directory_tree(self, tmp_path, duckdb_path): - output = tmp_path / "output" - self._sync(duckdb_path, output) - - base = output / "type=duckdb" / "database=test" - - # Schema directory - assert (base / "schema=main").is_dir() - - # Table directories - assert (base / "schema=main" / "table=users").is_dir() - assert (base / "schema=main" / "table=orders").is_dir() - - # Each table should have exactly the 3 default template outputs - for table in ("users", "orders"): - table_dir = base / "schema=main" / f"table={table}" - files = sorted(f.name for f in table_dir.iterdir()) - assert files == ["columns.md", "description.md", "preview.md"] - - def test_columns_md_users(self, tmp_path, duckdb_path): - output = tmp_path / "output" - self._sync(duckdb_path, output) - - content = (output / "type=duckdb" / "database=test" / "schema=main" / "table=users" / "columns.md").read_text() - - # NOT NULL columns are prefixed with ! by Ibis (e.g. !int32) - assert content == ( - "# users\n" - "\n" - "**Dataset:** `main`\n" - "\n" - "## Columns (4)\n" - "\n" - "- id (int32 NOT NULL)\n" - "- name (string NOT NULL)\n" - "- email (string)\n" - "- active (boolean)\n" - ) - - def test_columns_md_orders(self, tmp_path, duckdb_path): - output = tmp_path / "output" - self._sync(duckdb_path, output) - - content = (output / "type=duckdb" / "database=test" / "schema=main" / "table=orders" / "columns.md").read_text() - - assert content == ( - "# orders\n" - "\n" - "**Dataset:** `main`\n" - "\n" - "## Columns (3)\n" - "\n" - "- id (int32 NOT NULL)\n" - "- user_id (int32 NOT NULL)\n" - "- amount (float64 NOT NULL)\n" - ) - - def test_description_md_users(self, tmp_path, duckdb_path): - output = tmp_path / "output" - self._sync(duckdb_path, output) - - content = ( - output / "type=duckdb" / "database=test" / "schema=main" / "table=users" / "description.md" - ).read_text() - - assert content == ( - "# users\n" - "\n" - "**Dataset:** `main`\n" - "\n" - "## Table Metadata\n" - "\n" - "| Property | Value |\n" - "|----------|-------|\n" - "| **Row Count** | 3 |\n" - "| **Column Count** | 4 |\n" - "\n" - "## Description\n" - "\n" - "_No description available._\n" - ) - - def test_description_md_orders(self, tmp_path, duckdb_path): - output = tmp_path / "output" - self._sync(duckdb_path, output) - - content = ( - output / "type=duckdb" / "database=test" / "schema=main" / "table=orders" / "description.md" - ).read_text() - - assert "| **Row Count** | 2 |" in content - assert "| **Column Count** | 3 |" in content - - def test_preview_md_users(self, tmp_path, duckdb_path): - output = tmp_path / "output" - self._sync(duckdb_path, output) - - content = (output / "type=duckdb" / "database=test" / "schema=main" / "table=users" / "preview.md").read_text() - - assert "# users - Preview" in content - assert "**Dataset:** `main`" in content - assert "## Rows (3)" in content - - # Parse the JSONL rows from the markdown - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 3 - assert rows[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} - assert rows[1] == {"id": 2, "name": "Bob", "email": None, "active": False} - assert rows[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} - - def test_preview_md_orders(self, tmp_path, duckdb_path): - output = tmp_path / "output" - self._sync(duckdb_path, output) - - content = (output / "type=duckdb" / "database=test" / "schema=main" / "table=orders" / "preview.md").read_text() - - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 2 - assert rows[0] == {"id": 1, "user_id": 1, "amount": 99.99} - assert rows[1] == {"id": 2, "user_id": 1, "amount": 24.5} - - def test_sync_state_tracks_schemas_and_tables(self, tmp_path, duckdb_path): - output = tmp_path / "output" - state = self._sync(duckdb_path, output) - - assert state.schemas_synced == 1 - assert state.tables_synced == 2 - assert "main" in state.synced_schemas - assert "users" in state.synced_tables["main"] - assert "orders" in state.synced_tables["main"] - - def test_include_filter(self, tmp_path, duckdb_path): - """Only tables matching include patterns should be synced.""" - config = DuckDBConfig(name="test-db", path=str(duckdb_path), include=["main.users"]) - - output = tmp_path / "output" - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=duckdb" / "database=test" / "schema=main" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_exclude_filter(self, tmp_path, duckdb_path): - """Tables matching exclude patterns should be skipped.""" - config = DuckDBConfig(name="test-db", path=str(duckdb_path), exclude=["main.orders"]) - - output = tmp_path / "output" - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=duckdb" / "database=test" / "schema=main" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 +@pytest.fixture(scope="module") +def db_config(duckdb_path): + """Build a DuckDBConfig pointing at the temporary database.""" + return DuckDBConfig(name="test-db", path=str(duckdb_path)) + + +@pytest.fixture(scope="module") +def spec(): + return SyncTestSpec( + db_type="duckdb", + primary_schema="main", + users_column_assertions=( + "# users", + "**Dataset:** `main`", + "## Columns (4)", + "- id (int32 NOT NULL)", + "- name (string NOT NULL)", + "- email (string)", + "- active (boolean)", + ), + orders_column_assertions=( + "# orders", + "**Dataset:** `main`", + "## Columns (3)", + "- id (int32 NOT NULL)", + "- user_id (int32 NOT NULL)", + "- amount (float64 NOT NULL)", + ), + users_preview_rows=[ + {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True}, + {"id": 2, "name": "Bob", "email": None, "active": False}, + {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True}, + ], + orders_preview_rows=[ + {"id": 1, "user_id": 1, "amount": 99.99}, + {"id": 2, "user_id": 1, "amount": 24.5}, + ], + ) + + +class TestDuckDBSyncIntegration(BaseSyncIntegrationTests): + """Verify the sync pipeline produces correct output against a local DuckDB database.""" diff --git a/cli/tests/nao_core/commands/sync/integration/test_postgres.py b/cli/tests/nao_core/commands/sync/integration/test_postgres.py index 1fa1e780..b9f4aeb7 100644 --- a/cli/tests/nao_core/commands/sync/integration/test_postgres.py +++ b/cli/tests/nao_core/commands/sync/integration/test_postgres.py @@ -8,18 +8,17 @@ The test suite is skipped entirely when POSTGRES_HOST is not set. """ -import json import os import uuid from pathlib import Path import ibis import pytest -from rich.progress import Progress -from nao_core.commands.sync.providers.databases.provider import sync_database from nao_core.config.databases.postgres import PostgresConfig +from .base import BaseSyncIntegrationTests, SyncTestSpec + POSTGRES_HOST = os.environ.get("POSTGRES_HOST") pytestmark = pytest.mark.skipif( @@ -98,7 +97,7 @@ def temp_database(): @pytest.fixture(scope="module") -def postgres_config(temp_database): +def db_config(temp_database): """Build a PostgresConfig from environment variables using the temporary database.""" return PostgresConfig( name="test-postgres", @@ -112,244 +111,41 @@ def postgres_config(temp_database): @pytest.fixture(scope="module") -def synced(tmp_path_factory, postgres_config): - """Run sync once for the whole module and return (state, output_path, config).""" - output = tmp_path_factory.mktemp("postgres_sync") - - with Progress(transient=True) as progress: - state = sync_database(postgres_config, output, progress) - - return state, output, postgres_config +def spec(): + return SyncTestSpec( + db_type="postgres", + primary_schema="public", + users_column_assertions=( + "# users", + "**Dataset:** `public`", + "## Columns (4)", + "- id", + "- name", + "- email", + "- active", + ), + orders_column_assertions=( + "# orders", + "**Dataset:** `public`", + "## Columns (3)", + "- id", + "- user_id", + "- amount", + ), + users_preview_rows=[ + {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True}, + {"id": 2, "name": "Bob", "email": None, "active": False}, + {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True}, + ], + orders_preview_rows=[ + {"id": 1, "user_id": 1, "amount": 99.99}, + {"id": 2, "user_id": 1, "amount": 24.5}, + ], + schema_field="schema_name", + another_schema="another", + another_table="whatever", + ) -class TestPostgresSyncIntegration: +class TestPostgresSyncIntegration(BaseSyncIntegrationTests): """Verify the sync pipeline produces correct output against a live Postgres database.""" - - def test_creates_expected_directory_tree(self, synced): - state, output, config = synced - - base = output / "type=postgres" / f"database={config.database}" / "schema=public" - - # Schema directory - assert base.is_dir() - - # Each table should have exactly the 3 default template outputs - for table in ("orders", "users"): - assert (base / f"table={table}").is_dir() - table_dir = base / f"table={table}" - files = sorted(f.name for f in table_dir.iterdir()) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify that the "another" schema was NOT synced - another_schema_dir = output / "type=postgres" / f"database={config.database}" / "schema=another" - assert not another_schema_dir.exists() - - def test_columns_md_users(self, synced): - state, output, config = synced - - content = ( - output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=users" / "columns.md" - ).read_text() - - assert "# users" in content - assert "**Dataset:** `public`" in content - assert "## Columns (4)" in content - assert "- id" in content - assert "- name" in content - assert "- email" in content - assert "- active" in content - - def test_columns_md_orders(self, synced): - state, output, config = synced - - content = ( - output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=orders" / "columns.md" - ).read_text() - - assert "# orders" in content - assert "**Dataset:** `public`" in content - assert "## Columns (3)" in content - assert "- id" in content - assert "- user_id" in content - assert "- amount" in content - - def test_description_md_users(self, synced): - state, output, config = synced - - content = ( - output - / "type=postgres" - / f"database={config.database}" - / "schema=public" - / "table=users" - / "description.md" - ).read_text() - - assert "# users" in content - assert "**Dataset:** `public`" in content - assert "## Table Metadata" in content - assert "| **Row Count** | 3 |" in content - assert "| **Column Count** | 4 |" in content - - def test_description_md_orders(self, synced): - state, output, config = synced - - content = ( - output - / "type=postgres" - / f"database={config.database}" - / "schema=public" - / "table=orders" - / "description.md" - ).read_text() - - assert "| **Row Count** | 2 |" in content - assert "| **Column Count** | 3 |" in content - - def test_preview_md_users(self, synced): - state, output, config = synced - - content = ( - output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=users" / "preview.md" - ).read_text() - - assert "# users - Preview" in content - assert "**Dataset:** `public`" in content - assert "## Rows (3)" in content - - # Parse the JSONL rows from the markdown - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 3 - assert rows[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} - assert rows[1] == {"id": 2, "name": "Bob", "email": None, "active": False} - assert rows[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} - - def test_preview_md_orders(self, synced): - state, output, config = synced - - content = ( - output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=orders" / "preview.md" - ).read_text() - - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 2 - assert rows[0] == {"id": 1, "user_id": 1, "amount": 99.99} - assert rows[1] == {"id": 2, "user_id": 1, "amount": 24.5} - - def test_sync_state_tracks_schemas_and_tables(self, synced): - state, output, config = synced - - assert state.schemas_synced == 1 - assert state.tables_synced == 2 - assert "public" in state.synced_schemas - assert "users" in state.synced_tables["public"] - assert "orders" in state.synced_tables["public"] - - def test_include_filter(self, tmp_path_factory, postgres_config): - """Only tables matching include patterns should be synced.""" - config = PostgresConfig( - name=postgres_config.name, - host=postgres_config.host, - port=postgres_config.port, - database=postgres_config.database, - user=postgres_config.user, - password=postgres_config.password, - schema_name=postgres_config.schema_name, - include=["public.users"], - ) - - output = tmp_path_factory.mktemp("postgres_include") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=postgres" / f"database={config.database}" / "schema=public" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_exclude_filter(self, tmp_path_factory, postgres_config): - """Tables matching exclude patterns should be skipped.""" - config = PostgresConfig( - name=postgres_config.name, - host=postgres_config.host, - port=postgres_config.port, - database=postgres_config.database, - user=postgres_config.user, - password=postgres_config.password, - schema_name=postgres_config.schema_name, - exclude=["public.orders"], - ) - - output = tmp_path_factory.mktemp("postgres_exclude") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=postgres" / f"database={config.database}" / "schema=public" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, postgres_config): - """When schema_name is not provided, all schemas should be synced.""" - config = PostgresConfig( - name=postgres_config.name, - host=postgres_config.host, - port=postgres_config.port, - database=postgres_config.database, - user=postgres_config.user, - password=postgres_config.password, - schema_name=None, - ) - - output = tmp_path_factory.mktemp("postgres_all_schemas") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - # Verify public schema tables - assert (output / "type=postgres" / f"database={config.database}" / "schema=public").is_dir() - assert (output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=users").is_dir() - assert (output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=orders").is_dir() - - # Verify public.users files - files = sorted( - f.name - for f in ( - output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=users" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify public.orders files - files = sorted( - f.name - for f in ( - output / "type=postgres" / f"database={config.database}" / "schema=public" / "table=orders" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify another schema table - assert (output / "type=postgres" / f"database={config.database}" / "schema=another").is_dir() - assert (output / "type=postgres" / f"database={config.database}" / "schema=another" / "table=whatever").is_dir() - - # Verify another.whatever files - files = sorted( - f.name - for f in ( - output / "type=postgres" / f"database={config.database}" / "schema=another" / "table=whatever" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify state - assert state.schemas_synced == 2 - assert state.tables_synced == 3 - assert "public" in state.synced_schemas - assert "another" in state.synced_schemas - assert "users" in state.synced_tables["public"] - assert "orders" in state.synced_tables["public"] - assert "whatever" in state.synced_tables["another"] diff --git a/cli/tests/nao_core/commands/sync/integration/test_redshift.py b/cli/tests/nao_core/commands/sync/integration/test_redshift.py index eb1b1f55..c26707bf 100644 --- a/cli/tests/nao_core/commands/sync/integration/test_redshift.py +++ b/cli/tests/nao_core/commands/sync/integration/test_redshift.py @@ -8,27 +8,23 @@ The test suite is skipped entirely when REDSHIFT_HOST is not set. """ -import json import os import uuid from pathlib import Path import ibis import pytest -from rich.progress import Progress -from nao_core.commands.sync.providers.databases.provider import sync_database from nao_core.config.databases.redshift import RedshiftConfig +from .base import BaseSyncIntegrationTests, SyncTestSpec + REDSHIFT_HOST = os.environ.get("REDSHIFT_HOST") pytestmark = pytest.mark.skipif( REDSHIFT_HOST is None, reason="REDSHIFT_HOST not set — skipping Redshift integration tests" ) -# ibis uses pg_catalog.pg_enum which Redshift does not support -KNOWN_ERROR = "pg_catalog.pg_enum" - @pytest.fixture(scope="module") def temp_database(): @@ -88,7 +84,7 @@ def temp_database(): @pytest.fixture(scope="module") -def redshift_config(temp_database): +def db_config(temp_database): """Build a RedshiftConfig from environment variables using the temporary database.""" return RedshiftConfig( name="test-redshift", @@ -103,269 +99,41 @@ def redshift_config(temp_database): @pytest.fixture(scope="module") -def synced(tmp_path_factory, redshift_config): - """Run sync once for the whole module and return (state, output_path, config).""" - output = tmp_path_factory.mktemp("redshift_sync") - - with Progress(transient=True) as progress: - state = sync_database(redshift_config, output, progress) - - return state, output, redshift_config +def spec(): + return SyncTestSpec( + db_type="redshift", + primary_schema="public", + users_column_assertions=( + "# users", + "**Dataset:** `public`", + "## Columns (4)", + "- id (int32 NOT NULL)", + "- name (string NOT NULL)", + "- email (string)", + "- active (boolean)", + ), + orders_column_assertions=( + "# orders", + "**Dataset:** `public`", + "## Columns (3)", + "- id (int32 NOT NULL)", + "- user_id (int32 NOT NULL)", + "- amount (float32 NOT NULL)", + ), + users_preview_rows=[ + {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True}, + {"id": 2, "name": "Bob", "email": None, "active": False}, + {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True}, + ], + orders_preview_rows=[ + {"id": 1, "user_id": 1, "amount": 99.99}, + {"id": 2, "user_id": 1, "amount": 24.5}, + ], + schema_field="schema_name", + another_schema="another", + another_table="whatever", + ) -class TestRedshiftSyncIntegration: +class TestRedshiftSyncIntegration(BaseSyncIntegrationTests): """Verify the sync pipeline produces correct output against a live Redshift cluster.""" - - def test_creates_expected_directory_tree(self, synced): - state, output, config = synced - - base = output / "type=redshift" / f"database={config.database}" / "schema=public" - - # Schema directory - assert base.is_dir() - - # Each table should have exactly the 3 default template outputs - for table in ("orders", "users"): - assert (base / f"table={table}").is_dir() - table_dir = base / f"table={table}" - files = sorted(f.name for f in table_dir.iterdir()) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify that the "another" schema was NOT synced - another_schema_dir = output / "type=redshift" / f"database={config.database}" / "schema=another" - assert not another_schema_dir.exists() - - def test_columns_md_users(self, synced): - state, output, config = synced - - content = ( - output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=users" / "columns.md" - ).read_text() - - # NOT NULL columns are prefixed with ! by Ibis (e.g. !int32) - assert content == ( - "# users\n" - "\n" - "**Dataset:** `public`\n" - "\n" - "## Columns (4)\n" - "\n" - "- id (int32 NOT NULL)\n" - "- name (string NOT NULL)\n" - "- email (string)\n" - "- active (boolean)\n" - ) - - def test_columns_md_orders(self, synced): - state, output, config = synced - - content = ( - output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=orders" / "columns.md" - ).read_text() - - assert content == ( - "# orders\n" - "\n" - "**Dataset:** `public`\n" - "\n" - "## Columns (3)\n" - "\n" - "- id (int32 NOT NULL)\n" - "- user_id (int32 NOT NULL)\n" - "- amount (float32 NOT NULL)\n" - ) - - def test_description_md_users(self, synced): - state, output, config = synced - - content = ( - output - / "type=redshift" - / f"database={config.database}" - / "schema=public" - / "table=users" - / "description.md" - ).read_text() - - assert content == ( - "# users\n" - "\n" - "**Dataset:** `public`\n" - "\n" - "## Table Metadata\n" - "\n" - "| Property | Value |\n" - "|----------|-------|\n" - "| **Row Count** | 3 |\n" - "| **Column Count** | 4 |\n" - "\n" - "## Description\n" - "\n" - "_No description available._\n" - ) - - def test_description_md_orders(self, synced): - state, output, config = synced - - content = ( - output - / "type=redshift" - / f"database={config.database}" - / "schema=public" - / "table=orders" - / "description.md" - ).read_text() - - assert "| **Row Count** | 2 |" in content - assert "| **Column Count** | 3 |" in content - - def test_preview_md_users(self, synced): - state, output, config = synced - - content = ( - output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=users" / "preview.md" - ).read_text() - - assert "# users - Preview" in content - assert "**Dataset:** `public`" in content - assert "## Rows (3)" in content - - # Parse the JSONL rows from the markdown - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 3 - assert rows[0] == {"id": 1, "name": "Alice", "email": "alice@example.com", "active": True} - assert rows[1] == {"id": 2, "name": "Bob", "email": None, "active": False} - assert rows[2] == {"id": 3, "name": "Charlie", "email": "charlie@example.com", "active": True} - - def test_preview_md_orders(self, synced): - state, output, config = synced - - content = ( - output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=orders" / "preview.md" - ).read_text() - - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 2 - assert rows[0] == {"id": 1, "user_id": 1, "amount": 99.99} - assert rows[1] == {"id": 2, "user_id": 1, "amount": 24.5} - - def test_sync_state_tracks_schemas_and_tables(self, synced): - state, output, config = synced - - assert state.schemas_synced == 1 - assert state.tables_synced == 2 - assert "public" in state.synced_schemas - assert "users" in state.synced_tables["public"] - assert "orders" in state.synced_tables["public"] - - def test_include_filter(self, tmp_path_factory, redshift_config): - """Only tables matching include patterns should be synced.""" - config = RedshiftConfig( - name=redshift_config.name, - host=redshift_config.host, - port=redshift_config.port, - database=redshift_config.database, - user=redshift_config.user, - password=redshift_config.password, - schema_name=redshift_config.schema_name, - sslmode=redshift_config.sslmode, - include=["public.users"], - ) - - output = tmp_path_factory.mktemp("redshift_include") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=redshift" / f"database={config.database}" / "schema=public" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_exclude_filter(self, tmp_path_factory, redshift_config): - """Tables matching exclude patterns should be skipped.""" - config = RedshiftConfig( - name=redshift_config.name, - host=redshift_config.host, - port=redshift_config.port, - database=redshift_config.database, - user=redshift_config.user, - password=redshift_config.password, - schema_name=redshift_config.schema_name, - sslmode=redshift_config.sslmode, - exclude=["public.orders"], - ) - - output = tmp_path_factory.mktemp("redshift_exclude") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - base = output / "type=redshift" / f"database={config.database}" / "schema=public" - assert (base / "table=users").is_dir() - assert not (base / "table=orders").exists() - assert state.tables_synced == 1 - - def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, redshift_config): - """When schema_name is not provided, all schemas should be synced.""" - config = RedshiftConfig( - name=redshift_config.name, - host=redshift_config.host, - port=redshift_config.port, - database=redshift_config.database, - user=redshift_config.user, - password=redshift_config.password, - schema_name=None, - sslmode=redshift_config.sslmode, - ) - - output = tmp_path_factory.mktemp("redshift_all_schemas") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - # Verify public schema tables - assert (output / "type=redshift" / f"database={config.database}" / "schema=public").is_dir() - assert (output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=users").is_dir() - assert (output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=orders").is_dir() - - # Verify public.users files - files = sorted( - f.name - for f in ( - output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=users" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify public.orders files - files = sorted( - f.name - for f in ( - output / "type=redshift" / f"database={config.database}" / "schema=public" / "table=orders" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify another schema table - assert (output / "type=redshift" / f"database={config.database}" / "schema=another").is_dir() - assert (output / "type=redshift" / f"database={config.database}" / "schema=another" / "table=whatever").is_dir() - - # Verify another.whatever files - files = sorted( - f.name - for f in ( - output / "type=redshift" / f"database={config.database}" / "schema=another" / "table=whatever" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify state - assert state.schemas_synced == 2 - assert state.tables_synced == 3 - assert "public" in state.synced_schemas - assert "another" in state.synced_schemas - assert "users" in state.synced_tables["public"] - assert "orders" in state.synced_tables["public"] - assert "whatever" in state.synced_tables["another"] diff --git a/cli/tests/nao_core/commands/sync/integration/test_snowflake.py b/cli/tests/nao_core/commands/sync/integration/test_snowflake.py index d1242bf5..b54bf6f1 100644 --- a/cli/tests/nao_core/commands/sync/integration/test_snowflake.py +++ b/cli/tests/nao_core/commands/sync/integration/test_snowflake.py @@ -8,7 +8,6 @@ The test suite is skipped entirely when SNOWFLAKE_ACCOUNT_ID is not set. """ -import json import os import uuid from pathlib import Path @@ -17,11 +16,11 @@ import pytest from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -from rich.progress import Progress -from nao_core.commands.sync.providers.databases.provider import sync_database from nao_core.config.databases.snowflake import SnowflakeConfig +from .base import BaseSyncIntegrationTests, SyncTestSpec + SNOWFLAKE_ACCOUNT_ID = os.environ.get("SNOWFLAKE_ACCOUNT_ID") pytestmark = pytest.mark.skipif( @@ -99,7 +98,7 @@ def temp_database(): @pytest.fixture(scope="module") -def snowflake_config(temp_database): +def db_config(temp_database): """Build a SnowflakeConfig from environment variables using the temporary database.""" return SnowflakeConfig( name="test-snowflake", @@ -114,255 +113,46 @@ def snowflake_config(temp_database): @pytest.fixture(scope="module") -def synced(tmp_path_factory, snowflake_config): - """Run sync once for the whole module and return (state, output_path, config).""" - output = tmp_path_factory.mktemp("snowflake_sync") - - with Progress(transient=True) as progress: - state = sync_database(snowflake_config, output, progress) - - return state, output, snowflake_config +def spec(): + return SyncTestSpec( + db_type="snowflake", + primary_schema="PUBLIC", + users_table="USERS", + orders_table="ORDERS", + users_column_assertions=( + "# USERS", + "**Dataset:** `PUBLIC`", + "## Columns (4)", + "- ID", + "- NAME", + "- EMAIL", + "- ACTIVE", + ), + orders_column_assertions=( + "# ORDERS", + "**Dataset:** `PUBLIC`", + "## Columns (3)", + "- ID", + "- USER_ID", + "- AMOUNT", + ), + users_preview_rows=[ + {"ID": 1, "NAME": "Alice", "EMAIL": "alice@example.com", "ACTIVE": True}, + {"ID": 2, "NAME": "Bob", "EMAIL": None, "ACTIVE": False}, + {"ID": 3, "NAME": "Charlie", "EMAIL": "charlie@example.com", "ACTIVE": True}, + ], + orders_preview_rows=[ + {"ID": 1.0, "USER_ID": 1.0, "AMOUNT": 99.99}, + {"ID": 2.0, "USER_ID": 1.0, "AMOUNT": 24.5}, + ], + sort_rows=True, + row_id_key="ID", + filter_schema="public", + schema_field="schema_name", + another_schema="ANOTHER", + another_table="WHATEVER", + ) -class TestSnowflakeSyncIntegration: +class TestSnowflakeSyncIntegration(BaseSyncIntegrationTests): """Verify the sync pipeline produces correct output against a live Snowflake database.""" - - def test_creates_expected_directory_tree(self, synced): - state, output, config = synced - - base = output / "type=snowflake" / f"database={config.database}" / "schema=public" - - # Schema directory - assert base.is_dir() - - # Each table should have exactly the 3 default template outputs - for table in ("orders", "users"): - assert (base / f"table={table}").is_dir() - table_dir = base / f"table={table}" - files = sorted(f.name for f in table_dir.iterdir()) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify that the "another" schema was NOT synced - another_schema_dir = output / "type=snowflake" / f"database={config.database}" / "schema=another" - assert not another_schema_dir.exists() - - def test_columns_md_users(self, synced): - state, output, config = synced - - content = ( - output / "type=snowflake" / f"database={config.database}" / "schema=public" / "table=users" / "columns.md" - ).read_text() - - # Snowflake stores identifiers in uppercase by default - assert "# USERS" in content - assert "**Dataset:** `PUBLIC`" in content - assert "## Columns (4)" in content - assert "- ID" in content - assert "- NAME" in content - assert "- EMAIL" in content - assert "- ACTIVE" in content - - def test_columns_md_orders(self, synced): - state, output, config = synced - - content = ( - output / "type=snowflake" / f"database={config.database}" / "schema=public" / "table=orders" / "columns.md" - ).read_text() - - assert "# ORDERS" in content - assert "**Dataset:** `PUBLIC`" in content - assert "## Columns (3)" in content - assert "- ID" in content - assert "- USER_ID" in content - assert "- AMOUNT" in content - - def test_description_md_users(self, synced): - state, output, config = synced - - content = ( - output - / "type=snowflake" - / f"database={config.database}" - / "schema=public" - / "table=users" - / "description.md" - ).read_text() - - assert "# USERS" in content - assert "**Dataset:** `PUBLIC`" in content - assert "## Table Metadata" in content - assert "| **Row Count** | 3 |" in content - assert "| **Column Count** | 4 |" in content - - def test_description_md_orders(self, synced): - state, output, config = synced - - content = ( - output - / "type=snowflake" - / f"database={config.database}" - / "schema=public" - / "table=orders" - / "description.md" - ).read_text() - - assert "| **Row Count** | 2 |" in content - assert "| **Column Count** | 3 |" in content - - def test_preview_md_users(self, synced): - state, output, config = synced - - content = ( - output / "type=snowflake" / f"database={config.database}" / "schema=public" / "table=users" / "preview.md" - ).read_text() - - assert "# USERS - Preview" in content - assert "**Dataset:** `PUBLIC`" in content - assert "## Rows (3)" in content - - # Parse the JSONL rows from the markdown - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 3 - # Snowflake returns column names in uppercase - assert rows[0] == {"ID": 1, "NAME": "Alice", "EMAIL": "alice@example.com", "ACTIVE": True} - assert rows[1] == {"ID": 2, "NAME": "Bob", "EMAIL": None, "ACTIVE": False} - assert rows[2] == {"ID": 3, "NAME": "Charlie", "EMAIL": "charlie@example.com", "ACTIVE": True} - - def test_preview_md_orders(self, synced): - state, output, config = synced - - content = ( - output / "type=snowflake" / f"database={config.database}" / "schema=public" / "table=orders" / "preview.md" - ).read_text() - - lines = [line for line in content.splitlines() if line.startswith("- {")] - rows = [json.loads(line[2:]) for line in lines] - - assert len(rows) == 2 - # Snowflake returns column names in uppercase and integers as floats - assert rows[0] == {"ID": 1.0, "USER_ID": 1.0, "AMOUNT": 99.99} - assert rows[1] == {"ID": 2.0, "USER_ID": 1.0, "AMOUNT": 24.5} - - def test_sync_state_tracks_schemas_and_tables(self, synced): - state, output, config = synced - - assert state.schemas_synced == 1 - assert state.tables_synced == 2 - # Snowflake stores schema and table names in uppercase - assert "PUBLIC" in state.synced_schemas - assert "USERS" in state.synced_tables["PUBLIC"] - assert "ORDERS" in state.synced_tables["PUBLIC"] - - def test_include_filter(self, tmp_path_factory, snowflake_config): - """Only tables matching include patterns should be synced.""" - config = SnowflakeConfig( - name=snowflake_config.name, - account_id=snowflake_config.account_id, - username=snowflake_config.username, - database=snowflake_config.database, - private_key_path=snowflake_config.private_key_path, - passphrase=snowflake_config.passphrase, - schema_name=snowflake_config.schema_name, - warehouse=snowflake_config.warehouse, - include=["public.users"], - ) - - output = tmp_path_factory.mktemp("snowflake_include") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - # Snowflake uses uppercase names for schemas and tables - base = output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" - assert (base / "table=USERS").is_dir() - assert not (base / "table=ORDERS").exists() - assert state.tables_synced == 1 - - def test_exclude_filter(self, tmp_path_factory, snowflake_config): - """Tables matching exclude patterns should be skipped.""" - config = SnowflakeConfig( - name=snowflake_config.name, - account_id=snowflake_config.account_id, - username=snowflake_config.username, - database=snowflake_config.database, - private_key_path=snowflake_config.private_key_path, - passphrase=snowflake_config.passphrase, - schema_name=snowflake_config.schema_name, - warehouse=snowflake_config.warehouse, - exclude=["public.orders"], - ) - - output = tmp_path_factory.mktemp("snowflake_exclude") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - # Snowflake uses uppercase names for schemas and tables - base = output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" - assert (base / "table=USERS").is_dir() - assert not (base / "table=ORDERS").exists() - assert state.tables_synced == 1 - - def test_sync_all_schemas_when_schema_name_not_specified(self, tmp_path_factory, snowflake_config): - """When schema_name is not provided, all schemas should be synced.""" - config = SnowflakeConfig( - name=snowflake_config.name, - account_id=snowflake_config.account_id, - username=snowflake_config.username, - database=snowflake_config.database, - private_key_path=snowflake_config.private_key_path, - passphrase=snowflake_config.passphrase, - schema_name=None, - warehouse=snowflake_config.warehouse, - ) - - output = tmp_path_factory.mktemp("snowflake_all_schemas") - with Progress(transient=True) as progress: - state = sync_database(config, output, progress) - - # Verify PUBLIC schema tables (Snowflake uses uppercase names) - assert (output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC").is_dir() - assert (output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" / "table=USERS").is_dir() - assert (output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" / "table=ORDERS").is_dir() - - # Verify PUBLIC.USERS files - files = sorted( - f.name - for f in ( - output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" / "table=USERS" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify PUBLIC.ORDERS files - files = sorted( - f.name - for f in ( - output / "type=snowflake" / f"database={config.database}" / "schema=PUBLIC" / "table=ORDERS" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify ANOTHER schema table - assert (output / "type=snowflake" / f"database={config.database}" / "schema=ANOTHER").is_dir() - assert ( - output / "type=snowflake" / f"database={config.database}" / "schema=ANOTHER" / "table=WHATEVER" - ).is_dir() - - # Verify ANOTHER.WHATEVER files - files = sorted( - f.name - for f in ( - output / "type=snowflake" / f"database={config.database}" / "schema=ANOTHER" / "table=WHATEVER" - ).iterdir() - ) - assert files == ["columns.md", "description.md", "preview.md"] - - # Verify state - assert state.schemas_synced == 2 - assert state.tables_synced == 3 - assert "PUBLIC" in state.synced_schemas - assert "ANOTHER" in state.synced_schemas - assert "USERS" in state.synced_tables["PUBLIC"] - assert "ORDERS" in state.synced_tables["PUBLIC"] - assert "WHATEVER" in state.synced_tables["ANOTHER"] From dc9009d4440508c2974848b6c5ff67aee797f9e8 Mon Sep 17 00:00:00 2001 From: Christophe Blefari Date: Wed, 11 Feb 2026 00:13:44 +0100 Subject: [PATCH 4/7] Fix databricks certs --- cli/nao_core/config/databases/databricks.py | 7 +++++++ example/nao_config.yaml | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cli/nao_core/config/databases/databricks.py b/cli/nao_core/config/databases/databricks.py index 005d364d..85a56b6f 100644 --- a/cli/nao_core/config/databases/databricks.py +++ b/cli/nao_core/config/databases/databricks.py @@ -1,5 +1,7 @@ +import os from typing import Literal +import certifi import ibis from ibis import BaseBackend from pydantic import Field @@ -8,6 +10,11 @@ from .base import DatabaseConfig +# Ensure Python uses certifi's CA bundle for SSL verification. +# This fixes "certificate verify failed" errors when Python's default CA path is empty. +os.environ.setdefault("SSL_CERT_FILE", certifi.where()) +os.environ.setdefault("REQUESTS_CA_BUNDLE", certifi.where()) + class DatabricksConfig(DatabaseConfig): """Databricks-specific configuration.""" diff --git a/example/nao_config.yaml b/example/nao_config.yaml index 208cfcc7..584c8115 100644 --- a/example/nao_config.yaml +++ b/example/nao_config.yaml @@ -10,7 +10,7 @@ repos: url: https://github.com/dbt-labs/jaffle_shop_duckdb.git branch: null notion: - api_key: yolo + api_key: {{ env('NOTION_API_KEY') }} pages: - https://naolabs.notion.site/Jaffle-shop-information-2f8c7a70bc0680a4b7d0caf99f055360 llm: null From c86156534b55ca86870b4e09cee0397f741e8f42 Mon Sep 17 00:00:00 2001 From: Christophe Blefari Date: Wed, 11 Feb 2026 00:33:55 +0100 Subject: [PATCH 5/7] Fix databricks tests --- cli/nao_core/config/databases/databricks.py | 2 -- cli/tests/nao_core/commands/sync/integration/dml/databricks.sql | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/cli/nao_core/config/databases/databricks.py b/cli/nao_core/config/databases/databricks.py index 85a56b6f..6d96f01f 100644 --- a/cli/nao_core/config/databases/databricks.py +++ b/cli/nao_core/config/databases/databricks.py @@ -26,8 +26,6 @@ class DatabricksConfig(DatabaseConfig): catalog: str | None = Field(default=None, description="Unity Catalog name (optional)") schema_name: str | None = Field( default=None, - validation_alias="schema", - serialization_alias="schema", description="Default schema (optional)", ) diff --git a/cli/tests/nao_core/commands/sync/integration/dml/databricks.sql b/cli/tests/nao_core/commands/sync/integration/dml/databricks.sql index 8397bd84..f3aa85d4 100644 --- a/cli/tests/nao_core/commands/sync/integration/dml/databricks.sql +++ b/cli/tests/nao_core/commands/sync/integration/dml/databricks.sql @@ -2,7 +2,7 @@ CREATE TABLE {catalog}.public.users ( id INTEGER NOT NULL, name STRING NOT NULL, email STRING, -active BOOLEAN DEFAULT TRUE +active BOOLEAN ); INSERT INTO {catalog}.public.users VALUES From bda49e6fa40096f4cb9c8424b5c1321abc0cc429 Mon Sep 17 00:00:00 2001 From: Christophe Blefari Date: Wed, 11 Feb 2026 10:15:11 +0100 Subject: [PATCH 6/7] Don't create UDFs in Snowflake --- cli/nao_core/config/databases/snowflake.py | 2 +- cli/tests/nao_core/commands/sync/integration/test_snowflake.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cli/nao_core/config/databases/snowflake.py b/cli/nao_core/config/databases/snowflake.py index 93d0ace8..b42987a2 100644 --- a/cli/nao_core/config/databases/snowflake.py +++ b/cli/nao_core/config/databases/snowflake.py @@ -95,7 +95,7 @@ def connect(self) -> BaseBackend: ) kwargs["password"] = self.password - return ibis.snowflake.connect(**kwargs) + return ibis.snowflake.connect(**kwargs, create_object_udfs=False) def get_database_name(self) -> str: """Get the database name for Snowflake.""" diff --git a/cli/tests/nao_core/commands/sync/integration/test_snowflake.py b/cli/tests/nao_core/commands/sync/integration/test_snowflake.py index b54bf6f1..04eea774 100644 --- a/cli/tests/nao_core/commands/sync/integration/test_snowflake.py +++ b/cli/tests/nao_core/commands/sync/integration/test_snowflake.py @@ -69,6 +69,7 @@ def temp_database(): private_key=private_key_bytes, warehouse=os.environ.get("SNOWFLAKE_WAREHOUSE"), database=db_name, + create_object_udfs=False, ) # Create schema From e747d6d3e09b8cd22d5a69e346b4fc901c939685 Mon Sep 17 00:00:00 2001 From: Christophe Blefari Date: Wed, 11 Feb 2026 10:19:17 +0100 Subject: [PATCH 7/7] Respect Claude.md after review --- cli/nao_core/config/databases/snowflake.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/cli/nao_core/config/databases/snowflake.py b/cli/nao_core/config/databases/snowflake.py index b42987a2..d87b8d12 100644 --- a/cli/nao_core/config/databases/snowflake.py +++ b/cli/nao_core/config/databases/snowflake.py @@ -105,13 +105,6 @@ def matches_pattern(self, schema: str, table: str) -> bool: """Check if a schema.table matches the include/exclude patterns. Snowflake identifier matching is case-insensitive. - - Args: - schema: The schema name (uppercase from Snowflake) - table: The table name (uppercase from Snowflake) - - Returns: - True if the table should be included, False if excluded """ from fnmatch import fnmatch