Skip to content

Commit

Permalink
Merge pull request #77 from Pogchamp-company/bug/76/error-adding-array
Browse files Browse the repository at this point in the history
Fix error when adding column that uses existing changing enum
  • Loading branch information
RustyGuard authored Jul 13, 2024
2 parents 455f877 + 75789b0 commit 843b51a
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 21 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test_on_push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ on:
- tests/**
- alembic_postgresql_enum/**
- .github/workflows/test_on_push.yaml
pull_request: { }

jobs:
run_tests:
Expand Down
10 changes: 10 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:latest

COPY ./alembic_postgresql_enum ./alembic_postgresql_enum
COPY ./tests ./tests

WORKDIR ./tests

RUN pip install -r requirements.txt

ENTRYPOINT pytest
17 changes: 10 additions & 7 deletions alembic_postgresql_enum/compare_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def compare_enums(
for each defined enum that has changed new entries when compared to its
declared version.
"""
assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)

if autogen_context.dialect.name != "postgresql":
log.warning(
f"This library only supports postgresql, but you are using {autogen_context.dialect.name}, skipping"
Expand All @@ -49,19 +56,15 @@ def compare_enums(
if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names:
schema_names.append(operations_group.schema)

assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)
for schema in schema_names:
default_schema = autogen_context.dialect.default_schema_name
if schema is None:
schema = default_schema

definitions = get_defined_enums(autogen_context.connection, schema)
declarations = get_declared_enums(autogen_context.metadata, schema, default_schema, autogen_context.connection)
declarations = get_declared_enums(
autogen_context.metadata, schema, default_schema, autogen_context.connection, upgrade_ops
)

create_new_enums(definitions, declarations.enum_values, schema, upgrade_ops)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def sync_changed_enums(
enum_name,
list(old_values),
list(new_values),
list(affected_columns),
sorted( # Sort references alphabetically for consistency of generated text
affected_columns,
key=lambda reference: (reference.table_schema, reference.table_name, reference.column_name),
),
)
upgrade_ops.ops.append(op)
12 changes: 10 additions & 2 deletions alembic_postgresql_enum/get_enum_data/declared_enums.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import defaultdict
from enum import Enum
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Optional

import sqlalchemy
from alembic.operations.ops import UpgradeOps
from sqlalchemy import MetaData
from sqlalchemy.dialects import postgresql

from alembic_postgresql_enum.get_enum_data.get_default_from_alembic_ops import get_just_added_defaults
from alembic_postgresql_enum.sql_commands.column_default import get_column_default

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,6 +51,7 @@ def get_declared_enums(
schema: str,
default_schema: str,
connection: "Connection",
upgrade_ops: Optional[UpgradeOps] = None,
) -> DeclaredEnumValues:
"""
Return a dict mapping SQLAlchemy declared enumeration types to the set of their values
Expand All @@ -62,6 +64,8 @@ def get_declared_enums(
Default schema name, likely will be "public"
:param connection:
Database connection
:param upgrade_ops:
Upgrade operations in current migration
:returns DeclaredEnumValues:
enum_values: {
"my_enum": tuple(["a", "b", "c"]),
Expand All @@ -75,6 +79,8 @@ def get_declared_enums(
enum_name_to_values = dict()
enum_name_to_table_references: defaultdict[str, Set[TableReference]] = defaultdict(set)

just_added_defaults = get_just_added_defaults(upgrade_ops, default_schema)

if isinstance(metadata, list):
metadata_list = metadata
else:
Expand Down Expand Up @@ -103,6 +109,8 @@ def get_declared_enums(

table_schema = table.schema or default_schema
column_default = get_column_default(connection, table_schema, table.name, column.name)
if (table_schema, table.name, column.name) in just_added_defaults:
column_default = just_added_defaults[table_schema, table.name, column.name]
enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined]
TableReference(
table_schema=table_schema,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Optional, Dict, Tuple

from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, AlterColumnOp, CreateTableOp
from sqlalchemy import Column

SchemaName = str
TableName = str
ColumnName = str
ColumnLocation = Tuple[SchemaName, TableName, ColumnName]


def _get_default_from_add_column_op(op: AddColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if op.column.server_default is None:
raise AttributeError("No new server_default")
return (
(op.schema or default_schema, op.table_name, op.column.name),
op.column.server_default.arg.text, # type: ignore[attr-defined]
)


def _get_default_from_alter_column_op(op: AlterColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if op.modify_server_default is False:
raise AttributeError("No new server_default")
return (op.schema or default_schema, op.table_name, op.column_name), op.modify_server_default


def _get_default_from_column(column: Column, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]:
if column.server_default is None:
raise AttributeError("No new server_default")
return (
(column.table.schema or default_schema, column.table.name, column.name),
column.server_default.arg.text, # type: ignore[attr-defined]
)


def get_just_added_defaults(
upgrade_ops: Optional[UpgradeOps], default_schema: str
) -> Dict[ColumnLocation, Optional[str]]:
"""Get all server defaults that will be added in current migration"""
if upgrade_ops is None:
return {}

new_server_defaults = {}

for operations_group in upgrade_ops.ops:
if isinstance(operations_group, ModifyTableOps):
for operation in operations_group.ops:
if isinstance(operation, AddColumnOp):
try:
column_location, column_new_default = _get_default_from_add_column_op(operation, default_schema)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

elif isinstance(operation, AlterColumnOp):
try:
column_location, column_new_default = _get_default_from_alter_column_op(
operation, default_schema
)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

elif isinstance(operations_group, CreateTableOp):
for column in operations_group.columns:
if isinstance(column, Column):
try:
column_location, column_new_default = _get_default_from_column(column, default_schema)
new_server_defaults[column_location] = column_new_default
except AttributeError:
pass

return new_server_defaults
40 changes: 32 additions & 8 deletions alembic_postgresql_enum/sql_commands/column_default.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import TYPE_CHECKING, Union, List, Tuple

import sqlalchemy
Expand Down Expand Up @@ -54,15 +55,38 @@ def rename_default_if_required(
enum_name: str,
enum_values_to_rename: List[Tuple[str, str]],
) -> str:
is_array = default_value.endswith("[]")
if schema:
new_enum = f"{schema}.{enum_name}"
else:
new_enum = enum_name

if default_value.startswith("ARRAY["):
column_default_value = _replace_strings_in_quotes(default_value, enum_values_to_rename)
column_default_value = re.sub(r"::[.\w]+", f"::{new_enum}", column_default_value)
return column_default_value

if default_value.endswith("[]"):

# remove old type postfix
column_default_value = default_value[: default_value.find("::")]

column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename)

return f"{column_default_value}::{new_enum}[]"

# remove old type postfix
column_default_value = default_value[: default_value.find("::")]

for old_value, new_value in enum_values_to_rename:
column_default_value = column_default_value.replace(f"'{old_value}'", f"'{new_value}'")
column_default_value = column_default_value.replace(f'"{old_value}"', f'"{new_value}"')
column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename)

suffix = "[]" if is_array else ""
if schema:
return f"{column_default_value}::{schema}.{enum_name}{suffix}"
return f"{column_default_value}::{enum_name}{suffix}"
return f"{column_default_value}::{new_enum}"


def _replace_strings_in_quotes(
old_default: str,
enum_values_to_rename: List[Tuple[str, str]],
) -> str:
for old_value, new_value in enum_values_to_rename:
old_default = old_default.replace(f"'{old_value}'", f"'{new_value}'")
old_default = old_default.replace(f'"{old_value}"', f'"{new_value}"')
return old_default
28 changes: 28 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
version: "3.8"

services:
run-tests:
# entrypoint: pytest
build: .
stdin_open: true
tty: true
command:
- pytest
environment:
DATABASE_URI: postgresql://test_user:test_password@db:5432/test_db
depends_on:
- db
links:
- "db:database"
db:
image: postgres:12
environment:
POSTGRES_DB: "test_db"
POSTGRES_USER: "test_user"
POSTGRES_PASSWORD: "test_password"
PGUSER: "postgres"

ports:
- "5432:5432"
volumes:
- ./api/db/postgres-test-data:/var/lib/postgresql/data
15 changes: 13 additions & 2 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
# How to run tests

Create database for testing
# With `docker compose`

Just run:
```commandline
docker compose up --build --exit-code-from run-tests
```

# Manually

## Create database

Start postgres through docker compose:

## Env variables

Expand All @@ -24,4 +35,4 @@ pip install -R tests/requirements.txt
Run tests
```
pytest
```
```
33 changes: 33 additions & 0 deletions tests/sync_enum_values/test_rename_default_if_required.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,36 @@ def test_array_default_value_with_schema():
old_default_value = """'{}'::test.order_status_old[]"""

assert rename_default_if_required("test", old_default_value, "order_status", []) == """'{}'::test.order_status[]"""


def test_caps_array_default_value_without_schema():
old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_default_value_with_schema():
old_default_value = """ARRAY['A'::test.my_old_enum, 'B'::test.my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_another_default_value_without_schema():
old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_another_default_value_with_schema():
old_default_value = """ARRAY['A', 'B']::test.my_old_enum[]"""

assert rename_default_if_required("test", old_default_value, "my_enum", []) == """ARRAY['A', 'B']::test.my_enum[]"""
Loading

0 comments on commit 843b51a

Please sign in to comment.