Skip to content

Commit

Permalink
Fix: Add adapters for commands that are not using the snapshot evalua…
Browse files Browse the repository at this point in the history
…tor (#3531)
  • Loading branch information
themisvaltinos authored Jan 3, 2025
1 parent 7f650e1 commit e9efeff
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 20 deletions.
20 changes: 9 additions & 11 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,12 +418,10 @@ def engine_adapter(self) -> EngineAdapter:
@property
def snapshot_evaluator(self) -> SnapshotEvaluator:
if not self._snapshot_evaluator:
if self._snapshot_gateways:
self._create_engine_adapters(set(self._snapshot_gateways.values()))
self._snapshot_evaluator = SnapshotEvaluator(
{
gateway: adapter.with_log_level(logging.INFO)
for gateway, adapter in self._engine_adapters.items()
for gateway, adapter in self.engine_adapters.items()
},
ddl_concurrent_tasks=self.concurrent_tasks,
selected_gateway=self.selected_gateway,
Expand Down Expand Up @@ -1476,6 +1474,7 @@ def table_diff(
source_alias, target_alias = source, target

adapter = self.engine_adapter

if model_or_snapshot:
model = self.get_model(model_or_snapshot, raise_if_missing=True)
adapter = self._get_engine_adapter(model.gateway)
Expand Down Expand Up @@ -1641,6 +1640,7 @@ def create_test(
test_adapter = self._test_connection_config.create_engine_adapter(
register_comments_override=False
)

generate_test(
model=model_to_test,
input_queries=input_queries,
Expand Down Expand Up @@ -2021,21 +2021,19 @@ def _snapshot_gateways(self) -> t.Dict[str, str]:
if snapshot.is_model and snapshot.model.gateway
}

def _create_engine_adapters(self, gateways: t.Optional[t.Set] = None) -> None:
"""Create engine adapters for the gateways, when none provided include all defined in the configs."""

@cached_property
def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
"""Returns all the engine adapters for the gateways defined in the configuration."""
for gateway_name in self.config.gateways:
if gateway_name != self.selected_gateway and (
gateways is None or gateway_name in gateways
):
if gateway_name != self.selected_gateway:
connection = self.config.get_connection(gateway_name)
adapter = connection.create_engine_adapter()
self.concurrent_tasks = min(self.concurrent_tasks, connection.concurrent_tasks)
self._engine_adapters[gateway_name] = adapter
return self._engine_adapters

def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
if gateway:
if adapter := self._engine_adapters.get(gateway):
if adapter := self.engine_adapters.get(gateway):
return adapter
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
return self.engine_adapter
Expand Down
1 change: 1 addition & 0 deletions tests/core/engine_adapter/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def create_context(
)
if config_mutator:
config_mutator(self.gateway, config)
config.gateways = {self.gateway: config.gateways[self.gateway]}

gateway_config = config.gateways[self.gateway]
if (
Expand Down
4 changes: 4 additions & 0 deletions tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,8 @@ def test_sushi(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory):
personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()],
)

# To enable parallelism in integration tests
config.gateways = {ctx.gateway: config.gateways[ctx.gateway]}
current_gateway_config = config.gateways[ctx.gateway]
current_gateway_config.state_schema = sushi_state_schema

Expand Down Expand Up @@ -1730,6 +1732,8 @@ def _normalize_snowflake(name: str, prefix_regex: str = "(sqlmesh__)(.*)"):
if config.model_defaults.dialect != ctx.dialect:
config.model_defaults = config.model_defaults.copy(update={"dialect": ctx.dialect})

# To enable parallelism in integration tests
config.gateways = {ctx.gateway: config.gateways[ctx.gateway]}
current_gateway_config = config.gateways[ctx.gateway]

if ctx.dialect == "athena":
Expand Down
8 changes: 3 additions & 5 deletions tests/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,10 +757,8 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
new_callable=mocker.PropertyMock(return_value={"snapshot": "athena"}),
)

ctx._create_engine_adapters()

assert isinstance(ctx._connection_config, RedshiftConnectionConfig)
assert len(ctx._engine_adapters) == 2
assert isinstance(ctx._engine_adapters["athena"], AthenaEngineAdapter)
assert isinstance(ctx._engine_adapters["redshift"], RedshiftEngineAdapter)
assert len(ctx.engine_adapters) == 2
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter)
assert ctx.engine_adapter == ctx._get_engine_adapter("redshift")
9 changes: 6 additions & 3 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,12 @@ def test_gateway_specific_adapters(copy_to_temp_path, mocker):
ctx = Context(paths=path, config="isolated_systems_config", gateway="prod")
assert len(ctx._engine_adapters) == 1
assert ctx.engine_adapter == ctx._engine_adapters["prod"]

with pytest.raises(SQLMeshError):
assert ctx._get_engine_adapter("dev")
assert ctx._get_engine_adapter("non_existing")

# This will create the requested engine adapter
assert ctx._get_engine_adapter("dev") == ctx._engine_adapters["dev"]

ctx = Context(paths=path, config="isolated_systems_config")
assert len(ctx._engine_adapters) == 1
Expand All @@ -337,8 +341,7 @@ def test_gateway_specific_adapters(copy_to_temp_path, mocker):

ctx = Context(paths=path, config="isolated_systems_config")

ctx._create_engine_adapters({"test"})
assert len(ctx._engine_adapters) == 2
assert len(ctx.engine_adapters) == 3
assert ctx.engine_adapter == ctx._get_engine_adapter()
assert ctx._get_engine_adapter("test") == ctx._engine_adapters["test"]

Expand Down
41 changes: 41 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sqlmesh.core.context import Context, ExecutionContext
from sqlmesh.core.dialect import parse
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
from sqlmesh.core.macros import MacroEvaluator, macro
from sqlmesh.core.model import (
CustomKind,
Expand Down Expand Up @@ -6620,3 +6621,43 @@ def test_auto_restatement():
)
with pytest.raises(ValueError, match="Invalid cron expression '@invalid'.*"):
load_sql_based_model(parsed_definition)


def test_gateway_specific_render(assert_exp_eq) -> None:
gateways = {
"main": GatewayConfig(connection=DuckDBConnectionConfig()),
"duckdb": GatewayConfig(connection=DuckDBConnectionConfig()),
}
config = Config(
gateways=gateways,
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
default_gateway="main",
)
context = Context(config=config)
assert context.engine_adapter == context._engine_adapters["main"]

@model(
name="dummy_model",
is_sql=True,
kind="full",
gateway="duckdb",
grain='"x"',
)
def dummy_model_entry(evaluator: MacroEvaluator) -> exp.Select:
return exp.select("x").from_(exp.values([("1", 2)], "_v", ["x"]))

dummy_model = model.get_registry()["dummy_model"].model(module_path=Path("."), path=Path("."))
context.upsert_model(dummy_model)
assert isinstance(dummy_model, SqlModel)
assert dummy_model.gateway == "duckdb"

assert_exp_eq(
context.render("dummy_model"),
"""
SELECT
"_v"."x" AS "x",
FROM (VALUES ('1', 2)) AS "_v"("x")
""",
)
assert isinstance(context._get_engine_adapter("duckdb"), DuckDBEngineAdapter)
assert len(context._engine_adapters) == 2
52 changes: 51 additions & 1 deletion tests/core/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sqlmesh.core.macros import MacroEvaluator, macro
from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model
from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest
from sqlmesh.utils.errors import ConfigError, TestError
from sqlmesh.utils.errors import ConfigError, SQLMeshError, TestError
from sqlmesh.utils.yaml import dump as dump_yaml
from sqlmesh.utils.yaml import load as load_yaml

Expand Down Expand Up @@ -1989,3 +1989,53 @@ def test_test_generation_with_recursive_ctes(tmp_path: Path) -> None:
}

_check_successful_or_raise(context.test())


def test_test_with_gateway_specific_model(tmp_path: Path, mocker: MockerFixture) -> None:
init_example_project(tmp_path, dialect="duckdb")

config = Config(
gateways={
"main": GatewayConfig(connection=DuckDBConnectionConfig()),
"second": GatewayConfig(connection=DuckDBConnectionConfig()),
},
default_gateway="main",
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
)
gw_model_sql_file = tmp_path / "models" / "gw_model.sql"

# The model has a gateway specified which isn't the default
gw_model_sql_file.write_text(
"MODEL (name sqlmesh_example.gw_model, gateway second); SELECT c FROM sqlmesh_example.input_model;"
)
input_model_sql_file = tmp_path / "models" / "input_model.sql"
input_model_sql_file.write_text(
"MODEL (name sqlmesh_example.input_model); SELECT c FROM external_table;"
)

context = Context(paths=tmp_path, config=config)
input_queries = {'"memory"."sqlmesh_example"."input_model"': "SELECT 5 AS c"}
mocker.patch(
"sqlmesh.core.engine_adapter.base.EngineAdapter.fetchdf",
return_value=pd.DataFrame({"c": [5]}),
)

assert context.engine_adapter == context._engine_adapters["main"]
with pytest.raises(
SQLMeshError, match=r"Gateway 'wrong' not found in the available engine adapters."
):
context._get_engine_adapter("wrong")

# Create test should use the gateway specific engine adapter
context.create_test("sqlmesh_example.gw_model", input_queries=input_queries, overwrite=True)
assert context._get_engine_adapter("second") == context._engine_adapters["second"]
assert len(context._engine_adapters) == 2

test = load_yaml(context.path / c.TESTS / "test_gw_model.yaml")

assert len(test) == 1
assert "test_gw_model" in test
assert test["test_gw_model"]["inputs"] == {
'"memory"."sqlmesh_example"."input_model"': [{"c": 5}]
}
assert test["test_gw_model"]["outputs"] == {"query": [{"c": 5}]}

0 comments on commit e9efeff

Please sign in to comment.