diff --git a/alembic_postgresql_enum/add_create_type_false.py b/alembic_postgresql_enum/add_create_type_false.py index ea67246..9c90230 100644 --- a/alembic_postgresql_enum/add_create_type_false.py +++ b/alembic_postgresql_enum/add_create_type_false.py @@ -62,14 +62,10 @@ def add_create_type_false(upgrade_ops: UpgradeOps): if isinstance(operations_group, ModifyTableOps): for operation in operations_group.ops: if isinstance(operation, AddColumnOp): - column: Column = operation.column - - inject_repr_into_enums(column) - + inject_repr_into_enums(operation.column) elif isinstance(operation, DropColumnOp): - column: Column = operation._reverse.column - - inject_repr_into_enums(column) + assert operation._reverse is not None + inject_repr_into_enums(operation._reverse.column) elif isinstance(operations_group, CreateTableOp): for column in operations_group.columns: @@ -77,6 +73,7 @@ def add_create_type_false(upgrade_ops: UpgradeOps): inject_repr_into_enums(column) elif isinstance(operations_group, DropTableOp): + assert operations_group._reverse is not None for column in operations_group._reverse.columns: if isinstance(column, Column): inject_repr_into_enums(column) diff --git a/alembic_postgresql_enum/add_postgres_using_to_text.py b/alembic_postgresql_enum/add_postgres_using_to_text.py index 8e3116e..adf33ba 100644 --- a/alembic_postgresql_enum/add_postgres_using_to_text.py +++ b/alembic_postgresql_enum/add_postgres_using_to_text.py @@ -39,6 +39,7 @@ def _postgres_using_alter_column(autogen_context: AutogenContext, op: ops.AlterC def add_postgres_using_to_alter_operation(op: AlterColumnOp): + assert op.modify_type is not None op.kw["postgresql_using"] = f"{op.column_name}::{op.modify_type.name}" log.info("postgresql_using added to %r.%r alteration", op.table_name, op.column_name) op.__class__ = PostgresUsingAlterColumnOp diff --git a/alembic_postgresql_enum/compare_dispatch.py b/alembic_postgresql_enum/compare_dispatch.py index a61a90f..1f40172 100644 --- a/alembic_postgresql_enum/compare_dispatch.py +++ b/alembic_postgresql_enum/compare_dispatch.py @@ -1,3 +1,4 @@ +import logging from typing import Iterable, Union import alembic @@ -16,6 +17,9 @@ from alembic_postgresql_enum.get_enum_data import get_defined_enums, get_declared_enums +log = logging.getLogger(f"alembic.{__name__}") + + @alembic.autogenerate.comparators.dispatch_for("schema") def compare_enums( autogen_context: AutogenContext, @@ -28,6 +32,12 @@ def compare_enums( for each defined enum that has changed new entries when compared to its declared version. """ + if autogen_context.dialect.name != "postgresql": + log.warning( + f"This library only supports postgresql, but you are using {autogen_context.dialect.name}, skipping" + ) + return + add_create_type_false(upgrade_ops) add_postgres_using_to_text(upgrade_ops) @@ -39,6 +49,12 @@ def compare_enums( if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names: schema_names.append(operations_group.schema) + assert ( + autogen_context.dialect is not None + and autogen_context.dialect.default_schema_name is not None + and autogen_context.connection is not None + and autogen_context.metadata is not None + ) for schema in schema_names: default_schema = autogen_context.dialect.default_schema_name if schema is None: diff --git a/alembic_postgresql_enum/connection.py b/alembic_postgresql_enum/connection.py index e09b7ba..02d270d 100644 --- a/alembic_postgresql_enum/connection.py +++ b/alembic_postgresql_enum/connection.py @@ -1,10 +1,11 @@ from contextlib import contextmanager +from typing import Iterator import sqlalchemy @contextmanager -def get_connection(operations) -> sqlalchemy.engine.Connection: +def get_connection(operations) -> Iterator[sqlalchemy.engine.Connection]: """ SQLAlchemy 2.0 changes the operation binding location; bridge function to support both 1.x and 2.x. diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index 611da42..75706b8 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -1,5 +1,6 @@ from collections import defaultdict -from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING +from enum import Enum +from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast import sqlalchemy from sqlalchemy import MetaData @@ -93,16 +94,16 @@ def get_declared_enums( if not column_type_is_enum(column_type): continue - column_type_schema = column_type.schema or default_schema + column_type_schema = column_type.schema or default_schema # type: ignore[attr-defined] if column_type_schema != schema: continue - if column_type.name not in enum_name_to_values: - enum_name_to_values[column_type.name] = get_enum_values(column_type) + if column_type.name not in enum_name_to_values: # type: ignore[attr-defined] + enum_name_to_values[column_type.name] = get_enum_values(cast(sqlalchemy.Enum, column_type)) # type: ignore[attr-defined] table_schema = table.schema or default_schema column_default = get_column_default(connection, table_schema, table.name, column.name) - enum_name_to_table_references[column_type.name].add( + enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined] TableReference( table_schema=table_schema, table_name=table.name, diff --git a/alembic_postgresql_enum/get_enum_data/types.py b/alembic_postgresql_enum/get_enum_data/types.py index 42e299e..904c09b 100644 --- a/alembic_postgresql_enum/get_enum_data/types.py +++ b/alembic_postgresql_enum/get_enum_data/types.py @@ -20,9 +20,9 @@ def __repr__(self): class TableReference: table_name: str column_name: str - table_schema: Optional[str] = Unspecified # 'Unspecified' default is for migrations from older versions + table_schema: Optional[str] = Unspecified # type: ignore[assignment] # 'Unspecified' default is for migrations from older versions column_type: ColumnType = ColumnType.COMMON - existing_server_default: str = None + existing_server_default: Optional[str] = None def __repr__(self): result_str = "TableReference(" @@ -48,7 +48,7 @@ def table_name_with_schema(self): prefix = f"{self.table_schema}." else: prefix = "" - return f"{prefix}{self.table_name}" + return f'{prefix}"{self.table_name}"' EnumNamesToValues = Dict[str, Tuple[str, ...]] diff --git a/alembic_postgresql_enum/operations/create_enum.py b/alembic_postgresql_enum/operations/create_enum.py index 9b4afed..6d0ca46 100644 --- a/alembic_postgresql_enum/operations/create_enum.py +++ b/alembic_postgresql_enum/operations/create_enum.py @@ -19,6 +19,7 @@ def reverse(self): @alembic.autogenerate.render.renderers.dispatch_for(CreateEnumOp) def render_create_enum_op(autogen_context: AutogenContext, op: CreateEnumOp): + assert autogen_context.dialect is not None if op.schema != autogen_context.dialect.default_schema_name: return f""" sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').create(op.get_bind()) diff --git a/alembic_postgresql_enum/operations/drop_enum.py b/alembic_postgresql_enum/operations/drop_enum.py index 6872482..1f7f2ff 100644 --- a/alembic_postgresql_enum/operations/drop_enum.py +++ b/alembic_postgresql_enum/operations/drop_enum.py @@ -19,6 +19,7 @@ def reverse(self): @alembic.autogenerate.render.renderers.dispatch_for(DropEnumOp) def render_drop_enum_op(autogen_context: AutogenContext, op: DropEnumOp): + assert autogen_context.dialect is not None if op.schema != autogen_context.dialect.default_schema_name: return f""" sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').drop(op.get_bind()) diff --git a/alembic_postgresql_enum/operations/sync_enum_values.py b/alembic_postgresql_enum/operations/sync_enum_values.py index 896d53c..714fe9f 100644 --- a/alembic_postgresql_enum/operations/sync_enum_values.py +++ b/alembic_postgresql_enum/operations/sync_enum_values.py @@ -1,3 +1,4 @@ +import logging from typing import List, Tuple, Any, Iterable, TYPE_CHECKING import alembic.autogenerate @@ -31,6 +32,9 @@ from alembic_postgresql_enum.get_enum_data import TableReference, ColumnType +log = logging.getLogger(f"alembic.{__name__}") + + @alembic.operations.base.Operations.register_operation("sync_enum_values") class SyncEnumValuesOp(alembic.operations.ops.MigrateOperation): operation_name = "change_enum_variants" @@ -138,6 +142,12 @@ def sync_enum_values( ] If there was server default with old_name it will be renamed accordingly """ + if operations.migration_context.dialect.name != "postgresql": + log.warning( + f"This library only supports postgresql, but you are using {operations.migration_context.dialect.name}, skipping" + ) + return + enum_values_to_rename = list(enum_values_to_rename) with get_connection(operations) as connection: diff --git a/alembic_postgresql_enum/py.typed b/alembic_postgresql_enum/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 60fec30..3907f4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "alembic-postgresql-enum" -version = "1.1.2" +version = "1.2.0" description = "Alembic autogenerate support for creation, alteration and deletion of enums" authors = ["RustyGuard"] license = "MIT" diff --git a/tests/base/render_and_run.py b/tests/base/render_and_run.py index ff39f60..a4e5fe6 100644 --- a/tests/base/render_and_run.py +++ b/tests/base/render_and_run.py @@ -20,6 +20,7 @@ def compare_and_run( *, expected_upgrade: str, expected_downgrade: str, + disable_running: bool = False, ): """Compares generated migration script is equal to expected_upgrade and expected_downgrade, then runs it""" migration_context = create_migration_context(connection, target_schema) @@ -39,6 +40,9 @@ def compare_and_run( assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}" assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}" + if disable_running: + return + exec( upgrade_code, { # todo Use imports from template_args diff --git a/tests/base/run_migration_test_abc.py b/tests/base/run_migration_test_abc.py index fb2b7cb..eb6069a 100644 --- a/tests/base/run_migration_test_abc.py +++ b/tests/base/run_migration_test_abc.py @@ -9,12 +9,21 @@ class CompareAndRunTestCase(ABC): + """ + Base class for all tests that expect specific alembic generated code + """ + + disable_running = False + @abstractmethod def get_database_schema(self) -> MetaData: ... @abstractmethod def get_target_schema(self) -> MetaData: ... + def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None: + pass + @abstractmethod def get_expected_upgrade(self) -> str: ... @@ -26,10 +35,12 @@ def test_run(self, connection: "Connection"): target_schema = self.get_target_schema() database_schema.create_all(connection) + self.insert_migration_data(connection, database_schema) compare_and_run( connection, target_schema, expected_upgrade=self.get_expected_upgrade(), expected_downgrade=self.get_expected_downgrade(), + disable_running=self.disable_running, ) diff --git a/tests/test_alter_column/test_text_column.py b/tests/test_alter_column/test_text_column.py index d213963..647c83c 100644 --- a/tests/test_alter_column/test_text_column.py +++ b/tests/test_alter_column/test_text_column.py @@ -4,11 +4,11 @@ from sqlalchemy import MetaData, Table, Column, TEXT, insert from sqlalchemy.dialects import postgresql +from tests.base.run_migration_test_abc import CompareAndRunTestCase + if TYPE_CHECKING: from sqlalchemy import Connection -from tests.base.render_and_run import compare_and_run - class NewEnum(Enum): A = "a" @@ -16,44 +16,49 @@ class NewEnum(Enum): C = "c" -def test_text_column(connection: "Connection"): - database_schema = MetaData() - a_table = Table("a", database_schema, Column("value", TEXT)) - database_schema.create_all(connection) - connection.execute( - insert(a_table).values( - [ - {"value": NewEnum.A.name}, - {"value": NewEnum.B.name}, - {"value": NewEnum.B.name}, - {"value": NewEnum.C.name}, - ] +class TestTextColumn(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + database_schema = MetaData() + Table("a", database_schema, Column("value", TEXT)) + return database_schema + + def get_target_schema(self) -> MetaData: + target_schema = MetaData() + Table("a", target_schema, Column("value", postgresql.ENUM(NewEnum))) + return target_schema + + def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None: + a_table = database_schema.tables["a"] + connection.execute( + insert(a_table).values( + [ + {"value": NewEnum.A.name}, + {"value": NewEnum.B.name}, + {"value": NewEnum.B.name}, + {"value": NewEnum.C.name}, + ] + ) ) - ) - - target_schema = MetaData() - Table("a", target_schema, Column("value", postgresql.ENUM(NewEnum))) - - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" - # ### commands auto generated by Alembic - please adjust! ### - sa.Enum('A', 'B', 'C', name='newenum').create(op.get_bind()) - op.alter_column('a', 'value', - existing_type=sa.TEXT(), - type_=postgresql.ENUM('A', 'B', 'C', name='newenum'), - existing_nullable=True, - postgresql_using='value::newenum') - # ### end Alembic commands ### - """, - expected_downgrade=f""" - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column('a', 'value', - existing_type=postgresql.ENUM('A', 'B', 'C', name='newenum'), - type_=sa.TEXT(), - existing_nullable=True) - sa.Enum('A', 'B', 'C', name='newenum').drop(op.get_bind()) - # ### end Alembic commands ### - """, - ) + + def get_expected_upgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + sa.Enum('A', 'B', 'C', name='newenum').create(op.get_bind()) + op.alter_column('a', 'value', + existing_type=sa.TEXT(), + type_=postgresql.ENUM('A', 'B', 'C', name='newenum'), + existing_nullable=True, + postgresql_using='value::newenum') + # ### end Alembic commands ### + """ + + def get_expected_downgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('a', 'value', + existing_type=postgresql.ENUM('A', 'B', 'C', name='newenum'), + type_=sa.TEXT(), + existing_nullable=True) + sa.Enum('A', 'B', 'C', name='newenum').drop(op.get_bind()) + # ### end Alembic commands ### + """ diff --git a/tests/test_enum_creation/test_add_column.py b/tests/test_enum_creation/test_add_column.py index 393fa6d..a97aa54 100644 --- a/tests/test_enum_creation/test_add_column.py +++ b/tests/test_enum_creation/test_add_column.py @@ -6,7 +6,7 @@ from alembic.operations import ops from alembic_postgresql_enum.operations import CreateEnumOp -from tests.base.render_and_run import compare_and_run +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ( get_schema_with_enum_variants, USER_TABLE_NAME, @@ -25,31 +25,58 @@ from sqlalchemy import MetaData, Table, Column, Integer -def test_create_enum_before_add_column(connection: "Connection"): +class TestCreateEnumBeforeAddColumn(CompareAndRunTestCase): + new_enum_variants = ["active", "passive"] + + def get_database_schema(self) -> MetaData: + return get_schema_without_enum() + + def get_target_schema(self) -> MetaData: + return get_schema_with_enum_variants(self.new_enum_variants) + + def get_expected_upgrade(self) -> str: + return f""" + # ### commands auto generated by Alembic - please adjust! ### + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind()) + op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True)) + # ### end Alembic commands ### + """ + + def get_expected_downgrade(self) -> str: + return f""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('{USER_TABLE_NAME}', '{USER_STATUS_COLUMN_NAME}') + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) + # ### end Alembic commands ### + """ + + +class TestCreateEnumBeforeAddColumn(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside add_column""" - database_schema = get_schema_without_enum() - database_schema.create_all(connection) new_enum_variants = ["active", "passive"] - target_schema = get_schema_with_enum_variants(new_enum_variants) + def get_database_schema(self) -> MetaData: + return get_schema_without_enum() + + def get_target_schema(self) -> MetaData: + return get_schema_with_enum_variants(self.new_enum_variants) - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind()) - op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True)) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind()) + op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True)) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_column('{USER_TABLE_NAME}', '{USER_STATUS_COLUMN_NAME}') - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ def test_create_enum_diff_tuple(connection: "Connection"): @@ -110,63 +137,67 @@ def test_create_enum_diff_tuple_with_array(connection: "Connection"): assert add_column_tuple[0] == "add_column" -def test_with_non_native_enum(connection: "Connection"): +class TestWithNonNativeEnum(CompareAndRunTestCase): """Check that library ignores sa.Enum that are not native""" - database_schema = get_schema_without_enum() - database_schema.create_all(connection) new_enum_variants = ["active", "passive"] - target_schema = MetaData() + def get_database_schema(self) -> MetaData: + return get_schema_without_enum() - Table( - USER_TABLE_NAME, - target_schema, - Column("id", Integer, primary_key=True), - Column( - USER_STATUS_COLUMN_NAME, - sqlalchemy.Enum(*new_enum_variants, name=USER_STATUS_ENUM_NAME, native_enum=False), - ), - ) + def get_target_schema(self) -> MetaData: + target_schema = MetaData() + + Table( + USER_TABLE_NAME, + target_schema, + Column("id", Integer, primary_key=True), + Column( + USER_STATUS_COLUMN_NAME, + sqlalchemy.Enum(*self.new_enum_variants, name=USER_STATUS_ENUM_NAME, native_enum=False), + ), + ) - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + return target_schema + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### - op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', native_enum=False), nullable=True)) + op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', native_enum=False), nullable=True)) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_column('{USER_TABLE_NAME}', '{USER_STATUS_COLUMN_NAME}') # ### end Alembic commands ### - """, - ) + """ -def test_create_enum_before_add_column_metadata_list(connection: "Connection"): +class TestCreateEnumBeforeAddColumnMetadataList(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside add_column when metadata is in list""" - database_schema = get_schema_without_enum() - database_schema.create_all(connection) new_enum_variants = ["active", "passive"] - target_schema = get_schema_with_enum_variants(new_enum_variants) + def get_database_schema(self) -> MetaData: + return get_schema_without_enum() - compare_and_run( - connection, - [target_schema], - expected_upgrade=f""" + def get_target_schema(self) -> MetaData: + return get_schema_with_enum_variants(self.new_enum_variants) + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind()) - op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True)) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind()) + op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), nullable=True)) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_column('{USER_TABLE_NAME}', '{USER_STATUS_COLUMN_NAME}') - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ diff --git a/tests/test_enum_creation/test_create_array.py b/tests/test_enum_creation/test_create_array.py index bc67845..c8d440d 100644 --- a/tests/test_enum_creation/test_create_array.py +++ b/tests/test_enum_creation/test_create_array.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, List +from typing import List import sqlalchemy +from sqlalchemy import Table, Column, Integer, MetaData -from tests.base.render_and_run import compare_and_run +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ( get_car_schema_without_enum, CAR_TABLE_NAME, @@ -10,10 +11,6 @@ CAR_COLORS_ENUM_NAME, ) -if TYPE_CHECKING: - from sqlalchemy import Connection -from sqlalchemy import Table, Column, Integer, MetaData - def get_schema_with_enum_in_sqlalchemy_array_variants(variants: List[str]) -> MetaData: schema = MetaData() @@ -31,28 +28,29 @@ def get_schema_with_enum_in_sqlalchemy_array_variants(variants: List[str]) -> Me return schema -def test_create_enum_on_create_table_with_array(connection: "Connection"): +class TestCreateEnumOnCreateTableWithArray(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside create_table. Enum is used in ARRAY""" - database_schema = get_car_schema_without_enum() - database_schema.create_all(connection) new_enum_variants = ["black", "white", "red", "green", "blue", "other"] - target_schema = get_schema_with_enum_in_sqlalchemy_array_variants(new_enum_variants) + def get_database_schema(self) -> MetaData: + return get_car_schema_without_enum() - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_target_schema(self) -> MetaData: + return get_schema_with_enum_in_sqlalchemy_array_variants(self.new_enum_variants) + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').create(op.get_bind()) - op.add_column('{CAR_TABLE_NAME}', sa.Column('{CAR_COLORS_COLUMN_NAME}', sa.ARRAY(postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}', create_type=False)), nullable=True)) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').create(op.get_bind()) + op.add_column('{CAR_TABLE_NAME}', sa.Column('{CAR_COLORS_COLUMN_NAME}', sa.ARRAY(postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}', create_type=False)), nullable=True)) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_column('{CAR_TABLE_NAME}', '{CAR_COLORS_COLUMN_NAME}') - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ diff --git a/tests/test_enum_creation/test_create_schema.py b/tests/test_enum_creation/test_create_schema.py index 0261e5f..df499e4 100644 --- a/tests/test_enum_creation/test_create_schema.py +++ b/tests/test_enum_creation/test_create_schema.py @@ -1,39 +1,40 @@ -import textwrap -from typing import TYPE_CHECKING - -from alembic import autogenerate +from sqlalchemy import Table, Column, Integer, MetaData from sqlalchemy.dialects import postgresql +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ( USER_TABLE_NAME, USER_STATUS_ENUM_NAME, USER_STATUS_COLUMN_NAME, ) -from tests.utils.migration_context import create_migration_context - -if TYPE_CHECKING: - from sqlalchemy import Connection -from sqlalchemy import Table, Column, Integer, MetaData -def test_create_enum_on_create_table_inside_new_schema(connection: "Connection"): +class TestCreateEnumOnCreateTableInsideNewSchema(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside create_table inside new schema""" + + disable_running = True + new_enum_variants = ["active", "passive"] non_existing_schema = "non_existing_schema" - target_schema = MetaData(schema=non_existing_schema) - - Table( - USER_TABLE_NAME, - target_schema, - Column("id", Integer, primary_key=True), - Column( - USER_STATUS_COLUMN_NAME, - postgresql.ENUM(*new_enum_variants, name=USER_STATUS_ENUM_NAME, metadata=target_schema), - ), - ) + def get_database_schema(self) -> MetaData: + return MetaData() + + def get_target_schema(self) -> MetaData: + target_schema = MetaData(schema=self.non_existing_schema) + Table( + USER_TABLE_NAME, + target_schema, + Column("id", Integer, primary_key=True), + Column( + USER_STATUS_COLUMN_NAME, + postgresql.ENUM(*self.new_enum_variants, name=USER_STATUS_ENUM_NAME, metadata=target_schema), + ), + ) + return target_schema - expected_upgrade = f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### sa.Enum('active', 'passive', name='user_status', schema='non_existing_schema').create(op.get_bind()) op.create_table('users', @@ -43,25 +44,12 @@ def test_create_enum_on_create_table_inside_new_schema(connection: "Connection") schema='non_existing_schema' ) # ### end Alembic commands ### - """ - expected_downgrade = f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table('users', schema='non_existing_schema') sa.Enum('active', 'passive', name='user_status', schema='non_existing_schema').drop(op.get_bind()) # ### end Alembic commands ### - """ - - migration_context = create_migration_context(connection, target_schema) - - template_args = {} - # todo _render_migration_diffs marked as legacy, maybe find something else - autogenerate._render_migration_diffs(migration_context, template_args) - - upgrade_code = textwrap.dedent(" " + template_args["upgrades"]) - downgrade_code = textwrap.dedent(" " + template_args["downgrades"]) - - expected_upgrade = textwrap.dedent(expected_upgrade).strip("\n ") - expected_downgrade = textwrap.dedent(expected_downgrade).strip("\n ") - - assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}" - assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}" + """ diff --git a/tests/test_enum_creation/test_create_table.py b/tests/test_enum_creation/test_create_table.py index b798719..e727f94 100644 --- a/tests/test_enum_creation/test_create_table.py +++ b/tests/test_enum_creation/test_create_table.py @@ -1,9 +1,8 @@ -from typing import TYPE_CHECKING - import sqlalchemy +from sqlalchemy import Table, Column, Integer, MetaData from sqlalchemy.dialects import postgresql -from tests.base.render_and_run import compare_and_run +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ( get_schema_with_enum_variants, USER_TABLE_NAME, @@ -17,21 +16,20 @@ ANOTHER_SCHEMA_NAME, ) -if TYPE_CHECKING: - from sqlalchemy import Connection -from sqlalchemy import Table, Column, Integer, MetaData - -def test_create_enum_on_create_table(connection: "Connection"): +class TestCreateEnumOnCreateTable(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside create_table""" + new_enum_variants = ["active", "passive"] - target_schema = get_schema_with_enum_variants(new_enum_variants) + def get_database_schema(self) -> MetaData: + return MetaData() + + def get_target_schema(self) -> MetaData: + return get_schema_with_enum_variants(self.new_enum_variants) - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### sa.Enum('active', 'passive', name='user_status').create(op.get_bind()) op.create_table('{USER_TABLE_NAME}', @@ -40,65 +38,70 @@ def test_create_enum_on_create_table(connection: "Connection"): sa.PrimaryKeyConstraint('id') ) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table('{USER_TABLE_NAME}') - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ -def test_create_enum_on_create_table_with_array(connection: "Connection"): +class TestCreateEnumOnCreateTableWithArray(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside create_table. Enum is used in ARRAY""" - database_schema = get_car_schema_without_enum() - database_schema.create_all(connection) new_enum_variants = ["black", "white", "red", "green", "blue", "other"] - target_schema = get_schema_with_enum_in_array_variants(new_enum_variants) + def get_database_schema(self) -> MetaData: + return get_car_schema_without_enum() - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_target_schema(self) -> MetaData: + return get_schema_with_enum_in_array_variants(self.new_enum_variants) + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').create(op.get_bind()) - op.add_column('{CAR_TABLE_NAME}', sa.Column('{CAR_COLORS_COLUMN_NAME}', postgresql.ARRAY(postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}', create_type=False)), nullable=True)) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').create(op.get_bind()) + op.add_column('{CAR_TABLE_NAME}', sa.Column('{CAR_COLORS_COLUMN_NAME}', postgresql.ARRAY(postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}', create_type=False)), nullable=True)) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_column('{CAR_TABLE_NAME}', '{CAR_COLORS_COLUMN_NAME}') - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{CAR_COLORS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ -def test_create_enum_on_create_table_with_sa_enum(connection: "Connection"): +class TestCreateEnumOnCreateTableWithSaEnum(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside create_table. sqlalchemy.Enum is used sqlalchemy.Enum should be converted to postgresql.ENUM to specify create_type=False """ + new_enum_variants = ["active", "passive"] - target_schema = MetaData() - - Table( - USER_TABLE_NAME, - target_schema, - Column("id", Integer, primary_key=True), - Column( - USER_STATUS_COLUMN_NAME, - sqlalchemy.Enum(*new_enum_variants, name=USER_STATUS_ENUM_NAME), - ), - ) - - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_database_schema(self) -> MetaData: + return MetaData() + + def get_target_schema(self) -> MetaData: + target_schema = MetaData() + Table( + USER_TABLE_NAME, + target_schema, + Column("id", Integer, primary_key=True), + Column( + USER_STATUS_COLUMN_NAME, + sqlalchemy.Enum(*self.new_enum_variants, name=USER_STATUS_ENUM_NAME), + ), + ) + return target_schema + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### sa.Enum('active', 'passive', name='user_status').create(op.get_bind()) op.create_table('{USER_TABLE_NAME}', @@ -107,96 +110,107 @@ def test_create_enum_on_create_table_with_sa_enum(connection: "Connection"): sa.PrimaryKeyConstraint('id') ) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table('{USER_TABLE_NAME}') - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ -def test_create_enum_on_create_table_with_another_schema(connection: "Connection"): +class TestCreateEnumOnCreateTableWithAnotherSchema(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside create_table inside another table_schema""" + new_enum_variants = ["active", "passive"] - target_schema = MetaData() - - Table( - USER_TABLE_NAME, - target_schema, - Column("id", Integer, primary_key=True), - Column( - USER_STATUS_COLUMN_NAME, - postgresql.ENUM( - *new_enum_variants, - name=USER_STATUS_ENUM_NAME, - schema=ANOTHER_SCHEMA_NAME, + def get_database_schema(self) -> MetaData: + return MetaData() + + def get_target_schema(self) -> MetaData: + target_schema = MetaData() + + Table( + USER_TABLE_NAME, + target_schema, + Column("id", Integer, primary_key=True), + Column( + USER_STATUS_COLUMN_NAME, + postgresql.ENUM( + *self.new_enum_variants, + name=USER_STATUS_ENUM_NAME, + schema=ANOTHER_SCHEMA_NAME, + ), ), - ), - schema=ANOTHER_SCHEMA_NAME, - ) - - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + schema=ANOTHER_SCHEMA_NAME, + ) + + return target_schema + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}').create(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}').create(op.get_bind()) op.create_table('{USER_TABLE_NAME}', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('status', postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}', create_type=False), nullable=True), + sa.Column('status', postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}', create_type=False), nullable=True), sa.PrimaryKeyConstraint('id'), schema='{ANOTHER_SCHEMA_NAME}' ) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table('{USER_TABLE_NAME}', schema='{ANOTHER_SCHEMA_NAME}') - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ -def test_create_enum_on_create_table_with_another_schema_on_metadata( - connection: "Connection", -): +class TestCreateEnumOnCreateTableWithAnotherSchemaOnMetadata(CompareAndRunTestCase): """Check that library correctly creates enum before its use inside create_table inside another schema, specified on Metadata""" + new_enum_variants = ["active", "passive"] - target_schema = MetaData(schema=ANOTHER_SCHEMA_NAME) - - Table( - USER_TABLE_NAME, - target_schema, - Column("id", Integer, primary_key=True), - Column( - USER_STATUS_COLUMN_NAME, - postgresql.ENUM(*new_enum_variants, name=USER_STATUS_ENUM_NAME, metadata=target_schema), - ), - ) - - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_database_schema(self) -> MetaData: + return MetaData() + + def get_target_schema(self) -> MetaData: + target_schema = MetaData(schema=ANOTHER_SCHEMA_NAME) + + Table( + USER_TABLE_NAME, + target_schema, + Column("id", Integer, primary_key=True), + Column( + USER_STATUS_COLUMN_NAME, + postgresql.ENUM(*self.new_enum_variants, name=USER_STATUS_ENUM_NAME, metadata=target_schema), + ), + ) + + return target_schema + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}').create(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}').create(op.get_bind()) op.create_table('{USER_TABLE_NAME}', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('status', postgresql.ENUM({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}', create_type=False), nullable=True), + sa.Column('status', postgresql.ENUM({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}', create_type=False), nullable=True), sa.PrimaryKeyConstraint('id'), schema='{ANOTHER_SCHEMA_NAME}' ) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table('{USER_TABLE_NAME}', schema='{ANOTHER_SCHEMA_NAME}') - sa.Enum({', '.join(map(repr, new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.new_enum_variants))}, name='{USER_STATUS_ENUM_NAME}', schema='{ANOTHER_SCHEMA_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ diff --git a/tests/test_enum_creation/test_drop_column.py b/tests/test_enum_creation/test_drop_column.py index 1c8c615..e188247 100644 --- a/tests/test_enum_creation/test_drop_column.py +++ b/tests/test_enum_creation/test_drop_column.py @@ -1,9 +1,10 @@ from typing import TYPE_CHECKING from alembic_postgresql_enum.operations import DropEnumOp +from tests.base.run_migration_test_abc import CompareAndRunTestCase if TYPE_CHECKING: - from sqlalchemy import Connection + from sqlalchemy import Connection, MetaData from tests.schemas import ( get_schema_with_enum_in_array_variants, @@ -18,7 +19,6 @@ from alembic.autogenerate import api from alembic.operations import ops -from tests.base.render_and_run import compare_and_run from tests.schemas import ( get_schema_with_enum_variants, USER_TABLE_NAME, @@ -31,58 +31,64 @@ if TYPE_CHECKING: from sqlalchemy import Connection +from sqlalchemy import MetaData -def test_delete_enum_after_drop_column(connection: "Connection"): +class TestDeleteEnumAfterDropColumn(CompareAndRunTestCase): """Check that library correctly removes unused enum after drop_column""" + enum_variants_to_delete = ["active", "passive"] - database_schema = get_schema_with_enum_variants(enum_variants_to_delete) - database_schema.create_all(connection) - target_schema = get_schema_without_enum() + def get_database_schema(self) -> MetaData: + return get_schema_with_enum_variants(self.enum_variants_to_delete) + + def get_target_schema(self) -> MetaData: + return get_schema_without_enum() - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_column('{USER_TABLE_NAME}', '{USER_STATUS_COLUMN_NAME}') - sa.Enum({', '.join(map(repr, enum_variants_to_delete))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.enum_variants_to_delete))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + # For some reason alembic decided to add redundant autoincrement=False on downgrade + return f""" # ### commands auto generated by Alembic - please adjust! ### - sa.Enum({', '.join(map(repr, enum_variants_to_delete))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind()) - op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, enum_variants_to_delete))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), autoincrement=False, nullable=True)) + sa.Enum({', '.join(map(repr, self.enum_variants_to_delete))}, name='{USER_STATUS_ENUM_NAME}').create(op.get_bind()) + op.add_column('{USER_TABLE_NAME}', sa.Column('{USER_STATUS_COLUMN_NAME}', postgresql.ENUM({', '.join(map(repr, self.enum_variants_to_delete))}, name='{USER_STATUS_ENUM_NAME}', create_type=False), autoincrement=False, nullable=True)) # ### end Alembic commands ### - """, - ) # For some reason alembic decided to add redundant autoincrement=False on downgrade + """ -def test_delete_enum_after_drop_column_with_array(connection: "Connection"): +class TestDeleteEnumAfterDropColumnWithArray(CompareAndRunTestCase): """Check that library correctly removes unused enum after drop_column. Enum is used in ARRAY""" + enum_variants_to_delete = ["black", "white", "red", "green", "blue", "other"] - database_schema = get_schema_with_enum_in_array_variants(enum_variants_to_delete) - database_schema.create_all(connection) - target_schema = get_car_schema_without_enum() + def get_database_schema(self) -> MetaData: + return get_schema_with_enum_in_array_variants(self.enum_variants_to_delete) - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_target_schema(self) -> MetaData: + return get_car_schema_without_enum() + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_column('{CAR_TABLE_NAME}', '{CAR_COLORS_COLUMN_NAME}') - sa.Enum({', '.join(map(repr, enum_variants_to_delete))}, name='{CAR_COLORS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.enum_variants_to_delete))}, name='{CAR_COLORS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### - sa.Enum({', '.join(map(repr, enum_variants_to_delete))}, name='{CAR_COLORS_ENUM_NAME}').create(op.get_bind()) - op.add_column('{CAR_TABLE_NAME}', sa.Column('{CAR_COLORS_COLUMN_NAME}', postgresql.ARRAY(postgresql.ENUM({', '.join(map(repr, enum_variants_to_delete))}, name='{CAR_COLORS_ENUM_NAME}', create_type=False)), autoincrement=False, nullable=True)) + sa.Enum({', '.join(map(repr, self.enum_variants_to_delete))}, name='{CAR_COLORS_ENUM_NAME}').create(op.get_bind()) + op.add_column('{CAR_TABLE_NAME}', sa.Column('{CAR_COLORS_COLUMN_NAME}', postgresql.ARRAY(postgresql.ENUM({', '.join(map(repr, self.enum_variants_to_delete))}, name='{CAR_COLORS_ENUM_NAME}', create_type=False)), autoincrement=False, nullable=True)) # ### end Alembic commands ### - """, - ) + """ def test_delete_enum_diff_tuple(connection: "Connection"): diff --git a/tests/test_enum_creation/test_drop_table.py b/tests/test_enum_creation/test_drop_table.py index 9af2b39..22ccf74 100644 --- a/tests/test_enum_creation/test_drop_table.py +++ b/tests/test_enum_creation/test_drop_table.py @@ -1,36 +1,34 @@ -from typing import TYPE_CHECKING - from sqlalchemy import MetaData -from tests.base.render_and_run import compare_and_run +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ( get_schema_with_enum_variants, USER_TABLE_NAME, USER_STATUS_ENUM_NAME, ) -if TYPE_CHECKING: - from sqlalchemy import Connection - -def test_drop_enum_after_drop_table(connection: "Connection"): +class TestDropEnumAfterDropTable(CompareAndRunTestCase): """Check that library correctly drop enum after drop_table""" + dropped_enum_variants = ["active", "passive"] - database_schema = get_schema_with_enum_variants(dropped_enum_variants) - database_schema.create_all(connection) - target_schema = MetaData() + def get_database_schema(self) -> MetaData: + return get_schema_with_enum_variants(self.dropped_enum_variants) - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_target_schema(self) -> MetaData: + return MetaData() + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table('{USER_TABLE_NAME}') - sa.Enum({', '.join(map(repr, dropped_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) + sa.Enum({', '.join(map(repr, self.dropped_enum_variants))}, name='{USER_STATUS_ENUM_NAME}').drop(op.get_bind()) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### sa.Enum('active', 'passive', name='user_status').create(op.get_bind()) op.create_table('users', @@ -39,5 +37,4 @@ def test_drop_enum_after_drop_table(connection: "Connection"): sa.PrimaryKeyConstraint('id', name='users_pkey') ) # ### end Alembic commands ### - """, - ) + """ diff --git a/tests/test_without_schema_changes/test_explicit_schema.py b/tests/test_without_schema_changes/test_explicit_schema.py index 0afad2b..614eec6 100644 --- a/tests/test_without_schema_changes/test_explicit_schema.py +++ b/tests/test_without_schema_changes/test_explicit_schema.py @@ -7,7 +7,7 @@ get_defined_enums, get_declared_enums, ) -from tests.base.render_and_run import compare_and_run +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ANOTHER_SCHEMA_NAME, DEFAULT_SCHEMA if TYPE_CHECKING: @@ -68,23 +68,23 @@ def test_get_declared_enums(connection: "Connection"): } -def test_compare_and_run(connection: "Connection"): - database_schema = my_metadata - database_schema.create_all(connection) +class TestCompareAndRun(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + return my_metadata - target_schema = my_metadata + def get_target_schema(self) -> MetaData: + return my_metadata - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### pass # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### pass # ### end Alembic commands ### - """, - ) + """ diff --git a/tests/test_without_schema_changes/test_implicit_schema_1_4_style.py b/tests/test_without_schema_changes/test_implicit_schema_1_4_style.py index 27d870a..83214b7 100644 --- a/tests/test_without_schema_changes/test_implicit_schema_1_4_style.py +++ b/tests/test_without_schema_changes/test_implicit_schema_1_4_style.py @@ -7,7 +7,7 @@ get_defined_enums, get_declared_enums, ) -from tests.base.render_and_run import compare_and_run +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ANOTHER_SCHEMA_NAME, DEFAULT_SCHEMA if TYPE_CHECKING: @@ -80,13 +80,15 @@ def test_get_declared_enums(connection: "Connection"): } -def test_compare_and_run_create_table(connection: "Connection"): - target_schema = my_metadata +class TestCompareAndRunCreateTable(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + return MetaData() - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_target_schema(self) -> MetaData: + return my_metadata + + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### sa.Enum('PENDING', 'SUCCESS', 'FAILED', name='test_status').create(op.get_bind()) op.create_table('test', @@ -96,33 +98,34 @@ def test_compare_and_run_create_table(connection: "Connection"): schema='another' ) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table('test', schema='another') sa.Enum('PENDING', 'SUCCESS', 'FAILED', name='test_status').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ -def test_compare_and_run(connection: "Connection"): - database_schema = my_metadata - database_schema.create_all(connection) +class TestCompareAndRun(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + return my_metadata - target_schema = my_metadata + def get_target_schema(self) -> MetaData: + return my_metadata - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### pass # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### pass # ### end Alembic commands ### - """, - ) + """ diff --git a/tests/test_without_schema_changes/test_implicit_schema_2_0_style.py b/tests/test_without_schema_changes/test_implicit_schema_2_0_style.py index 50750ae..9f1b4f9 100644 --- a/tests/test_without_schema_changes/test_implicit_schema_2_0_style.py +++ b/tests/test_without_schema_changes/test_implicit_schema_2_0_style.py @@ -10,7 +10,7 @@ get_defined_enums, get_declared_enums, ) -from tests.base.render_and_run import compare_and_run +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ANOTHER_SCHEMA_NAME, DEFAULT_SCHEMA if TYPE_CHECKING: @@ -81,14 +81,15 @@ def test_get_declared_enums(connection: "Connection"): @pytest.mark.skipif(sqlalchemy.__version__.startswith("1."), reason="Table are made in 2.0 style") -def test_compare_and_run_create_table(connection: "Connection"): - my_metadata = get_my_metadata() - target_schema = my_metadata +class TestCompareAndRunCreateTable(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + return MetaData() + + def get_target_schema(self) -> MetaData: + return get_my_metadata() - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### sa.Enum('PENDING', 'SUCCESS', 'FAILED', name='test_status').create(op.get_bind()) op.create_table('test', @@ -98,35 +99,35 @@ def test_compare_and_run_create_table(connection: "Connection"): schema='another' ) # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### op.drop_table('test', schema='another') sa.Enum('PENDING', 'SUCCESS', 'FAILED', name='test_status').drop(op.get_bind()) # ### end Alembic commands ### - """, - ) + """ @pytest.mark.skipif(sqlalchemy.__version__.startswith("1."), reason="Table are made in 2.0 style") -def test_compare_and_run(connection: "Connection"): - my_metadata = get_my_metadata() - database_schema = my_metadata - database_schema.create_all(connection) +class TestCompareAndRun(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + return get_my_metadata() - target_schema = my_metadata + def get_target_schema(self) -> MetaData: + return get_my_metadata() - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### pass # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### pass # ### end Alembic commands ### - """, - ) + """ diff --git a/tests/test_without_schema_changes/test_inherit_schema_true.py b/tests/test_without_schema_changes/test_inherit_schema_true.py index 298aeca..9b0578f 100644 --- a/tests/test_without_schema_changes/test_inherit_schema_true.py +++ b/tests/test_without_schema_changes/test_inherit_schema_true.py @@ -7,7 +7,7 @@ get_defined_enums, get_declared_enums, ) -from tests.base.render_and_run import compare_and_run +from tests.base.run_migration_test_abc import CompareAndRunTestCase from tests.schemas import ANOTHER_SCHEMA_NAME, DEFAULT_SCHEMA if TYPE_CHECKING: @@ -68,23 +68,23 @@ def test_get_declared_enums(connection: "Connection"): } -def test_compare_and_run(connection: "Connection"): - database_schema = my_metadata - database_schema.create_all(connection) +class TestCompareAndRun(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + return my_metadata - target_schema = my_metadata + def get_target_schema(self) -> MetaData: + return my_metadata - compare_and_run( - connection, - target_schema, - expected_upgrade=f""" + def get_expected_upgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### pass # ### end Alembic commands ### - """, - expected_downgrade=f""" + """ + + def get_expected_downgrade(self) -> str: + return f""" # ### commands auto generated by Alembic - please adjust! ### pass # ### end Alembic commands ### - """, - ) + """