Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
artem.golovin committed Jun 30, 2024
1 parent 757d16b commit 75789b0
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 53 deletions.
13 changes: 7 additions & 6 deletions alembic_postgresql_enum/compare_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def compare_enums(
for each defined enum that has changed new entries when compared to its
declared version.
"""
assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)

if autogen_context.dialect.name != "postgresql":
log.warning(
f"This library only supports postgresql, but you are using {autogen_context.dialect.name}, skipping"
Expand All @@ -49,12 +56,6 @@ def compare_enums(
if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names:
schema_names.append(operations_group.schema)

assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)
for schema in schema_names:
default_schema = autogen_context.dialect.default_schema_name
if schema is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def sync_changed_enums(
enum_name,
list(old_values),
list(new_values),
sorted(
sorted( # Sort references alphabetically for consistency of generated text
affected_columns,
key=lambda reference: (reference.table_schema, reference.table_name, reference.column_name),
),
Expand Down
51 changes: 5 additions & 46 deletions alembic_postgresql_enum/get_enum_data/declared_enums.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import defaultdict
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Dict
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Optional

import sqlalchemy
from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, CreateTableOp, AlterColumnOp
from sqlalchemy import MetaData, Column
from alembic.operations.ops import UpgradeOps
from sqlalchemy import MetaData
from sqlalchemy.dialects import postgresql

from alembic_postgresql_enum.get_enum_data.get_default_from_alembic_ops import get_just_added_defaults
from alembic_postgresql_enum.sql_commands.column_default import get_column_default

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,54 +46,12 @@ def column_type_is_enum(column_type: Any) -> bool:
return False


def get_just_added_defaults(
upgrade_ops: Union[UpgradeOps, None], default_schema: str
) -> Dict[Tuple[str, str, str], str]:
"""Get all server defaults that will be added in current migration"""
if upgrade_ops is None:
return {}

new_server_defaults = {}

for operations_group in upgrade_ops.ops:
if isinstance(operations_group, ModifyTableOps):
for operation in operations_group.ops:
if isinstance(operation, AddColumnOp):
try:
if operation.column.server_default is None:
continue
new_server_defaults[
operation.schema or default_schema, operation.table_name, operation.column.name
] = operation.column.server_default.arg.text
except AttributeError:
pass
elif isinstance(operation, AlterColumnOp):
if operation.modify_server_default is not False:
new_server_defaults[
operation.schema or default_schema, operation.table_name, operation.column_name
] = operation.modify_server_default

elif isinstance(operations_group, CreateTableOp):
for column in operations_group.columns:
if isinstance(column, Column):
try:
if column.server_default is None:
continue
new_server_defaults[column.table.schema or default_schema, column.table.name, column.name] = (
column.server_default.arg.text
)
except AttributeError:
pass

return new_server_defaults


def get_declared_enums(
metadata: Union[MetaData, List[MetaData]],
schema: str,
default_schema: str,
connection: "Connection",
upgrade_ops: Union[UpgradeOps, None] = None,
upgrade_ops: Optional[UpgradeOps] = None,
) -> DeclaredEnumValues:
"""
Return a dict mapping SQLAlchemy declared enumeration types to the set of their values
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Optional, Dict, Tuple

from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, AlterColumnOp, CreateTableOp
from sqlalchemy import Column

SchemaName = str
TableName = str
ColumnName = str
ColumnLocation = Tuple[SchemaName, TableName, ColumnName]


def _get_default_from_add_column_op(op: AddColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if op.column.server_default is None:
raise AttributeError("No new server_default")
return (
(op.schema or default_schema, op.table_name, op.column.name),
op.column.server_default.arg.text, # type: ignore[attr-defined]
)


def _get_default_from_alter_column_op(op: AlterColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if op.modify_server_default is False:
raise AttributeError("No new server_default")
return (op.schema or default_schema, op.table_name, op.column_name), op.modify_server_default


def _get_default_from_column(column: Column, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if column.server_default is None:
raise AttributeError("No new server_default")
return (
(column.table.schema or default_schema, column.table.name, column.name),
column.server_default.arg.text, # type: ignore[attr-defined]
)


def get_just_added_defaults(
upgrade_ops: Optional[UpgradeOps], default_schema: str
) -> Dict[ColumnLocation, Optional[str]]:
"""Get all server defaults that will be added in current migration"""
if upgrade_ops is None:
return {}

new_server_defaults = {}

for operations_group in upgrade_ops.ops:
if isinstance(operations_group, ModifyTableOps):
for operation in operations_group.ops:
if isinstance(operation, AddColumnOp):
try:
column_location, column_new_default = _get_default_from_add_column_op(operation, default_schema)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

elif isinstance(operation, AlterColumnOp):
try:
column_location, column_new_default = _get_default_from_alter_column_op(
operation, default_schema
)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

elif isinstance(operations_group, CreateTableOp):
for column in operations_group.columns:
if isinstance(column, Column):
try:
column_location, column_new_default = _get_default_from_column(column, default_schema)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

return new_server_defaults

0 comments on commit 75789b0

Please sign in to comment.