|
66 | 66 | to_table_mapping, |
67 | 67 | ) |
68 | 68 | from sqlmesh.core.state_sync import StateReader, StateSync |
| 69 | +from sqlmesh.core.table_diff import TableDiff |
69 | 70 | from sqlmesh.core.test import get_all_model_tests, run_model_tests, run_tests |
70 | 71 | from sqlmesh.core.user import User |
71 | 72 | from sqlmesh.utils import UniqueKeyDict, sys_path |
72 | 73 | from sqlmesh.utils.dag import DAG |
73 | 74 | from sqlmesh.utils.date import TimeLike, yesterday_ds |
74 | | -from sqlmesh.utils.errors import ConfigError, MissingDependencyError, PlanError |
| 75 | +from sqlmesh.utils.errors import ( |
| 76 | + ConfigError, |
| 77 | + MissingDependencyError, |
| 78 | + PlanError, |
| 79 | + SQLMeshError, |
| 80 | +) |
75 | 81 | from sqlmesh.utils.jinja import JinjaMacroRegistry |
76 | 82 |
|
77 | 83 | if t.TYPE_CHECKING: |
78 | 84 | import graphviz |
| 85 | + from typing_extensions import Literal |
79 | 86 |
|
80 | 87 | from sqlmesh.core.engine_adapter._typing import DF, PySparkDataFrame, PySparkSession |
81 | 88 |
|
@@ -252,6 +259,12 @@ def engine_adapter(self) -> EngineAdapter: |
252 | 259 | """Returns an engine adapter.""" |
253 | 260 | return self._engine_adapter |
254 | 261 |
|
| 262 | + def execution_context(self, is_dev: bool = False) -> ExecutionContext: |
| 263 | + """Returns an execution context.""" |
| 264 | + return ExecutionContext( |
| 265 | + engine_adapter=self._engine_adapter, snapshots=self.snapshots, is_dev=is_dev |
| 266 | + ) |
| 267 | + |
255 | 268 | def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: |
256 | 269 | """Update or insert a model. |
257 | 270 |
|
@@ -371,9 +384,76 @@ def run( |
371 | 384 | if not skip_janitor and environment.lower() == c.PROD: |
372 | 385 | self._run_janitor() |
373 | 386 |
|
374 | | - def get_model(self, name: str) -> t.Optional[Model]: |
375 | | - """Returns a model with the given name or None if a model with such name doesn't exist.""" |
376 | | - return self._models.get(name) |
| 387 | + @t.overload |
| 388 | + def get_model( |
| 389 | + self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[True] = True |
| 390 | + ) -> Model: |
| 391 | + ... |
| 392 | + |
| 393 | + @t.overload |
| 394 | + def get_model( |
| 395 | + self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[False] = False |
| 396 | + ) -> t.Optional[Model]: |
| 397 | + ... |
| 398 | + |
| 399 | + def get_model( |
| 400 | + self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: bool = False |
| 401 | + ) -> t.Optional[Model]: |
| 402 | + """Returns a model with the given name or None if a model with such name doesn't exist. |
| 403 | +
|
| 404 | + Args: |
| 405 | + model_or_snapshot: A model name, model, or snapshot. |
| 406 | + raise_if_missing: Raises an error if a model is not found. |
| 407 | +
|
| 408 | + Returns: |
| 409 | + The expected model. |
| 410 | + """ |
| 411 | + if isinstance(model_or_snapshot, str): |
| 412 | + model = self._models.get(model_or_snapshot) |
| 413 | + elif isinstance(model_or_snapshot, Snapshot): |
| 414 | + model = model_or_snapshot.model |
| 415 | + else: |
| 416 | + model = model_or_snapshot |
| 417 | + |
| 418 | + if raise_if_missing and not model: |
| 419 | + raise SQLMeshError(f"Cannot find model for '{model_or_snapshot}'") |
| 420 | + return model |
| 421 | + |
| 422 | + @t.overload |
| 423 | + def get_snapshot( |
| 424 | + self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[True] |
| 425 | + ) -> Snapshot: |
| 426 | + ... |
| 427 | + |
| 428 | + @t.overload |
| 429 | + def get_snapshot( |
| 430 | + self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[False] |
| 431 | + ) -> t.Optional[Snapshot]: |
| 432 | + ... |
| 433 | + |
| 434 | + def get_snapshot( |
| 435 | + self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: bool = False |
| 436 | + ) -> t.Optional[Snapshot]: |
| 437 | + """Returns a snapshot with the given name or None if a snapshot with such name doesn't exist. |
| 438 | +
|
| 439 | + Args: |
| 440 | + model_or_snapshot: A model name, model, or snapshot. |
| 441 | + raise_if_missing: Raises an error if a snapshot is not found. |
| 442 | +
|
| 443 | + Returns: |
| 444 | + The expected snapshot. |
| 445 | + """ |
| 446 | + if isinstance(model_or_snapshot, str): |
| 447 | + snapshot = self.snapshots.get(model_or_snapshot) |
| 448 | + elif isinstance(model_or_snapshot, Snapshot): |
| 449 | + snapshot = model_or_snapshot |
| 450 | + else: |
| 451 | + snapshot = self.snapshots.get(model_or_snapshot.name) |
| 452 | + |
| 453 | + if raise_if_missing and not snapshot: |
| 454 | + raise SQLMeshError(f"Cannot find snapshot for '{model_or_snapshot}'") |
| 455 | + |
| 456 | + return snapshot |
377 | 457 |
|
378 | 458 | def config_for_path(self, path: Path) -> Config: |
379 | 459 | for config_path, config in self.configs.items(): |
@@ -478,18 +558,17 @@ def render( |
478 | 558 | """ |
479 | 559 | latest = latest or yesterday_ds() |
480 | 560 |
|
481 | | - if isinstance(model_or_snapshot, str): |
482 | | - model = self._models[model_or_snapshot] |
483 | | - elif isinstance(model_or_snapshot, Snapshot): |
484 | | - model = model_or_snapshot.model |
485 | | - else: |
486 | | - model = model_or_snapshot |
| 561 | + model = self.get_model(model_or_snapshot, raise_if_missing=True) |
487 | 562 |
|
488 | 563 | expand = self.dag.upstream(model.name) if expand is True else expand or [] |
489 | 564 |
|
490 | 565 | if model.is_seed: |
491 | | - df = next(model.render(self, start=start, end=end, latest=latest, **kwargs)) |
492 | | - return next(pandas_to_sql(df, model.columns_to_types)) |
| 566 | + df = next( |
| 567 | + model.render( |
| 568 | + context=self.execution_context(), start=start, end=end, latest=latest, **kwargs |
| 569 | + ) |
| 570 | + ) |
| 571 | + return next(pandas_to_sql(t.cast(pd.DataFrame, df), model.columns_to_types)) |
493 | 572 |
|
494 | 573 | return model.render_query( |
495 | 574 | start=start, |
@@ -520,12 +599,7 @@ def evaluate( |
520 | 599 | latest: The latest time used for non incremental datasets. |
521 | 600 | limit: A limit applied to the model. |
522 | 601 | """ |
523 | | - if isinstance(model_or_snapshot, str): |
524 | | - snapshot = self.snapshots[model_or_snapshot] |
525 | | - elif isinstance(model_or_snapshot, Snapshot): |
526 | | - snapshot = model_or_snapshot |
527 | | - else: |
528 | | - snapshot = self.snapshots[model_or_snapshot.name] |
| 602 | + snapshot = self.get_snapshot(model_or_snapshot, raise_if_missing=True) |
529 | 603 |
|
530 | 604 | df = self.snapshot_evaluator.evaluate( |
531 | 605 | snapshot, |
@@ -671,6 +745,55 @@ def diff(self, environment: t.Optional[str] = None, detailed: bool = False) -> N |
671 | 745 | self._context_diff(environment or c.PROD), detailed |
672 | 746 | ) |
673 | 747 |
|
| 748 | + def table_diff( |
| 749 | + self, |
| 750 | + source: str, |
| 751 | + target: str, |
| 752 | + on: t.List[str] | exp.Condition, |
| 753 | + model_or_snapshot: t.Optional[ModelOrSnapshot] = None, |
| 754 | + where: t.Optional[str | exp.Condition] = None, |
| 755 | + limit: int = 20, |
| 756 | + show: bool = True, |
| 757 | + ) -> TableDiff: |
| 758 | + """Show a diff between two tables. |
| 759 | +
|
| 760 | + Args: |
| 761 | + source: The source environment or table. |
| 762 | + target: The target environment or table. |
| 763 | + on: The join condition, table aliases must be "s" and "t" for source and target. |
| 764 | + model_or_snapshot: The model or snapshot to use when environments are passed in. |
| 765 | + where: An optional where statement to filter results. |
| 766 | + limit: The limit of the sample dataframe. |
| 767 | + show: Show the table diff in the console. |
| 768 | +
|
| 769 | + Returns: |
| 770 | + The TableDiff object containing schema and summary differences. |
| 771 | + """ |
| 772 | + if model_or_snapshot: |
| 773 | + model = self.get_model(model_or_snapshot, raise_if_missing=True) |
| 774 | + source_env = self.state_reader.get_environment(source) |
| 775 | + target_env = self.state_reader.get_environment(target) |
| 776 | + |
| 777 | + if not source_env: |
| 778 | + raise SQLMeshError(f"Could not find environment '{source}'") |
| 779 | + if not target_env: |
| 780 | + raise SQLMeshError(f"Could not find environment '{target}')") |
| 781 | + |
| 782 | + source = next( |
| 783 | + snapshot for snapshot in source_env.snapshots if snapshot.name == model.name |
| 784 | + ).table_name() |
| 785 | + target = next( |
| 786 | + snapshot for snapshot in target_env.snapshots if snapshot.name == model.name |
| 787 | + ).table_name() |
| 788 | + |
| 789 | + table_diff = TableDiff( |
| 790 | + adapter=self._engine_adapter, source=source, target=target, on=on, where=where |
| 791 | + ) |
| 792 | + if show: |
| 793 | + self.console.show_schema_diff(table_diff.schema_diff()) |
| 794 | + self.console.show_row_diff(table_diff.row_diff()) |
| 795 | + return table_diff |
| 796 | + |
674 | 797 | def get_dag(self, format: str = "svg") -> graphviz.Digraph: |
675 | 798 | """Gets a graphviz dag. |
676 | 799 |
|
|
0 commit comments