Skip to content

Commit

Permalink
Merge pull request #75 from Pogchamp-company/develop
Browse files Browse the repository at this point in the history
Version 1.2.0
  • Loading branch information
RustyGuard authored Apr 13, 2024
2 parents 2ad77dd + 455f877 commit c7248b8
Show file tree
Hide file tree
Showing 24 changed files with 493 additions and 407 deletions.
11 changes: 4 additions & 7 deletions alembic_postgresql_enum/add_create_type_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,18 @@ def add_create_type_false(upgrade_ops: UpgradeOps):
if isinstance(operations_group, ModifyTableOps):
for operation in operations_group.ops:
if isinstance(operation, AddColumnOp):
column: Column = operation.column

inject_repr_into_enums(column)

inject_repr_into_enums(operation.column)
elif isinstance(operation, DropColumnOp):
column: Column = operation._reverse.column

inject_repr_into_enums(column)
assert operation._reverse is not None
inject_repr_into_enums(operation._reverse.column)

elif isinstance(operations_group, CreateTableOp):
for column in operations_group.columns:
if isinstance(column, Column):
inject_repr_into_enums(column)

elif isinstance(operations_group, DropTableOp):
assert operations_group._reverse is not None
for column in operations_group._reverse.columns:
if isinstance(column, Column):
inject_repr_into_enums(column)
1 change: 1 addition & 0 deletions alembic_postgresql_enum/add_postgres_using_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _postgres_using_alter_column(autogen_context: AutogenContext, op: ops.AlterC


def add_postgres_using_to_alter_operation(op: AlterColumnOp):
assert op.modify_type is not None
op.kw["postgresql_using"] = f"{op.column_name}::{op.modify_type.name}"
log.info("postgresql_using added to %r.%r alteration", op.table_name, op.column_name)
op.__class__ = PostgresUsingAlterColumnOp
Expand Down
16 changes: 16 additions & 0 deletions alembic_postgresql_enum/compare_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Iterable, Union

import alembic
Expand All @@ -16,6 +17,9 @@
from alembic_postgresql_enum.get_enum_data import get_defined_enums, get_declared_enums


log = logging.getLogger(f"alembic.{__name__}")


@alembic.autogenerate.comparators.dispatch_for("schema")
def compare_enums(
autogen_context: AutogenContext,
Expand All @@ -28,6 +32,12 @@ def compare_enums(
for each defined enum that has changed new entries when compared to its
declared version.
"""
if autogen_context.dialect.name != "postgresql":
log.warning(
f"This library only supports postgresql, but you are using {autogen_context.dialect.name}, skipping"
)
return

add_create_type_false(upgrade_ops)
add_postgres_using_to_text(upgrade_ops)

Expand All @@ -39,6 +49,12 @@ def compare_enums(
if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names:
schema_names.append(operations_group.schema)

assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)
for schema in schema_names:
default_schema = autogen_context.dialect.default_schema_name
if schema is None:
Expand Down
3 changes: 2 additions & 1 deletion alembic_postgresql_enum/connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from contextlib import contextmanager
from typing import Iterator

import sqlalchemy


@contextmanager
def get_connection(operations) -> sqlalchemy.engine.Connection:
def get_connection(operations) -> Iterator[sqlalchemy.engine.Connection]:
"""
SQLAlchemy 2.0 changes the operation binding location; bridge function to support
both 1.x and 2.x.
Expand Down
11 changes: 6 additions & 5 deletions alembic_postgresql_enum/get_enum_data/declared_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING
from enum import Enum
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast

import sqlalchemy
from sqlalchemy import MetaData
Expand Down Expand Up @@ -93,16 +94,16 @@ def get_declared_enums(
if not column_type_is_enum(column_type):
continue

column_type_schema = column_type.schema or default_schema
column_type_schema = column_type.schema or default_schema # type: ignore[attr-defined]
if column_type_schema != schema:
continue

if column_type.name not in enum_name_to_values:
enum_name_to_values[column_type.name] = get_enum_values(column_type)
if column_type.name not in enum_name_to_values: # type: ignore[attr-defined]
enum_name_to_values[column_type.name] = get_enum_values(cast(sqlalchemy.Enum, column_type)) # type: ignore[attr-defined]

table_schema = table.schema or default_schema
column_default = get_column_default(connection, table_schema, table.name, column.name)
enum_name_to_table_references[column_type.name].add(
enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined]
TableReference(
table_schema=table_schema,
table_name=table.name,
Expand Down
6 changes: 3 additions & 3 deletions alembic_postgresql_enum/get_enum_data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __repr__(self):
class TableReference:
table_name: str
column_name: str
table_schema: Optional[str] = Unspecified # 'Unspecified' default is for migrations from older versions
table_schema: Optional[str] = Unspecified # type: ignore[assignment] # 'Unspecified' default is for migrations from older versions
column_type: ColumnType = ColumnType.COMMON
existing_server_default: str = None
existing_server_default: Optional[str] = None

def __repr__(self):
result_str = "TableReference("
Expand All @@ -48,7 +48,7 @@ def table_name_with_schema(self):
prefix = f"{self.table_schema}."
else:
prefix = ""
return f"{prefix}{self.table_name}"
return f'{prefix}"{self.table_name}"'


EnumNamesToValues = Dict[str, Tuple[str, ...]]
Expand Down
1 change: 1 addition & 0 deletions alembic_postgresql_enum/operations/create_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reverse(self):

@alembic.autogenerate.render.renderers.dispatch_for(CreateEnumOp)
def render_create_enum_op(autogen_context: AutogenContext, op: CreateEnumOp):
assert autogen_context.dialect is not None
if op.schema != autogen_context.dialect.default_schema_name:
return f"""
sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').create(op.get_bind())
Expand Down
1 change: 1 addition & 0 deletions alembic_postgresql_enum/operations/drop_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reverse(self):

@alembic.autogenerate.render.renderers.dispatch_for(DropEnumOp)
def render_drop_enum_op(autogen_context: AutogenContext, op: DropEnumOp):
assert autogen_context.dialect is not None
if op.schema != autogen_context.dialect.default_schema_name:
return f"""
sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').drop(op.get_bind())
Expand Down
10 changes: 10 additions & 0 deletions alembic_postgresql_enum/operations/sync_enum_values.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List, Tuple, Any, Iterable, TYPE_CHECKING

import alembic.autogenerate
Expand Down Expand Up @@ -31,6 +32,9 @@
from alembic_postgresql_enum.get_enum_data import TableReference, ColumnType


log = logging.getLogger(f"alembic.{__name__}")


@alembic.operations.base.Operations.register_operation("sync_enum_values")
class SyncEnumValuesOp(alembic.operations.ops.MigrateOperation):
operation_name = "change_enum_variants"
Expand Down Expand Up @@ -138,6 +142,12 @@ def sync_enum_values(
]
If there was server default with old_name it will be renamed accordingly
"""
if operations.migration_context.dialect.name != "postgresql":
log.warning(
f"This library only supports postgresql, but you are using {operations.migration_context.dialect.name}, skipping"
)
return

enum_values_to_rename = list(enum_values_to_rename)

with get_connection(operations) as connection:
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "alembic-postgresql-enum"
version = "1.1.2"
version = "1.2.0"
description = "Alembic autogenerate support for creation, alteration and deletion of enums"
authors = ["RustyGuard"]
license = "MIT"
Expand Down
4 changes: 4 additions & 0 deletions tests/base/render_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def compare_and_run(
*,
expected_upgrade: str,
expected_downgrade: str,
disable_running: bool = False,
):
"""Compares generated migration script is equal to expected_upgrade and expected_downgrade, then runs it"""
migration_context = create_migration_context(connection, target_schema)
Expand All @@ -39,6 +40,9 @@ def compare_and_run(
assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}"
assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}"

if disable_running:
return

exec(
upgrade_code,
{ # todo Use imports from template_args
Expand Down
11 changes: 11 additions & 0 deletions tests/base/run_migration_test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,21 @@


class CompareAndRunTestCase(ABC):
"""
Base class for all tests that expect specific alembic generated code
"""

disable_running = False

@abstractmethod
def get_database_schema(self) -> MetaData: ...

@abstractmethod
def get_target_schema(self) -> MetaData: ...

def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None:
pass

@abstractmethod
def get_expected_upgrade(self) -> str: ...

Expand All @@ -26,10 +35,12 @@ def test_run(self, connection: "Connection"):
target_schema = self.get_target_schema()

database_schema.create_all(connection)
self.insert_migration_data(connection, database_schema)

compare_and_run(
connection,
target_schema,
expected_upgrade=self.get_expected_upgrade(),
expected_downgrade=self.get_expected_downgrade(),
disable_running=self.disable_running,
)
89 changes: 47 additions & 42 deletions tests/test_alter_column/test_text_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,61 @@
from sqlalchemy import MetaData, Table, Column, TEXT, insert
from sqlalchemy.dialects import postgresql

from tests.base.run_migration_test_abc import CompareAndRunTestCase

if TYPE_CHECKING:
from sqlalchemy import Connection

from tests.base.render_and_run import compare_and_run


class NewEnum(Enum):
A = "a"
B = "b"
C = "c"


def test_text_column(connection: "Connection"):
database_schema = MetaData()
a_table = Table("a", database_schema, Column("value", TEXT))
database_schema.create_all(connection)
connection.execute(
insert(a_table).values(
[
{"value": NewEnum.A.name},
{"value": NewEnum.B.name},
{"value": NewEnum.B.name},
{"value": NewEnum.C.name},
]
class TestTextColumn(CompareAndRunTestCase):
def get_database_schema(self) -> MetaData:
database_schema = MetaData()
Table("a", database_schema, Column("value", TEXT))
return database_schema

def get_target_schema(self) -> MetaData:
target_schema = MetaData()
Table("a", target_schema, Column("value", postgresql.ENUM(NewEnum)))
return target_schema

def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None:
a_table = database_schema.tables["a"]
connection.execute(
insert(a_table).values(
[
{"value": NewEnum.A.name},
{"value": NewEnum.B.name},
{"value": NewEnum.B.name},
{"value": NewEnum.C.name},
]
)
)
)

target_schema = MetaData()
Table("a", target_schema, Column("value", postgresql.ENUM(NewEnum)))

compare_and_run(
connection,
target_schema,
expected_upgrade=f"""
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum('A', 'B', 'C', name='newenum').create(op.get_bind())
op.alter_column('a', 'value',
existing_type=sa.TEXT(),
type_=postgresql.ENUM('A', 'B', 'C', name='newenum'),
existing_nullable=True,
postgresql_using='value::newenum')
# ### end Alembic commands ###
""",
expected_downgrade=f"""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('a', 'value',
existing_type=postgresql.ENUM('A', 'B', 'C', name='newenum'),
type_=sa.TEXT(),
existing_nullable=True)
sa.Enum('A', 'B', 'C', name='newenum').drop(op.get_bind())
# ### end Alembic commands ###
""",
)

def get_expected_upgrade(self) -> str:
return """
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum('A', 'B', 'C', name='newenum').create(op.get_bind())
op.alter_column('a', 'value',
existing_type=sa.TEXT(),
type_=postgresql.ENUM('A', 'B', 'C', name='newenum'),
existing_nullable=True,
postgresql_using='value::newenum')
# ### end Alembic commands ###
"""

def get_expected_downgrade(self) -> str:
return """
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('a', 'value',
existing_type=postgresql.ENUM('A', 'B', 'C', name='newenum'),
type_=sa.TEXT(),
existing_nullable=True)
sa.Enum('A', 'B', 'C', name='newenum').drop(op.get_bind())
# ### end Alembic commands ###
"""
Loading

0 comments on commit c7248b8

Please sign in to comment.