Skip to content

Commit

Permalink
Migrated tests cases to CompareAndRunTestCase
Browse files Browse the repository at this point in the history
  • Loading branch information
RustyGuard committed Mar 17, 2024
1 parent 34aee50 commit 542d97b
Show file tree
Hide file tree
Showing 12 changed files with 374 additions and 356 deletions.
4 changes: 4 additions & 0 deletions tests/base/render_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def compare_and_run(
*,
expected_upgrade: str,
expected_downgrade: str,
disable_running: bool = False,
):
"""Compares generated migration script is equal to expected_upgrade and expected_downgrade, then runs it"""
migration_context = create_migration_context(connection, target_schema)
Expand All @@ -39,6 +40,9 @@ def compare_and_run(
assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}"
assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}"

if disable_running:
return

exec(
upgrade_code,
{ # todo Use imports from template_args
Expand Down
19 changes: 11 additions & 8 deletions tests/base/run_migration_test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,26 @@


class CompareAndRunTestCase(ABC):
"""
Base class for all tests that expect specific alembic generated code
"""

disable_running = False

@abstractmethod
def get_database_schema(self) -> MetaData:
...
def get_database_schema(self) -> MetaData: ...

@abstractmethod
def get_target_schema(self) -> MetaData:
...
def get_target_schema(self) -> MetaData: ...

def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None:
pass

@abstractmethod
def get_expected_upgrade(self) -> str:
...
def get_expected_upgrade(self) -> str: ...

@abstractmethod
def get_expected_downgrade(self) -> str:
...
def get_expected_downgrade(self) -> str: ...

def test_run(self, connection: "Connection"):
database_schema = self.get_database_schema()
Expand All @@ -40,4 +42,5 @@ def test_run(self, connection: "Connection"):
target_schema,
expected_upgrade=self.get_expected_upgrade(),
expected_downgrade=self.get_expected_downgrade(),
disable_running=self.disable_running,
)
110 changes: 57 additions & 53 deletions tests/test_enum_creation/test_add_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from alembic_postgresql_enum.operations import CreateEnumOp
from tests.base.run_migration_test_abc import CompareAndRunTestCase
from tests.base.render_and_run import compare_and_run
from tests.schemas import (
get_schema_with_enum_variants,
USER_TABLE_NAME,
Expand Down Expand Up @@ -52,31 +51,32 @@ def get_expected_downgrade(self) -> str:
"""


def test_create_enum_before_add_column(connection: "Connection"):
class TestCreateEnumBeforeAddColumn(CompareAndRunTestCase):
"""Check that library correctly creates enum before its use inside add_column"""
database_schema = get_schema_without_enum()
database_schema.create_all(connection)

new_enum_variants = ["active", "passive"]

target_schema = get_schema_with_enum_variants(new_enum_variants)
def get_database_schema(self) -> MetaData:
return get_schema_without_enum()

def get_target_schema(self) -> MetaData:
return get_schema_with_enum_variants(self.new_enum_variants)

compare_and_run(
connection,
target_schema,
expected_upgrade=f"""
def get_expected_upgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind())
op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True))
sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind())
op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True))
# ### end Alembic commands ###
""",
expected_downgrade=f"""
"""

def get_expected_downgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('{USER_TABLE_NAME}', '{USER_STATUS_COLUMN_NAME}')
sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind())
sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind())
# ### end Alembic commands ###
""",
)
"""


def test_create_enum_diff_tuple(connection: "Connection"):
Expand Down Expand Up @@ -137,63 +137,67 @@ def test_create_enum_diff_tuple_with_array(connection: "Connection"):
assert add_column_tuple[0] == "add_column"


def test_with_non_native_enum(connection: "Connection"):
class TestWithNonNativeEnum(CompareAndRunTestCase):
"""Check that library ignores sa.Enum that are not native"""
database_schema = get_schema_without_enum()
database_schema.create_all(connection)

new_enum_variants = ["active", "passive"]

target_schema = MetaData()
def get_database_schema(self) -> MetaData:
return get_schema_without_enum()

Table(
USER_TABLE_NAME,
target_schema,
Column("id", Integer, primary_key=True),
Column(
USER_STATUS_COLUMN_NAME,
sqlalchemy.Enum(*new_enum_variants, name=USER_STATUS_ENUM_NAME, native_enum=False),
),
)
def get_target_schema(self) -> MetaData:
target_schema = MetaData()

Table(
USER_TABLE_NAME,
target_schema,
Column("id", Integer, primary_key=True),
Column(
USER_STATUS_COLUMN_NAME,
sqlalchemy.Enum(*self.new_enum_variants, name=USER_STATUS_ENUM_NAME, native_enum=False),
),
)

return target_schema

compare_and_run(
connection,
target_schema,
expected_upgrade=f"""
def get_expected_upgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', native_enum=False), nullable=True))
op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', native_enum=False), nullable=True))
# ### end Alembic commands ###
""",
expected_downgrade=f"""
"""

def get_expected_downgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('{USER_TABLE_NAME}', '{USER_STATUS_COLUMN_NAME}')
# ### end Alembic commands ###
""",
)
"""


def test_create_enum_before_add_column_metadata_list(connection: "Connection"):
class TestCreateEnumBeforeAddColumnMetadataList(CompareAndRunTestCase):
"""Check that library correctly creates enum before its use inside add_column when metadata is in list"""
database_schema = get_schema_without_enum()
database_schema.create_all(connection)

new_enum_variants = ["active", "passive"]

target_schema = get_schema_with_enum_variants(new_enum_variants)
def get_database_schema(self) -> MetaData:
return get_schema_without_enum()

compare_and_run(
connection,
[target_schema],
expected_upgrade=f"""
def get_target_schema(self) -> MetaData:
return get_schema_with_enum_variants(self.new_enum_variants)

def get_expected_upgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind())
op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True))
sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind())
op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True))
# ### end Alembic commands ###
""",
expected_downgrade=f"""
"""

def get_expected_downgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('{USER_TABLE_NAME}', '{USER_STATUS_COLUMN_NAME}')
sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind())
sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind())
# ### end Alembic commands ###
""",
)
"""
40 changes: 19 additions & 21 deletions tests/test_enum_creation/test_create_array.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from typing import TYPE_CHECKING, List
from typing import List

import sqlalchemy
from sqlalchemy import Table, Column, Integer, MetaData

from tests.base.render_and_run import compare_and_run
from tests.base.run_migration_test_abc import CompareAndRunTestCase
from tests.schemas import (
get_car_schema_without_enum,
CAR_TABLE_NAME,
CAR_COLORS_COLUMN_NAME,
CAR_COLORS_ENUM_NAME,
)

if TYPE_CHECKING:
from sqlalchemy import Connection
from sqlalchemy import Table, Column, Integer, MetaData


def get_schema_with_enum_in_sqlalchemy_array_variants(variants: List[str]) -> MetaData:
schema = MetaData()
Expand All @@ -31,28 +28,29 @@ def get_schema_with_enum_in_sqlalchemy_array_variants(variants: List[str]) -> Me
return schema


def test_create_enum_on_create_table_with_array(connection: "Connection"):
class TestCreateEnumOnCreateTableWithArray(CompareAndRunTestCase):
"""Check that library correctly creates enum before its use inside create_table. Enum is used in ARRAY"""
database_schema = get_car_schema_without_enum()
database_schema.create_all(connection)

new_enum_variants = ["black", "white", "red", "green", "blue", "other"]

target_schema = get_schema_with_enum_in_sqlalchemy_array_variants(new_enum_variants)
def get_database_schema(self) -> MetaData:
return get_car_schema_without_enum()

compare_and_run(
connection,
target_schema,
expected_upgrade=f"""
def get_target_schema(self) -> MetaData:
return get_schema_with_enum_in_sqlalchemy_array_variants(self.new_enum_variants)

def get_expected_upgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').create(op.get_bind())
op.add_column('{CAR_TABLE_NAME}', sa.Column('{CAR_COLORS_COLUMN_NAME}', sa.ARRAY(postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}', create_type=False)), nullable=True))
sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').create(op.get_bind())
op.add_column('{CAR_TABLE_NAME}', sa.Column('{CAR_COLORS_COLUMN_NAME}', sa.ARRAY(postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}', create_type=False)), nullable=True))
# ### end Alembic commands ###
""",
expected_downgrade=f"""
"""

def get_expected_downgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('{CAR_TABLE_NAME}', '{CAR_COLORS_COLUMN_NAME}')
sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').drop(op.get_bind())
sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').drop(op.get_bind())
# ### end Alembic commands ###
""",
)
"""
68 changes: 28 additions & 40 deletions tests/test_enum_creation/test_create_schema.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,40 @@
import textwrap
from typing import TYPE_CHECKING

from alembic import autogenerate
from sqlalchemy import Table, Column, Integer, MetaData
from sqlalchemy.dialects import postgresql

from tests.base.run_migration_test_abc import CompareAndRunTestCase
from tests.schemas import (
USER_TABLE_NAME,
USER_STATUS_ENUM_NAME,
USER_STATUS_COLUMN_NAME,
)
from tests.utils.migration_context import create_migration_context

if TYPE_CHECKING:
from sqlalchemy import Connection
from sqlalchemy import Table, Column, Integer, MetaData


def test_create_enum_on_create_table_inside_new_schema(connection: "Connection"):
class TestCreateEnumOnCreateTableInsideNewSchema(CompareAndRunTestCase):
"""Check that library correctly creates enum before its use inside create_table inside new schema"""

disable_running = True

new_enum_variants = ["active", "passive"]
non_existing_schema = "non_existing_schema"

target_schema = MetaData(schema=non_existing_schema)

Table(
USER_TABLE_NAME,
target_schema,
Column("id", Integer, primary_key=True),
Column(
USER_STATUS_COLUMN_NAME,
postgresql.ENUM(*new_enum_variants, name=USER_STATUS_ENUM_NAME, metadata=target_schema),
),
)
def get_database_schema(self) -> MetaData:
return MetaData()

def get_target_schema(self) -> MetaData:
target_schema = MetaData(schema=self.non_existing_schema)
Table(
USER_TABLE_NAME,
target_schema,
Column("id", Integer, primary_key=True),
Column(
USER_STATUS_COLUMN_NAME,
postgresql.ENUM(*self.new_enum_variants, name=USER_STATUS_ENUM_NAME, metadata=target_schema),
),
)
return target_schema

expected_upgrade = f"""
def get_expected_upgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum('active', 'passive', name='user_status', schema='non_existing_schema').create(op.get_bind())
op.create_table('users',
Expand All @@ -43,25 +44,12 @@ def test_create_enum_on_create_table_inside_new_schema(connection: "Connection")
schema='non_existing_schema'
)
# ### end Alembic commands ###
"""
expected_downgrade = f"""
"""

def get_expected_downgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('users', schema='non_existing_schema')
sa.Enum('active', 'passive', name='user_status', schema='non_existing_schema').drop(op.get_bind())
# ### end Alembic commands ###
"""

migration_context = create_migration_context(connection, target_schema)

template_args = {}
# todo _render_migration_diffs marked as legacy, maybe find something else
autogenerate._render_migration_diffs(migration_context, template_args)

upgrade_code = textwrap.dedent(" " + template_args["upgrades"])
downgrade_code = textwrap.dedent(" " + template_args["downgrades"])

expected_upgrade = textwrap.dedent(expected_upgrade).strip("\n ")
expected_downgrade = textwrap.dedent(expected_downgrade).strip("\n ")

assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}"
assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}"
"""
Loading

0 comments on commit 542d97b

Please sign in to comment.