Skip to content

Commit

Permalink
Use Engine.begin() instead of Engine.connect() to ensure DML/DDL stat…
Browse files Browse the repository at this point in the history
…ements are committed in migration code (#1129)

# Description
One of the changes in SQLAlchemy from 1.4.X -> 2.X is that they've
removed automatic commits for DDL/DML statements fed into `Connection`s
created from `Engine.connect()`. This leads to a few minor bugs in the
DB migration code post `0.39.0` release (and the SQLAlchemy 2.X
migration in #1127).

## This PR
This PR fixes those bugs by using `Connection`s created by
`Engine.begin()` instead

### Testing
```bash
$ bazel run sematic/db/tests:test_migrate
```
+ manual testing in a repo that's using Sematic as a dependency
  • Loading branch information
bcalvert-graft authored Sep 10, 2024
1 parent fc96e11 commit 9b13728
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sematic/db/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _is_migration_file(file_name: str) -> bool:


def _get_current_versions() -> List[str]:
with db().get_engine().connect() as conn:
with db().get_engine().begin() as conn:
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def up():


def down():
with db().get_engine().connect() as conn:
with db().get_engine().begin() as conn:
# TODO #303: standardize NULL vs 'null'
run_id_exception_json_pairs = conn.execute(
text(
Expand Down
4 changes: 2 additions & 2 deletions sematic/db/tests/test_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_invalid_sql(_, test_db_empty): # noqa: F811
def test_migrate(_, test_db_empty): # noqa: F811

with pytest.raises(OperationalError):
with db().get_engine().connect() as conn:
with db().get_engine().begin() as conn:
conn.execute(text("SELECT version FROM schema_migrations;"))

migrate_up()
Expand All @@ -93,7 +93,7 @@ def test_migrate(_, test_db_empty): # noqa: F811
assert len(current_versions) > 0

# Test tables were created
with db().get_engine().connect() as conn:
with db().get_engine().begin() as conn:
run_count = conn.execute(text("SELECT COUNT(*) from runs;"))

assert list(run_count)[0][0] == 0
Expand Down

0 comments on commit 9b13728

Please sign in to comment.