From b357e7817f3ce1b6797ab0c149a17fcfebb21045 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Tue, 11 Feb 2025 11:35:15 +1300 Subject: [PATCH] Feat!: Adjust physical_properties evaluation and add macro to resolve physical table names (#3772) --- docs/concepts/macros/sqlmesh_macros.md | 40 +++++++ docs/integrations/engines/trino.md | 8 +- sqlmesh/core/macros.py | 53 +++++++++ sqlmesh/core/model/definition.py | 47 +++++--- sqlmesh/core/snapshot/evaluator.py | 95 ++++++++------- .../integration/test_integration_trino.py | 110 ++++++++++++++++++ tests/core/test_macros.py | 58 +++++++++ tests/core/test_model.py | 58 +++++++++ tests/core/test_test.py | 42 +++++++ 9 files changed, 456 insertions(+), 55 deletions(-) create mode 100644 tests/core/engine_adapter/integration/test_integration_trino.py diff --git a/docs/concepts/macros/sqlmesh_macros.md b/docs/concepts/macros/sqlmesh_macros.md index fe3caffefc..59089f33ab 100644 --- a/docs/concepts/macros/sqlmesh_macros.md +++ b/docs/concepts/macros/sqlmesh_macros.md @@ -1000,6 +1000,46 @@ Note: This is DuckDB SQL and other dialects will be transpiled accordingly. - Recursive CTEs (common table expressions) will be used for `Redshift / MySQL / MSSQL`. - For `MSSQL` in particular, there's a recursion limit of approximately 100. If this becomes a problem, you can add an `OPTION (MAXRECURSION 0)` clause after the date spine macro logic to remove the limit. This applies for long date ranges. +### @RESOLVE_TEMPLATE + +`@resolve_template` is a helper macro intended to be used in situations where you need to gain access to the *components* of the physical object name. It's intended for use in the following situations: + +- Providing explicit control over table locations on a per-model basis for engines that decouple storage and compute (such as Athena, Trino, Spark etc) +- Generating references to engine-specific metadata tables that are derived from the physical table name, such as the [`$properties`](https://trino.io/docs/current/connector/iceberg.html#metadata-tables) metadata table in Trino. + +Under the hood, it uses the `@this_model` variable so it can only be used during the `creating` and `evaluation` [runtime stages](./macro_variables.md#runtime-variables). Attempting to use it at the `loading` runtime stage will result in a no-op. + +The `@resolve_template` macro supports the following arguments: + + - `template` - The string template to render into an AST node + - `mode` - What type of SQLGlot AST node to return after rendering the template. Valid values are `literal` or `table`. Defaults to `literal`. + +The `template` can contain the following placeholders that will be substituted: + + - `@{catalog_name}` - The name of the catalog, eg `datalake` + - `@{schema_name}` - The name of the physical schema that SQLMesh is using for the model version table, eg `sqlmesh__landing` + - `@{table_name}` - The name of the physical table that SQLMesh is using for the model version, eg `landing__customers__2517971505` + +It can be used in a `MODEL` block: + +```sql linenums="1" hl_lines="5" +MODEL ( + name datalake.landing.customers, + ... + physical_properties ( + location = @resolve_template('s3://warehouse-data/@{catalog_name}/prod/@{schema_name}/@{table_name}') + ) +); +-- CREATE TABLE "datalake"."sqlmesh__landing"."landing__customers__2517971505" ... +-- WITH (location = 's3://warehouse-data/datalake/prod/sqlmesh__landing/landing__customers__2517971505') +``` + +And also within a query, using `mode := 'table'`: + +```sql linenums="1" +SELECT * FROM @resolve_template('@{catalog_name}.@{schema_name}.@{table_name}$properties', mode := 'table') +-- SELECT * FROM "datalake"."sqlmesh__landing"."landing__customers__2517971505$properties" +``` ### @AND diff --git a/docs/integrations/engines/trino.md b/docs/integrations/engines/trino.md index db196576ad..46c1c623cf 100644 --- a/docs/integrations/engines/trino.md +++ b/docs/integrations/engines/trino.md @@ -185,14 +185,16 @@ This would perform the following mappings: Often, you dont need to configure an explicit table location because if you have configured explicit schema locations, table locations are automatically inferred by Trino to be a subdirectory under the schema location. -However, if you need to, you can configure an explicit table location by adding a `location` property to the model `physical_properties`: +However, if you need to, you can configure an explicit table location by adding a `location` property to the model `physical_properties`. -``` +Note that you need to use the [@resolve_template](../../concepts/macros/sqlmesh_macros.md#resolve_template) macro to generate a unique table location for each model version. Otherwise, all model versions will be written to the same location and clobber each other. + +```sql hl_lines="5" MODEL ( name staging.customers, kind FULL, physical_properties ( - location = 's3://warehouse/staging/customers' + location = @resolve_template('s3://warehouse/@{catalog_name}/@{schema_name}/@{table_name}') ) ); diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index ce284d109f..89ec79e5ba 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -1185,6 +1185,59 @@ def date_spine( return exp.select(alias_name).from_(exploded) +@macro() +def resolve_template( + evaluator: MacroEvaluator, + template: exp.Literal, + mode: str = "literal", +) -> t.Union[exp.Literal, exp.Table]: + """ + Generates either a String literal or an exp.Table representing a physical table location, based on rendering the provided template String literal. + + Note: It relies on the @this_model variable being available in the evaluation context (@this_model resolves to an exp.Table object + representing the current physical table). + Therefore, the @resolve_template macro must be used at creation or evaluation time and not at load time. + + Args: + template: Template string literal. Can contain the following placeholders: + @{catalog_name} -> replaced with the catalog of the exp.Table returned from @this_model + @{schema_name} -> replaced with the schema of the exp.Table returned from @this_model + @{table_name} -> replaced with the name of the exp.Table returned from @this_model + mode: What to return. + 'literal' -> return an exp.Literal string + 'table' -> return an exp.Table + + Example: + >>> from sqlglot import parse_one, exp + >>> from sqlmesh.core.macros import MacroEvaluator, RuntimeStage + >>> sql = "@resolve_template('s3://data-bucket/prod/@{catalog_name}/@{schema_name}/@{table_name}')" + >>> evaluator = MacroEvaluator(runtime_stage=RuntimeStage.CREATING) + >>> evaluator.locals.update({"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")}) + >>> evaluator.transform(parse_one(sql)).sql() + "'s3://data-bucket/prod/test_catalog/sqlmesh__test/test__test_model__2517971505'" + """ + if "this_model" in evaluator.locals: + this_model = exp.to_table(evaluator.locals["this_model"], dialect=evaluator.dialect) + template_str: str = template.this + result = ( + template_str.replace("@{catalog_name}", this_model.catalog) + .replace("@{schema_name}", this_model.db) + .replace("@{table_name}", this_model.name) + ) + + if mode.lower() == "table": + return exp.to_table(result, dialect=evaluator.dialect) + return exp.Literal.string(result) + elif evaluator.runtime_stage != RuntimeStage.LOADING.value: + # only error if we are CREATING, EVALUATING or TESTING and @this_model is not present; this could indicate a bug + # otherwise, for LOADING, it's a no-op + raise SQLMeshError( + "@this_model must be present in the macro evaluation context in order to use @resolve_template" + ) + + return template + + def normalize_macro_name(name: str) -> str: """Prefix macro name with @ and upcase""" return f"@{name.upper()}" diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index d8fc276dd2..c3bfefb47a 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -639,6 +639,26 @@ def render_merge_filter( raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}") return rendered_exprs[0].transform(d.replace_merge_table_aliases) + def render_physical_properties(self, **render_kwargs: t.Any) -> t.Dict[str, exp.Expression]: + def _render(expression: exp.Expression) -> exp.Expression: + # note: we use the _statement_renderer instead of _create_renderer because it sets model_fqn which + # in turn makes @this_model available in the evaluation context + rendered_exprs = self._statement_renderer(expression).render(**render_kwargs) + + if not rendered_exprs: + raise SQLMeshError( + f"Expected rendering '{expression.sql(dialect=self.dialect)}' to return an expression" + ) + + if len(rendered_exprs) != 1: + raise SQLMeshError( + f"Expected one result when rendering '{expression.sql(dialect=self.dialect)}' but got {len(rendered_exprs)}" + ) + + return rendered_exprs[0] + + return {k: _render(v) for k, v in self.physical_properties.items()} + def _create_renderer(self, expression: exp.Expression) -> ExpressionRenderer: return ExpressionRenderer( expression, @@ -1814,16 +1834,16 @@ def load_sql_based_model( meta = d.Model(expressions=[]) # Dummy meta node expressions.insert(0, meta) + # We deliberately hold off rendering some properties at load time because there is not enough information available + # at load time to render them. They will get rendered later at evaluation time + unrendered_properties = {} unrendered_merge_filter = None - unrendered_signals = None - unrendered_audits = None for prop in meta.expressions: - if prop.name.lower() == "signals": - unrendered_signals = prop.args.get("value") - if prop.name.lower() == "audits": - unrendered_audits = prop.args.get("value") - if ( + prop_name = prop.name.lower() + if prop_name in ("signals", "audits", "physical_properties"): + unrendered_properties[prop_name] = prop.args.get("value") + elif ( prop.name.lower() == "kind" and (value := prop.args.get("value")) and value.name.lower() == "incremental_by_unique_key" @@ -1868,12 +1888,9 @@ def load_sql_based_model( **kwargs, } - # signals, audits and merge_filter must remain unrendered, so that they can be rendered later at evaluation runtime - if unrendered_signals: - meta_fields["signals"] = unrendered_signals - - if unrendered_audits: - meta_fields["audits"] = unrendered_audits + # Discard the potentially half-rendered versions of these properties and replace them with the + # original unrendered versions. They will get rendered properly at evaluation time + meta_fields.update(unrendered_properties) if unrendered_merge_filter: for idx, kind_prop in enumerate(meta_fields["kind"].expressions): @@ -2166,6 +2183,10 @@ def _create_model( statements.extend(kwargs["post_statements"]) if "on_virtual_update" in kwargs: statements.extend(kwargs["on_virtual_update"]) + if physical_properties := kwargs.get("physical_properties"): + # to allow variables like @gateway to be used in physical_properties + # since rendering shifted from load time to run time + statements.extend(physical_properties) jinja_macro_references, used_variables = extract_macro_references_and_variables( *(gen(e) for e in statements) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 8770314807..6e3d23ec24 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -601,6 +601,28 @@ def _evaluate_snapshot( # If there are no existing intervals yet; only consider this a first insert for the first snapshot in the batch is_first_insert = not _intervals(snapshot, deployability_index) and batch_index == 0 + from sqlmesh.core.context import ExecutionContext + + common_render_kwargs = dict( + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + runtime_stage=RuntimeStage.EVALUATING, + **kwargs, + ) + + render_statements_kwargs = dict( + engine_adapter=adapter, + snapshots=snapshots, + deployability_index=deployability_index, + **common_render_kwargs, + ) + + rendered_physical_properties = snapshot.model.render_physical_properties( + **render_statements_kwargs + ) + def apply(query_or_df: QueryOrDF, index: int = 0) -> None: if index > 0: evaluation_strategy.append( @@ -614,6 +636,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: start=start, end=end, execution_time=execution_time, + physical_properties=rendered_physical_properties, ) else: logger.info( @@ -634,26 +657,9 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: start=start, end=end, execution_time=execution_time, + physical_properties=rendered_physical_properties, ) - from sqlmesh.core.context import ExecutionContext - - common_render_kwargs = dict( - start=start, - end=end, - execution_time=execution_time, - snapshot=snapshot, - runtime_stage=RuntimeStage.EVALUATING, - **kwargs, - ) - - render_statements_kwargs = dict( - engine_adapter=adapter, - snapshots=snapshots, - deployability_index=deployability_index, - **common_render_kwargs, - ) - with adapter.transaction(), adapter.session(snapshot.model.session_properties): wap_id: t.Optional[str] = None if ( @@ -765,6 +771,9 @@ def _create_snapshot( with adapter.transaction(), adapter.session(snapshot.model.session_properties): adapter.execute(snapshot.model.render_pre_statements(**pre_post_render_kwargs)) + rendered_physical_properties = snapshot.model.render_physical_properties( + **create_render_kwargs + ) if ( snapshot.is_forward_only @@ -792,6 +801,7 @@ def _create_snapshot( ), is_snapshot_deployable=is_snapshot_deployable, is_snapshot_representative=is_snapshot_representative, + physical_properties=rendered_physical_properties, ) try: adapter.clone_table(target_table_name, snapshot.table_name(), replace=True) @@ -831,6 +841,7 @@ def _create_snapshot( is_snapshot_deployable=is_snapshot_deployable, is_snapshot_representative=is_snapshot_representative, dry_run=dry_run, + physical_properties=rendered_physical_properties, ) adapter.execute(snapshot.model.render_post_statements(**pre_post_render_kwargs)) @@ -894,6 +905,7 @@ def _migrate_snapshot( is_snapshot_deployable=True, is_snapshot_representative=True, dry_run=False, + physical_properties=snapshot.model.render_physical_properties(**render_kwargs), ) adapter.execute(snapshot.model.render_post_statements(**render_kwargs)) @@ -1217,7 +1229,9 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None: view_name: The name of the target view in the virtual layer. """ - def _replace_query_for_model(self, model: Model, name: str, query_or_df: QueryOrDF) -> None: + def _replace_query_for_model( + self, model: Model, name: str, query_or_df: QueryOrDF, **kwargs: t.Any + ) -> None: """Replaces the table for the given model. Args: @@ -1239,7 +1253,7 @@ def _replace_query_for_model(self, model: Model, name: str, query_or_df: QueryOr partitioned_by=model.partitioned_by, partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, columns_to_types=columns_to_types, @@ -1358,6 +1372,7 @@ def create( **kwargs: t.Any, ) -> None: ctas_query = model.ctas_query(**render_kwargs) + physical_properties = kwargs.get("physical_properties", model.physical_properties) logger.info("Creating table '%s'", table_name) if model.annotated: @@ -1369,7 +1384,7 @@ def create( partitioned_by=model.partitioned_by, partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=physical_properties, table_description=model.description if is_table_deployable else None, column_descriptions=model.column_descriptions if is_table_deployable else None, ) @@ -1393,7 +1408,7 @@ def create( partitioned_by=model.partitioned_by, partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=physical_properties, table_description=model.description if is_table_deployable else None, column_descriptions=model.column_descriptions if is_table_deployable else None, ) @@ -1428,7 +1443,7 @@ def insert( **kwargs: t.Any, ) -> None: if is_first_insert: - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, **kwargs) else: self.adapter.insert_overwrite_by_partition( table_name, @@ -1468,7 +1483,7 @@ def insert( **kwargs: t.Any, ) -> None: if is_first_insert: - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, **kwargs) else: self.adapter.merge( table_name, @@ -1514,7 +1529,7 @@ def insert( **kwargs: t.Any, ) -> None: if is_first_insert: - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, **kwargs) elif isinstance(model.kind, IncrementalUnmanagedKind) and model.kind.insert_overwrite: self.adapter.insert_overwrite_by_partition( table_name, @@ -1540,7 +1555,7 @@ def insert( is_first_insert: bool, **kwargs: t.Any, ) -> None: - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, **kwargs) class SeedStrategy(MaterializableStrategy): @@ -1569,7 +1584,7 @@ def create( try: for index, df in enumerate(model.render_seed()): if index == 0: - self._replace_query_for_model(model, table_name, df) + self._replace_query_for_model(model, table_name, df, **kwargs) else: self.adapter.insert_append( table_name, df, columns_to_types=model.columns_to_types @@ -1613,7 +1628,7 @@ def create( partitioned_by=model.partitioned_by, partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description if is_table_deployable else None, column_descriptions=model.column_descriptions if is_table_deployable else None, ) @@ -1759,7 +1774,7 @@ def insert( model.columns_to_types, replace=not self.adapter.HAS_VIEW_BINDING, materialized=self._is_materialized_view(model), - view_properties=model.physical_properties, + view_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, ) @@ -1814,7 +1829,7 @@ def create( replace=False, materialized=self._is_materialized_view(model), materialized_properties=materialized_properties, - view_properties=model.physical_properties, + view_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description if is_table_deployable else None, column_descriptions=model.column_descriptions if is_table_deployable else None, ) @@ -1828,16 +1843,16 @@ def migrate( ) -> None: logger.info("Migrating view '%s'", target_table_name) model = snapshot.model + render_kwargs = dict( + execution_time=now(), snapshots=kwargs["snapshots"], engine_adapter=self.adapter + ) + self.adapter.create_view( target_table_name, - model.render_query_or_raise( - execution_time=now(), - snapshots=kwargs["snapshots"], - engine_adapter=self.adapter, - ), + model.render_query_or_raise(**render_kwargs), model.columns_to_types, materialized=self._is_materialized_view(model), - view_properties=model.physical_properties, + view_properties=model.render_physical_properties(**render_kwargs), table_description=model.description, column_descriptions=model.column_descriptions, ) @@ -1941,7 +1956,7 @@ def create( columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, ) @@ -1977,7 +1992,7 @@ def insert( columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, ) @@ -1989,7 +2004,9 @@ def insert( table_name, model.name, ) - self._replace_query_for_model(model=model, name=table_name, query_or_df=query_or_df) + self._replace_query_for_model( + model=model, name=table_name, query_or_df=query_or_df, **kwargs + ) def append( self, diff --git a/tests/core/engine_adapter/integration/test_integration_trino.py b/tests/core/engine_adapter/integration/test_integration_trino.py new file mode 100644 index 0000000000..cdebd290ec --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_trino.py @@ -0,0 +1,110 @@ +import typing as t +import pytest +from pathlib import Path +from sqlmesh.core.engine_adapter import TrinoEngineAdapter +from tests.core.engine_adapter.integration import TestContext +from sqlglot import parse_one, exp + +pytestmark = [pytest.mark.docker, pytest.mark.engine, pytest.mark.trino] + + +@pytest.fixture( + params=[ + pytest.param( + "trino", + marks=[ + pytest.mark.docker, + pytest.mark.engine, + pytest.mark.trino, + ], + ), + pytest.param( + "trino_iceberg", + marks=[ + pytest.mark.docker, + pytest.mark.engine, + pytest.mark.trino_iceberg, + ], + ), + pytest.param( + "trino_delta", + marks=[ + pytest.mark.docker, + pytest.mark.engine, + pytest.mark.trino_delta, + ], + ), + pytest.param( + "trino_nessie", + marks=[ + pytest.mark.docker, + pytest.mark.engine, + pytest.mark.trino_nessie, + ], + ), + ] +) +def mark_gateway(request) -> t.Tuple[str, str]: + return request.param, f"inttest_{request.param}" + + +@pytest.fixture +def test_type() -> str: + return "query" + + +def test_macros_in_physical_properties( + tmp_path: Path, ctx: TestContext, engine_adapter: TrinoEngineAdapter +): + if "iceberg" not in ctx.gateway: + pytest.skip("This test only needs to be run once") + + models_dir = tmp_path / "models" + models_dir.mkdir(parents=True) + + schema = ctx.schema() + + with open(models_dir / "test_model.sql", "w") as f: + f.write( + """ + MODEL ( + name SCHEMA.test, + kind FULL, + physical_properties ( + location = @resolve_template('s3://trino/@{catalog_name}/@{schema_name}/@{table_name}'), + sorted_by = @if(@gateway = 'inttest_trino_iceberg', ARRAY['col_a'], ARRAY['col_b']) + ) + ); + + select 1 as col_a, 2 as col_b; + """.replace("SCHEMA", schema) + ) + + context = ctx.create_context(path=tmp_path) + assert len(context.models) == 1 + + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + + snapshot = plan_result.new_snapshots[0] + + physical_table_str = snapshot.table_name() + physical_table = exp.to_table(physical_table_str) + create_sql = list(engine_adapter.fetchone(f"show create table {physical_table}") or [])[0] + + parsed_create_sql = parse_one(create_sql, dialect="trino") + + location_property = parsed_create_sql.find(exp.LocationProperty) + assert location_property + + assert "@{table_name}" not in location_property.sql(dialect="trino") + assert ( + location_property.text("this") + == f"s3://trino/{physical_table.catalog}/{physical_table.db}/{physical_table.name}" + ) + + sorted_by_property = next( + p for p in parsed_create_sql.find_all(exp.Property) if "sorted_by" in p.sql(dialect="trino") + ) + assert sorted_by_property.sql(dialect="trino") == "sorted_by=ARRAY['col_a ASC NULLS FIRST']" diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index e185e205e5..510462adbe 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -8,6 +8,7 @@ from sqlmesh.core.macros import SQL, MacroEvalError, MacroEvaluator, macro from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.core.macros import RuntimeStage @pytest.fixture @@ -1020,3 +1021,60 @@ def test_macro_union(assert_exp_eq, macro_evaluator: MacroEvaluator): expected_sql = "SELECT 1 AS col UNION ALL SELECT 1 AS col" assert_exp_eq(macro_evaluator.transform(parse_one(sql)), expected_sql) + + +def test_resolve_template_literal(): + parsed_sql = parse_one( + "@resolve_template('s3://data-bucket/prod/@{catalog_name}/@{schema_name}/@{table_name}')" + ) + + # Loading + # During loading, this should passthrough / no-op + # This is because SQLMesh renders everything on load to figure out model dependencies and we dont want to throw an error + evaluator = MacroEvaluator(runtime_stage=RuntimeStage.LOADING) + assert evaluator.transform(parsed_sql) == exp.Literal.string( + "s3://data-bucket/prod/@{catalog_name}/@{schema_name}/@{table_name}" + ) + + # Creating + # This macro can work during creating / evaluating but only if @this_model is present in the context + evaluator = MacroEvaluator(runtime_stage=RuntimeStage.CREATING) + with pytest.raises(SQLMeshError) as e: + evaluator.transform(parsed_sql) + + assert "this_model must be present" in str(e.value.__cause__) + + evaluator.locals.update( + {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + ) + + assert ( + evaluator.transform(parsed_sql).sql() + == "'s3://data-bucket/prod/test_catalog/sqlmesh__test/test__test_model__2517971505'" + ) + + # Evaluating + evaluator = MacroEvaluator(runtime_stage=RuntimeStage.EVALUATING) + evaluator.locals.update( + {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + ) + assert ( + evaluator.transform(parsed_sql).sql() + == "'s3://data-bucket/prod/test_catalog/sqlmesh__test/test__test_model__2517971505'" + ) + + +def test_resolve_template_table(): + parsed_sql = parse_one( + "SELECT * FROM @resolve_template('@{catalog_name}.@{schema_name}.@{table_name}$partitions', mode := 'table')" + ) + + evaluator = MacroEvaluator(runtime_stage=RuntimeStage.CREATING) + evaluator.locals.update( + {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + ) + + assert ( + evaluator.transform(parsed_sql).sql(identify=True) + == 'SELECT * FROM "test_catalog"."sqlmesh__test"."test__test_model__2517971505$partitions"' + ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 8d7fd729d9..eacd1f5c6a 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -29,6 +29,7 @@ from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.macros import MacroEvaluator, macro +from sqlmesh.core import constants as c from sqlmesh.core.model import ( CustomKind, PythonModel, @@ -58,6 +59,7 @@ from sqlmesh.utils.errors import ConfigError, SQLMeshError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.core.macros import RuntimeStage def missing_schema_warning_msg(model, deps): @@ -5437,6 +5439,62 @@ def this_model_resolves_to_quoted_table(evaluator): ) +def test_macros_in_physical_properties(make_snapshot): + expressions = d.parse( + """ + MODEL ( + name test.test_model, + kind FULL, + physical_properties ( + location1 = @resolve_template('s3://bucket/prefix/@{schema_name}/@{table_name}'), + location2 = @IF( + @gateway = 'dev', + @resolve_template('hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}'), + @resolve_template('s3://prod/@{table_name}') + ), + sort_order = @IF(@gateway = 'prod', 'desc', 'asc') + ) + ); + + SELECT 1; + """ + ) + + model = load_sql_based_model( + expressions, variables={"gateway": "dev"}, default_catalog="unit_test" + ) + + assert model.name == "test.test_model" + assert "location1" in model.physical_properties + assert "location2" in model.physical_properties + assert "sort_order" in model.physical_properties + + # load time is a no-op + assert isinstance(model.physical_properties["location1"], d.MacroFunc) + assert isinstance(model.physical_properties["location2"], d.MacroFunc) + assert isinstance(model.physical_properties["sort_order"], d.MacroFunc) + + # substitution occurs at runtime + snapshot: Snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + rendered_physical_properties = model.render_physical_properties( + snapshots={model.fqn: snapshot}, # to trigger @this_model generation + runtime_stage=RuntimeStage.CREATING, + python_env=model.python_env, + ) + + assert ( + rendered_physical_properties["location1"].text("this") + == f"s3://bucket/prefix/sqlmesh__test/test__test_model__{snapshot.version}" + ) + assert ( + rendered_physical_properties["location2"].text("this") + == f"hdfs://unit_test/sqlmesh__test/dev/test__test_model__{snapshot.version}" + ) + assert rendered_physical_properties["sort_order"].text("this") == "asc" + + def test_macros_in_model_statement(sushi_context, assert_exp_eq): @macro() def session_properties(evaluator, value): diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 5f632236e0..b9c0562d36 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -2086,3 +2086,45 @@ def test_test_with_gateway_specific_model(tmp_path: Path, mocker: MockerFixture) '"memory"."sqlmesh_example"."input_model"': [{"c": 5}] } assert test["test_gw_model"]["outputs"] == {"query": [{"c": 5}]} + + +def test_test_with_resolve_template_macro(tmp_path: Path): + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "foo.sql").write_text( + """ + MODEL ( + name test.foo, + kind full, + physical_properties ( + location = @resolve_template('file:///tmp/@{table_name}') + ) + ); + + SELECT t.a + 1 as a + FROM @resolve_template('@{schema_name}.dev_@{table_name}', mode := 'table') as t + """ + ) + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "test_foo.yaml").write_text( + """ +test_resolve_template_macro: + model: test.foo + inputs: + test.dev_foo: + - a: 1 + outputs: + query: + - a: 2 + """ + ) + + context = Context(paths=tmp_path, config=config) + _check_successful_or_raise(context.test())