77
88from dbt .adapters .base import BaseRelation
99from dbt .contracts .relation import RelationType
10+ from jinja2 import nodes
11+ from jinja2 .exceptions import UndefinedError
1012from pydantic import Field , validator
1113from sqlglot .helper import ensure_list
1214
@@ -140,6 +142,14 @@ def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
140142
141143 @property
142144 def all_sql (self ) -> SqlStr :
145+ return SqlStr ("\n " .join (self .pre_hook + [self .sql_no_config ] + self .post_hook ))
146+
147+ @property
148+ def sql_no_config (self ) -> SqlStr :
149+ return SqlStr ("" )
150+
151+ @property
152+ def sql_embedded_config (self ) -> SqlStr :
143153 return SqlStr ("" )
144154
145155 @property
@@ -190,13 +200,17 @@ def relation_info(self) -> AttributeDict[str, t.Any]:
190200 }
191201 )
192202
203+ def attribute_dict (self ) -> AttributeDict [str , t .Any ]:
204+ return AttributeDict (self .dict ())
205+
193206 def sqlmesh_model_kwargs (self , model_context : DbtContext ) -> t .Dict [str , t .Any ]:
194207 """Get common sqlmesh model parameters"""
195208 jinja_macros = model_context .jinja_macros .trim (self ._dependencies .macros )
196209 jinja_macros .global_objs .update (
197210 {
198211 "this" : self .relation_info ,
199212 "schema" : self .table_schema ,
213+ "config" : self .attribute_dict (),
200214 ** model_context .jinja_globals , # type: ignore
201215 }
202216 )
@@ -220,7 +234,6 @@ def sqlmesh_model_kwargs(self, model_context: DbtContext) -> t.Dict[str, t.Any]:
220234
221235 def render_config (self : BMC , context : DbtContext ) -> BMC :
222236 rendered = super ().render_config (context )
223- rendered ._dependencies = Dependencies (macros = extract_macro_references (rendered .all_sql ))
224237 rendered = ModelSqlRenderer (context , rendered ).enriched_config
225238
226239 rendered_dependencies = rendered ._dependencies
@@ -275,7 +288,7 @@ def __init__(self, context: DbtContext, config: BMC):
275288 jinja_globals = {
276289 ** context .jinja_globals ,
277290 ** date_dict (c .EPOCH , c .EPOCH , c .EPOCH ),
278- "config" : self . _config ,
291+ "config" : lambda * args , ** kwargs : "" ,
279292 "ref" : self ._ref ,
280293 "var" : self ._var ,
281294 "source" : self ._source ,
@@ -293,9 +306,15 @@ def __init__(self, context: DbtContext, config: BMC):
293306 dialect = context .engine_adapter .dialect if context .engine_adapter else "" ,
294307 )
295308
309+ self .jinja_env = self .context .jinja_macros .build_environment (** self ._jinja_globals )
310+
296311 @property
297312 def enriched_config (self ) -> BMC :
298313 if self ._rendered_sql is None :
314+ self ._enriched_config = self ._update_with_sql_config (self ._enriched_config )
315+ self ._enriched_config ._dependencies = Dependencies (
316+ macros = extract_macro_references (self ._enriched_config .all_sql )
317+ )
299318 self .render ()
300319 self ._enriched_config ._dependencies = self ._enriched_config ._dependencies .union (
301320 self ._captured_dependencies
@@ -304,14 +323,42 @@ def enriched_config(self) -> BMC:
304323
305324 def render (self ) -> str :
306325 if self ._rendered_sql is None :
307- registry = self . context . jinja_macros
308- self ._rendered_sql = (
309- registry . build_environment ( ** self ._jinja_globals )
310- . from_string ( self . config . all_sql )
311- . render ()
312- )
326+ try :
327+ self ._rendered_sql = self . jinja_env . from_string (
328+ self ._enriched_config . all_sql
329+ ). render ( )
330+ except UndefinedError as e :
331+ raise ConfigError ( e . message )
313332 return self ._rendered_sql
314333
334+ def _update_with_sql_config (self , config : BMC ) -> BMC :
335+ def _extract_value (node : t .Any ) -> t .Any :
336+ if not isinstance (node , nodes .Node ):
337+ return node
338+ if isinstance (node , nodes .Const ):
339+ return _extract_value (node .value )
340+ if isinstance (node , nodes .TemplateData ):
341+ return _extract_value (node .data )
342+ if isinstance (node , nodes .List ):
343+ return [_extract_value (val ) for val in node .items ]
344+ if isinstance (node , nodes .Dict ):
345+ return {_extract_value (pair .key ): _extract_value (pair .value ) for pair in node .items }
346+ if isinstance (node , nodes .Tuple ):
347+ return tuple (_extract_value (val ) for val in node .items )
348+
349+ return self .jinja_env .from_string (nodes .Template ([nodes .Output ([node ])])).render ()
350+
351+ for call in self .jinja_env .parse (self ._enriched_config .sql_embedded_config ).find_all (
352+ nodes .Call
353+ ):
354+ if not isinstance (call .node , nodes .Name ) or call .node .name != "config" :
355+ continue
356+ config = config .update_with (
357+ {kwarg .key : _extract_value (kwarg .value ) for kwarg in call .kwargs }
358+ )
359+
360+ return config
361+
315362 def _ref (self , package_name : str , model_name : t .Optional [str ] = None ) -> BaseRelation :
316363 if package_name in self .context .models :
317364 relation = BaseRelation .create (** self .context .models [package_name ].relation_info )
@@ -341,13 +388,6 @@ def _source(self, source_name: str, table_name: str) -> BaseRelation:
341388 self ._captured_dependencies .sources .add (full_name )
342389 return BaseRelation .create (** self .context .sources [full_name ].relation_info )
343390
344- def _config (self , * args : t .Any , ** kwargs : t .Any ) -> str :
345- if args and isinstance (args [0 ], dict ):
346- self ._enriched_config = self ._enriched_config .update_with (args [0 ])
347- if kwargs :
348- self ._enriched_config = self ._enriched_config .update_with (kwargs )
349- return ""
350-
351391 class TrackingAdapter (ParsetimeAdapter ):
352392 def __init__ (self , outer_self : ModelSqlRenderer , * args : t .Any , ** kwargs : t .Any ):
353393 super ().__init__ (* args , ** kwargs )
0 commit comments