Skip to content

Commit 3c1511d

Browse files
authored
Feat: table-diff, easily diff two tables. (#907)
* Feat: table-diff, easily diff two tables. * pr feedback * Fix: test no longer creates duckdb file. * add cli and magics
1 parent f4dec46 commit 3c1511d

File tree

9 files changed

+541
-21
lines changed

9 files changed

+541
-21
lines changed

sqlmesh/cli/main.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,54 @@ def create_external_models(obj: Context) -> None:
392392
obj.create_external_models()
393393

394394

395+
@cli.command("table_diff")
396+
@click.option(
397+
"--source",
398+
"-s",
399+
type=str,
400+
required=True,
401+
help="The source environment or table.",
402+
)
403+
@click.option(
404+
"--target",
405+
"-t",
406+
type=str,
407+
required=True,
408+
help="The target environment or table.",
409+
)
410+
@click.option(
411+
"--on",
412+
type=str,
413+
nargs="+",
414+
required=True,
415+
help='The SQL join condition or list of columns to use as keys. Table aliases must be "s" and "t" for source and target.',
416+
)
417+
@click.option(
418+
"--model",
419+
type=str,
420+
help="The model to diff against when source and target are environments and not tables.",
421+
)
422+
@click.option(
423+
"--where",
424+
type=str,
425+
help="An optional where statement to filter results.",
426+
)
427+
@click.option(
428+
"--limit",
429+
type=int,
430+
help="The limit of the sample dataframe.",
431+
)
432+
@click.pass_obj
433+
@error_handler
434+
def table_diff(obj: Context, **kwargs: t.Any) -> None:
435+
"""Show the diff between two tables.
436+
437+
Can either be two tables or two environments and a model.
438+
"""
439+
kwargs["model_or_snapshot"] = kwargs.pop("model", None)
440+
obj.table_diff(**kwargs)
441+
442+
395443
@cli.command("prompt")
396444
@click.argument("prompt")
397445
@click.option(

sqlmesh/core/config/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def load_config_from_paths(
3232

3333
extension = path.name.split(".")[-1].lower()
3434
if extension in ("yml", "yaml"):
35+
if config_name != "config":
36+
raise ConfigError(
37+
f"YAML configs do not support multiple configs. Use Python instead."
38+
)
3539
next_config = load_config_from_yaml(path)
3640
elif extension == "py":
3741
next_config = load_config_from_python_module(path, config_name=config_name)

sqlmesh/core/console.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from sqlmesh.core.context_diff import ContextDiff
2626
from sqlmesh.core.plan import Plan
27+
from sqlmesh.core.table_diff import RowDiff, SchemaDiff
2728

2829
LayoutWidget = t.TypeVar("LayoutWidget", bound=t.Union[widgets.VBox, widgets.HBox])
2930

@@ -122,6 +123,14 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
122123
def loading_stop(self, id: uuid.UUID) -> None:
123124
"""Stop loading for the given id"""
124125

126+
@abc.abstractmethod
127+
def show_schema_diff(self, schema_diff: SchemaDiff) -> None:
128+
"""Show table schema diff"""
129+
130+
@abc.abstractmethod
131+
def show_row_diff(self, row_diff: RowDiff) -> None:
132+
"""Show table summary diff"""
133+
125134

126135
class TerminalConsole(Console):
127136
"""A rich based implementation of the console"""
@@ -449,6 +458,35 @@ def loading_stop(self, id: uuid.UUID) -> None:
449458
self.loading_status[id].stop()
450459
del self.loading_status[id]
451460

461+
def show_schema_diff(self, schema_diff: SchemaDiff) -> None:
462+
tree = Tree(f"[bold]Schema Diff Between '{schema_diff.source}' and '{schema_diff.target}':")
463+
464+
if schema_diff.added:
465+
added = Tree("[green]Added Columns:")
466+
for c, t in schema_diff.added:
467+
added.add(f"[green]{c} ({t})")
468+
tree.add(added)
469+
470+
if schema_diff.removed:
471+
removed = Tree("[red]Removed Columns:")
472+
for c, t in schema_diff.removed:
473+
removed.add(f"[red]{c} ({t})")
474+
tree.add(removed)
475+
476+
if schema_diff.modified:
477+
modified = Tree("[magenta]Modified Columns:")
478+
for c, (ft, tt) in schema_diff.modified.items():
479+
modified.add(f"[magenta]{c} ({ft} -> {tt})")
480+
tree.add(modified)
481+
482+
self.console.print(tree)
483+
484+
def show_row_diff(self, row_diff: RowDiff) -> None:
485+
self.console.print(
486+
f"[bold]Row Count:[/bold] {row_diff.source}: {row_diff.source_count}, {row_diff.target}: {row_diff.target_count} -- {row_diff.count_pct_change}%"
487+
)
488+
self.console.print(row_diff.sample.to_string(index=False))
489+
452490
def _get_snapshot_change_category(
453491
self, snapshot: Snapshot, plan: Plan, auto_apply: bool
454492
) -> None:

sqlmesh/core/context.py

Lines changed: 141 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,23 @@
6666
to_table_mapping,
6767
)
6868
from sqlmesh.core.state_sync import StateReader, StateSync
69+
from sqlmesh.core.table_diff import TableDiff
6970
from sqlmesh.core.test import get_all_model_tests, run_model_tests, run_tests
7071
from sqlmesh.core.user import User
7172
from sqlmesh.utils import UniqueKeyDict, sys_path
7273
from sqlmesh.utils.dag import DAG
7374
from sqlmesh.utils.date import TimeLike, yesterday_ds
74-
from sqlmesh.utils.errors import ConfigError, MissingDependencyError, PlanError
75+
from sqlmesh.utils.errors import (
76+
ConfigError,
77+
MissingDependencyError,
78+
PlanError,
79+
SQLMeshError,
80+
)
7581
from sqlmesh.utils.jinja import JinjaMacroRegistry
7682

7783
if t.TYPE_CHECKING:
7884
import graphviz
85+
from typing_extensions import Literal
7986

8087
from sqlmesh.core.engine_adapter._typing import DF, PySparkDataFrame, PySparkSession
8188

@@ -252,6 +259,12 @@ def engine_adapter(self) -> EngineAdapter:
252259
"""Returns an engine adapter."""
253260
return self._engine_adapter
254261

262+
def execution_context(self, is_dev: bool = False) -> ExecutionContext:
263+
"""Returns an execution context."""
264+
return ExecutionContext(
265+
engine_adapter=self._engine_adapter, snapshots=self.snapshots, is_dev=is_dev
266+
)
267+
255268
def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:
256269
"""Update or insert a model.
257270
@@ -371,9 +384,76 @@ def run(
371384
if not skip_janitor and environment.lower() == c.PROD:
372385
self._run_janitor()
373386

374-
def get_model(self, name: str) -> t.Optional[Model]:
375-
"""Returns a model with the given name or None if a model with such name doesn't exist."""
376-
return self._models.get(name)
387+
@t.overload
388+
def get_model(
389+
self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[True] = True
390+
) -> Model:
391+
...
392+
393+
@t.overload
394+
def get_model(
395+
self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[False] = False
396+
) -> t.Optional[Model]:
397+
...
398+
399+
def get_model(
400+
self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: bool = False
401+
) -> t.Optional[Model]:
402+
"""Returns a model with the given name or None if a model with such name doesn't exist.
403+
404+
Args:
405+
model_or_snapshot: A model name, model, or snapshot.
406+
raise_if_missing: Raises an error if a model is not found.
407+
408+
Returns:
409+
The expected model.
410+
"""
411+
if isinstance(model_or_snapshot, str):
412+
model = self._models.get(model_or_snapshot)
413+
elif isinstance(model_or_snapshot, Snapshot):
414+
model = model_or_snapshot.model
415+
else:
416+
model = model_or_snapshot
417+
418+
if raise_if_missing and not model:
419+
raise SQLMeshError(f"Cannot find model for '{model_or_snapshot}'")
420+
return model
421+
422+
@t.overload
423+
def get_snapshot(
424+
self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[True]
425+
) -> Snapshot:
426+
...
427+
428+
@t.overload
429+
def get_snapshot(
430+
self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[False]
431+
) -> t.Optional[Snapshot]:
432+
...
433+
434+
def get_snapshot(
435+
self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: bool = False
436+
) -> t.Optional[Snapshot]:
437+
"""Returns a snapshot with the given name or None if a snapshot with such name doesn't exist.
438+
439+
Args:
440+
model_or_snapshot: A model name, model, or snapshot.
441+
raise_if_missing: Raises an error if a snapshot is not found.
442+
443+
Returns:
444+
The expected snapshot.
445+
"""
446+
if isinstance(model_or_snapshot, str):
447+
snapshot = self.snapshots.get(model_or_snapshot)
448+
elif isinstance(model_or_snapshot, Snapshot):
449+
snapshot = model_or_snapshot
450+
else:
451+
snapshot = self.snapshots.get(model_or_snapshot.name)
452+
453+
if raise_if_missing and not snapshot:
454+
raise SQLMeshError(f"Cannot find snapshot for '{model_or_snapshot}'")
455+
456+
return snapshot
377457

378458
def config_for_path(self, path: Path) -> Config:
379459
for config_path, config in self.configs.items():
@@ -478,18 +558,17 @@ def render(
478558
"""
479559
latest = latest or yesterday_ds()
480560

481-
if isinstance(model_or_snapshot, str):
482-
model = self._models[model_or_snapshot]
483-
elif isinstance(model_or_snapshot, Snapshot):
484-
model = model_or_snapshot.model
485-
else:
486-
model = model_or_snapshot
561+
model = self.get_model(model_or_snapshot, raise_if_missing=True)
487562

488563
expand = self.dag.upstream(model.name) if expand is True else expand or []
489564

490565
if model.is_seed:
491-
df = next(model.render(self, start=start, end=end, latest=latest, **kwargs))
492-
return next(pandas_to_sql(df, model.columns_to_types))
566+
df = next(
567+
model.render(
568+
context=self.execution_context(), start=start, end=end, latest=latest, **kwargs
569+
)
570+
)
571+
return next(pandas_to_sql(t.cast(pd.DataFrame, df), model.columns_to_types))
493572

494573
return model.render_query(
495574
start=start,
@@ -520,12 +599,7 @@ def evaluate(
520599
latest: The latest time used for non incremental datasets.
521600
limit: A limit applied to the model.
522601
"""
523-
if isinstance(model_or_snapshot, str):
524-
snapshot = self.snapshots[model_or_snapshot]
525-
elif isinstance(model_or_snapshot, Snapshot):
526-
snapshot = model_or_snapshot
527-
else:
528-
snapshot = self.snapshots[model_or_snapshot.name]
602+
snapshot = self.get_snapshot(model_or_snapshot, raise_if_missing=True)
529603

530604
df = self.snapshot_evaluator.evaluate(
531605
snapshot,
@@ -671,6 +745,55 @@ def diff(self, environment: t.Optional[str] = None, detailed: bool = False) -> N
671745
self._context_diff(environment or c.PROD), detailed
672746
)
673747

748+
def table_diff(
749+
self,
750+
source: str,
751+
target: str,
752+
on: t.List[str] | exp.Condition,
753+
model_or_snapshot: t.Optional[ModelOrSnapshot] = None,
754+
where: t.Optional[str | exp.Condition] = None,
755+
limit: int = 20,
756+
show: bool = True,
757+
) -> TableDiff:
758+
"""Show a diff between two tables.
759+
760+
Args:
761+
source: The source environment or table.
762+
target: The target environment or table.
763+
on: The join condition, table aliases must be "s" and "t" for source and target.
764+
model_or_snapshot: The model or snapshot to use when environments are passed in.
765+
where: An optional where statement to filter results.
766+
limit: The limit of the sample dataframe.
767+
show: Show the table diff in the console.
768+
769+
Returns:
770+
The TableDiff object containing schema and summary differences.
771+
"""
772+
if model_or_snapshot:
773+
model = self.get_model(model_or_snapshot, raise_if_missing=True)
774+
source_env = self.state_reader.get_environment(source)
775+
target_env = self.state_reader.get_environment(target)
776+
777+
if not source_env:
778+
raise SQLMeshError(f"Could not find environment '{source}'")
779+
if not target_env:
780+
raise SQLMeshError(f"Could not find environment '{target}')")
781+
782+
source = next(
783+
snapshot for snapshot in source_env.snapshots if snapshot.name == model.name
784+
).table_name()
785+
target = next(
786+
snapshot for snapshot in target_env.snapshots if snapshot.name == model.name
787+
).table_name()
788+
789+
table_diff = TableDiff(
790+
adapter=self._engine_adapter, source=source, target=target, on=on, where=where
791+
)
792+
if show:
793+
self.console.show_schema_diff(table_diff.schema_diff())
794+
self.console.show_row_diff(table_diff.row_diff())
795+
return table_diff
796+
674797
def get_dag(self, format: str = "svg") -> graphviz.Digraph:
675798
"""Gets a graphviz dag.
676799

sqlmesh/core/engine_adapter/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,29 @@ def execute(
781781
logger.debug(f"Executing SQL:\n{sql}")
782782
self.cursor.execute(sql, **kwargs)
783783

784+
@contextlib.contextmanager
785+
def temp_table(self, query_or_df: QueryOrDF, name: str = "diff") -> t.Iterator[exp.Table]:
786+
"""A context manager for working a temp table.
787+
788+
The table will be created with a random guid and cleaned up after the block.
789+
790+
Args:
791+
query_or_df: The query or df to create a temp table for.
792+
name: The base name of the temp table.
793+
794+
Yields:
795+
The table expression
796+
"""
797+
798+
with self.transaction(TransactionType.DDL):
799+
table = self._get_temp_table(name)
800+
self.ctas(table, query_or_df)
801+
802+
try:
803+
yield table
804+
finally:
805+
self.drop_table(table)
806+
784807
def _create_table_properties(
785808
self,
786809
storage_format: t.Optional[str] = None,

0 commit comments

Comments
 (0)