Skip to content

Commit d1d8d9e

Browse files
authored
add migration framework (#654)
* add migration framework * move init to migration and refactor for pr * make postgres engine adapter so it can inherit table_exists * add logging
1 parent 2d30d7f commit d1d8d9e

File tree

18 files changed

+504
-64
lines changed

18 files changed

+504
-64
lines changed

sqlmesh/core/context.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,9 @@ def state_sync(self) -> StateSync:
322322
raise ConfigError(
323323
"The operation is not supported when using a read-only state sync"
324324
)
325-
self._state_sync.init_schema()
325+
326+
if self._state_sync.get_versions(validate=False).schema_version == 0:
327+
self._state_sync.migrate()
326328
return self._state_sync
327329

328330
@property
@@ -807,6 +809,13 @@ def audit(
807809
self.console.show_sql(f"{error.query}")
808810
self.console.log_status_update("Done.")
809811

812+
def migrate(self) -> None:
813+
"""Migrates SQLMesh to the current running version.
814+
815+
Please contact your SQLMesh administrator before doing this.
816+
"""
817+
self.state_sync.migrate()
818+
810819
def close(self) -> None:
811820
"""Releases all resources allocated by this context."""
812821
self.snapshot_evaluator.close()

sqlmesh/core/engine_adapter/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sqlmesh.core.engine_adapter.databricks import DatabricksSparkSessionEngineAdapter
1010
from sqlmesh.core.engine_adapter.databricks_api import DatabricksSQLEngineAdapter
1111
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
12+
from sqlmesh.core.engine_adapter.postgres import PostgresEngineAdapter
1213
from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter
1314
from sqlmesh.core.engine_adapter.shared import TransactionType
1415
from sqlmesh.core.engine_adapter.snowflake import SnowflakeEngineAdapter
@@ -21,7 +22,7 @@
2122
"snowflake": SnowflakeEngineAdapter,
2223
"databricks": DatabricksSparkSessionEngineAdapter,
2324
"redshift": RedshiftEngineAdapter,
24-
"postgres": EngineAdapterWithIndexSupport,
25+
"postgres": PostgresEngineAdapter,
2526
"mysql": EngineAdapterWithIndexSupport,
2627
"mssql": EngineAdapterWithIndexSupport,
2728
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
from sqlglot import exp
6+
7+
from sqlmesh.core.engine_adapter.base import (
8+
EngineAdapter,
9+
EngineAdapterWithIndexSupport,
10+
)
11+
12+
if t.TYPE_CHECKING:
13+
from sqlmesh.core._typing import TableName
14+
15+
16+
class PostgresBaseEngineAdapter(EngineAdapter):
17+
def table_exists(self, table_name: TableName) -> bool:
18+
"""
19+
Redshift/Postgres doesn't support describe so I'm using what the redshift cursor does to check if a table
20+
exists. We don't use this directly because we still want all execution to go through our execute method
21+
22+
Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553
23+
"""
24+
table = exp.to_table(table_name)
25+
26+
# Redshift doesn't support catalog
27+
if table.args.get("catalog"):
28+
return False
29+
30+
query = exp.select("1").from_("information_schema.tables")
31+
where = exp.condition(f"table_name = '{table.alias_or_name}'")
32+
33+
schema = table.text("db")
34+
if schema:
35+
where = where.and_(f"table_schema = '{schema}'")
36+
37+
self.execute(query.where(where))
38+
39+
result = self.cursor.fetchone()
40+
41+
return result[0] == 1 if result is not None else False
42+
43+
44+
class PostgresEngineAdapter(PostgresBaseEngineAdapter, EngineAdapterWithIndexSupport):
45+
DIALECT = "postgres"

sqlmesh/core/engine_adapter/redshift.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88

99
from sqlmesh.core.dialect import pandas_to_sql
1010
from sqlmesh.core.engine_adapter._typing import DF_TYPES, Query
11-
from sqlmesh.core.engine_adapter.base import EngineAdapter
11+
from sqlmesh.core.engine_adapter.postgres import PostgresBaseEngineAdapter
1212
from sqlmesh.core.engine_adapter.shared import DataObject
1313

1414
if t.TYPE_CHECKING:
1515
from sqlmesh.core._typing import TableName
1616
from sqlmesh.core.engine_adapter._typing import QueryOrDF
1717

1818

19-
class RedshiftEngineAdapter(EngineAdapter):
19+
class RedshiftEngineAdapter(PostgresBaseEngineAdapter):
2020
DIALECT = "redshift"
2121
DEFAULT_BATCH_SIZE = 1000
2222

@@ -145,32 +145,6 @@ def replace_query(
145145
def _short_hash(self) -> str:
146146
return uuid.uuid4().hex[:8]
147147

148-
def table_exists(self, table_name: TableName) -> bool:
149-
"""
150-
Redshift doesn't support describe so I'm using what the redshift cursor does to check if a table
151-
exists. We don't use this directly because we still want all execution to go through our execute method
152-
153-
Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553
154-
"""
155-
table = exp.to_table(table_name)
156-
157-
# Redshift doesn't support catalog
158-
if table.args.get("catalog"):
159-
return False
160-
161-
q: str = (
162-
f"SELECT 1 FROM information_schema.tables WHERE table_name = '{table.alias_or_name}'"
163-
)
164-
database_name = table.args.get("db")
165-
if database_name:
166-
q += f" AND table_schema = '{database_name}'"
167-
168-
self.execute(q)
169-
170-
result = self.cursor.fetchone()
171-
172-
return result[0] == 1 if result is not None else False
173-
174148
def _get_data_objects(
175149
self, schema_name: str, catalog_name: t.Optional[str] = None
176150
) -> t.List[DataObject]:

sqlmesh/core/state_sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
The provided `sqlmesh.core.state_sync.EngineAdapterStateSync` leverages an existing engine
1414
adapter to read and write state to the underlying data store.
1515
"""
16-
from sqlmesh.core.state_sync.base import StateReader, StateSync
16+
from sqlmesh.core.state_sync.base import StateReader, StateSync, Versions
1717
from sqlmesh.core.state_sync.common import CommonStateSyncMixin
1818
from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync

sqlmesh/core/state_sync/base.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from __future__ import annotations
22

33
import abc
4+
import importlib
5+
import logging
6+
import pkgutil
47
import typing as t
58

9+
from sqlglot import __version__ as SQLGLOT_VERSION
10+
11+
from sqlmesh import migrations
612
from sqlmesh.core import scheduler
713
from sqlmesh.core.environment import Environment
814
from sqlmesh.core.snapshot import (
@@ -14,8 +20,30 @@
1420
SnapshotNameVersionLike,
1521
SnapshotTableInfo,
1622
)
23+
from sqlmesh.utils import major_minor
1724
from sqlmesh.utils.date import TimeLike, now, to_datetime
1825
from sqlmesh.utils.errors import SQLMeshError
26+
from sqlmesh.utils.pydantic import PydanticModel
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class Versions(PydanticModel):
32+
"""Represents the various versions of dependencies in the state sync."""
33+
34+
schema_version: int
35+
sqlglot_version: str
36+
37+
@property
38+
def minor_sqlglot_version(self) -> t.Tuple[int, int]:
39+
return major_minor(self.sqlglot_version)
40+
41+
42+
MIGRATIONS = [
43+
importlib.import_module(f"sqlmesh.migrations.{migration}")
44+
for migration in sorted(info.name for info in pkgutil.iter_modules(migrations.__path__))
45+
]
46+
SCHEMA_VERSION: int = len(MIGRATIONS)
1947

2048

2149
class StateReader(abc.ABC):
@@ -168,13 +196,55 @@ def missing_intervals(
168196
missing[snapshot] = intervals
169197
return missing
170198

199+
def get_versions(self, validate: bool = True) -> Versions:
200+
"""Get the current versions of the SQLMesh schema and libraries.
171201
172-
class StateSync(StateReader, abc.ABC):
173-
"""Abstract base class for snapshot and environment state management."""
202+
Args:
203+
validate: Whether or not to raise error if the running version is ahead of state.
204+
205+
Returns:
206+
The versions object.
207+
"""
208+
versions = self._get_versions()
209+
210+
def raise_error(lib: str, local: str | int, remote: str | int, ahead: bool = False) -> None:
211+
if ahead:
212+
raise SQLMeshError(
213+
f"{lib} (local) is using version '{local}' which is ahead of '{remote}' (remote). Please run a migration."
214+
)
215+
raise SQLMeshError(
216+
f"{lib} (local) is using version '{local}' which is behind '{remote}' (remote). Please upgrade {lib}."
217+
)
218+
219+
if SCHEMA_VERSION < versions.schema_version:
220+
raise_error("SQLMesh", SCHEMA_VERSION, versions.schema_version)
221+
222+
if major_minor(SQLGLOT_VERSION) < major_minor(versions.sqlglot_version):
223+
raise_error("SQLGlot", SQLGLOT_VERSION, versions.sqlglot_version)
224+
225+
if validate:
226+
if SCHEMA_VERSION > versions.schema_version:
227+
raise_error("SQLMesh", SCHEMA_VERSION, versions.schema_version, ahead=True)
228+
229+
if major_minor(SQLGLOT_VERSION) > major_minor(versions.sqlglot_version):
230+
raise_error("SQLGlot", SQLGLOT_VERSION, versions.sqlglot_version, ahead=True)
231+
232+
return versions
174233

175234
@abc.abstractmethod
176-
def init_schema(self) -> None:
177-
"""Optional initialization of the sync."""
235+
def _get_versions(self, lock_for_update: bool = False) -> Versions:
236+
"""Queries the store to get the current versions of SQLMesh and deps.
237+
238+
Args:
239+
lock_for_update: Whether or not the usage of this method plans to update the row.
240+
241+
Returns:
242+
The versions object.
243+
"""
244+
245+
246+
class StateSync(StateReader, abc.ABC):
247+
"""Abstract base class for snapshot and environment state management."""
178248

179249
@abc.abstractmethod
180250
def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None:
@@ -291,3 +361,30 @@ def unpause_snapshots(
291361
unpaused_dt: The datetime object which indicates when target snapshots
292362
were unpaused.
293363
"""
364+
365+
def migrate(self) -> None:
366+
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
367+
versions = self.get_versions(validate=False)
368+
migrations = MIGRATIONS[versions.schema_version :]
369+
370+
if not migrations and major_minor(SQLGLOT_VERSION) == versions.minor_sqlglot_version:
371+
return
372+
373+
for migration in migrations:
374+
logger.info(f"Applying migration {migration}")
375+
migration.migrate(self)
376+
377+
self._migrate_rows()
378+
self._update_versions()
379+
380+
@abc.abstractmethod
381+
def _migrate_rows(self) -> None:
382+
"""Migrate all rows in the state sync, including snapshots and environments."""
383+
384+
@abc.abstractmethod
385+
def _update_versions(
386+
self,
387+
schema_version: int = SCHEMA_VERSION,
388+
sqlglot_version: str = SQLGLOT_VERSION,
389+
) -> None:
390+
"""Update the schema versions to the latest running versions."""

0 commit comments

Comments
 (0)