|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import abc |
| 4 | +import importlib |
| 5 | +import logging |
| 6 | +import pkgutil |
4 | 7 | import typing as t |
5 | 8 |
|
| 9 | +from sqlglot import __version__ as SQLGLOT_VERSION |
| 10 | + |
| 11 | +from sqlmesh import migrations |
6 | 12 | from sqlmesh.core import scheduler |
7 | 13 | from sqlmesh.core.environment import Environment |
8 | 14 | from sqlmesh.core.snapshot import ( |
|
14 | 20 | SnapshotNameVersionLike, |
15 | 21 | SnapshotTableInfo, |
16 | 22 | ) |
| 23 | +from sqlmesh.utils import major_minor |
17 | 24 | from sqlmesh.utils.date import TimeLike, now, to_datetime |
18 | 25 | from sqlmesh.utils.errors import SQLMeshError |
| 26 | +from sqlmesh.utils.pydantic import PydanticModel |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +class Versions(PydanticModel): |
| 32 | + """Represents the various versions of dependencies in the state sync.""" |
| 33 | + |
| 34 | + schema_version: int |
| 35 | + sqlglot_version: str |
| 36 | + |
| 37 | + @property |
| 38 | + def minor_sqlglot_version(self) -> t.Tuple[int, int]: |
| 39 | + return major_minor(self.sqlglot_version) |
| 40 | + |
| 41 | + |
| 42 | +MIGRATIONS = [ |
| 43 | + importlib.import_module(f"sqlmesh.migrations.{migration}") |
| 44 | + for migration in sorted(info.name for info in pkgutil.iter_modules(migrations.__path__)) |
| 45 | +] |
| 46 | +SCHEMA_VERSION: int = len(MIGRATIONS) |
19 | 47 |
|
20 | 48 |
|
21 | 49 | class StateReader(abc.ABC): |
@@ -168,13 +196,55 @@ def missing_intervals( |
168 | 196 | missing[snapshot] = intervals |
169 | 197 | return missing |
170 | 198 |
|
| 199 | + def get_versions(self, validate: bool = True) -> Versions: |
| 200 | + """Get the current versions of the SQLMesh schema and libraries. |
171 | 201 |
|
172 | | -class StateSync(StateReader, abc.ABC): |
173 | | - """Abstract base class for snapshot and environment state management.""" |
| 202 | + Args: |
| 203 | + validate: Whether or not to raise error if the running version is ahead of state. |
| 204 | +
|
| 205 | + Returns: |
| 206 | + The versions object. |
| 207 | + """ |
| 208 | + versions = self._get_versions() |
| 209 | + |
| 210 | + def raise_error(lib: str, local: str | int, remote: str | int, ahead: bool = False) -> None: |
| 211 | + if ahead: |
| 212 | + raise SQLMeshError( |
| 213 | + f"{lib} (local) is using version '{local}' which is ahead of '{remote}' (remote). Please run a migration." |
| 214 | + ) |
| 215 | + raise SQLMeshError( |
| 216 | + f"{lib} (local) is using version '{local}' which is behind '{remote}' (remote). Please upgrade {lib}." |
| 217 | + ) |
| 218 | + |
| 219 | + if SCHEMA_VERSION < versions.schema_version: |
| 220 | + raise_error("SQLMesh", SCHEMA_VERSION, versions.schema_version) |
| 221 | + |
| 222 | + if major_minor(SQLGLOT_VERSION) < major_minor(versions.sqlglot_version): |
| 223 | + raise_error("SQLGlot", SQLGLOT_VERSION, versions.sqlglot_version) |
| 224 | + |
| 225 | + if validate: |
| 226 | + if SCHEMA_VERSION > versions.schema_version: |
| 227 | + raise_error("SQLMesh", SCHEMA_VERSION, versions.schema_version, ahead=True) |
| 228 | + |
| 229 | + if major_minor(SQLGLOT_VERSION) > major_minor(versions.sqlglot_version): |
| 230 | + raise_error("SQLGlot", SQLGLOT_VERSION, versions.sqlglot_version, ahead=True) |
| 231 | + |
| 232 | + return versions |
174 | 233 |
|
175 | 234 | @abc.abstractmethod |
176 | | - def init_schema(self) -> None: |
177 | | - """Optional initialization of the sync.""" |
| 235 | + def _get_versions(self, lock_for_update: bool = False) -> Versions: |
| 236 | + """Queries the store to get the current versions of SQLMesh and deps. |
| 237 | +
|
| 238 | + Args: |
| 239 | + lock_for_update: Whether or not the usage of this method plans to update the row. |
| 240 | +
|
| 241 | + Returns: |
| 242 | + The versions object. |
| 243 | + """ |
| 244 | + |
| 245 | + |
| 246 | +class StateSync(StateReader, abc.ABC): |
| 247 | + """Abstract base class for snapshot and environment state management.""" |
178 | 248 |
|
179 | 249 | @abc.abstractmethod |
180 | 250 | def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: |
@@ -291,3 +361,30 @@ def unpause_snapshots( |
291 | 361 | unpaused_dt: The datetime object which indicates when target snapshots |
292 | 362 | were unpaused. |
293 | 363 | """ |
| 364 | + |
| 365 | + def migrate(self) -> None: |
| 366 | + """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" |
| 367 | + versions = self.get_versions(validate=False) |
| 368 | + migrations = MIGRATIONS[versions.schema_version :] |
| 369 | + |
| 370 | + if not migrations and major_minor(SQLGLOT_VERSION) == versions.minor_sqlglot_version: |
| 371 | + return |
| 372 | + |
| 373 | + for migration in migrations: |
| 374 | + logger.info(f"Applying migration {migration}") |
| 375 | + migration.migrate(self) |
| 376 | + |
| 377 | + self._migrate_rows() |
| 378 | + self._update_versions() |
| 379 | + |
| 380 | + @abc.abstractmethod |
| 381 | + def _migrate_rows(self) -> None: |
| 382 | + """Migrate all rows in the state sync, including snapshots and environments.""" |
| 383 | + |
| 384 | + @abc.abstractmethod |
| 385 | + def _update_versions( |
| 386 | + self, |
| 387 | + schema_version: int = SCHEMA_VERSION, |
| 388 | + sqlglot_version: str = SQLGLOT_VERSION, |
| 389 | + ) -> None: |
| 390 | + """Update the schema versions to the latest running versions.""" |
0 commit comments