Skip to content

Commit

Permalink
Reformat tests files with black
Browse files Browse the repository at this point in the history
  • Loading branch information
RustyGuard committed Jan 6, 2024
1 parent d14ff56 commit 05221ff
Show file tree
Hide file tree
Showing 23 changed files with 837 additions and 461 deletions.
62 changes: 38 additions & 24 deletions tests/base/render_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
from sqlalchemy import Connection


def compare_and_run(connection: 'Connection', target_schema: Union[MetaData, List[MetaData]], *,
expected_upgrade: str,
expected_downgrade: str):
def compare_and_run(
connection: "Connection",
target_schema: Union[MetaData, List[MetaData]],
*,
expected_upgrade: str,
expected_downgrade: str,
):
"""Compares generated migration script is equal to expected_upgrade and expected_downgrade, then runs it"""
migration_context = create_migration_context(connection, target_schema)

Expand All @@ -26,24 +30,34 @@ def compare_and_run(connection: 'Connection', target_schema: Union[MetaData, Lis
# todo _render_migration_diffs marked as legacy, maybe find something else
autogenerate._render_migration_diffs(migration_context, template_args)

upgrade_code = textwrap.dedent(' ' + template_args["upgrades"])
downgrade_code = textwrap.dedent(' ' + template_args["downgrades"])

expected_upgrade = textwrap.dedent(expected_upgrade).strip('\n ')
expected_downgrade = textwrap.dedent(expected_downgrade).strip('\n ')

assert upgrade_code == expected_upgrade, f'Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}'
assert downgrade_code == expected_downgrade, f'Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}'

exec(upgrade_code, { # todo Use imports from template_args
'op': op,
'sa': sqlalchemy,
'postgresql': postgresql,
'ColumnType': ColumnType,
})
exec(downgrade_code, {
'op': op,
'sa': sqlalchemy,
'postgresql': postgresql,
'ColumnType': ColumnType,
})
upgrade_code = textwrap.dedent(" " + template_args["upgrades"])
downgrade_code = textwrap.dedent(" " + template_args["downgrades"])

expected_upgrade = textwrap.dedent(expected_upgrade).strip("\n ")
expected_downgrade = textwrap.dedent(expected_downgrade).strip("\n ")

assert (
upgrade_code == expected_upgrade
), f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}"
assert (
downgrade_code == expected_downgrade
), f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}"

exec(
upgrade_code,
{ # todo Use imports from template_args
"op": op,
"sa": sqlalchemy,
"postgresql": postgresql,
"ColumnType": ColumnType,
},
)
exec(
downgrade_code,
{
"op": op,
"sa": sqlalchemy,
"postgresql": postgresql,
"ColumnType": ColumnType,
},
)
11 changes: 8 additions & 3 deletions tests/fixtures/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,25 @@

try:
import dotenv

dotenv.load_dotenv()
except ImportError:
pass
database_uri = os.getenv('DATABASE_URI')
database_uri = os.getenv("DATABASE_URI")


@pytest.fixture
def connection() -> Generator:
engine = create_engine(database_uri)
with engine.connect() as conn:
conn.execute(sqlalchemy.text(f'''
conn.execute(
sqlalchemy.text(
f"""
DROP SCHEMA {DEFAULT_SCHEMA} CASCADE;
CREATE SCHEMA {DEFAULT_SCHEMA};
DROP SCHEMA IF EXISTS {ANOTHER_SCHEMA_NAME} CASCADE;
CREATE SCHEMA {ANOTHER_SCHEMA_NAME};
'''))
"""
)
)
yield conn
31 changes: 21 additions & 10 deletions tests/get_enum_data/test_get_declared_enums.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
from typing import TYPE_CHECKING

from alembic_postgresql_enum.get_enum_data import TableReference, get_declared_enums
from tests.schemas import get_schema_with_enum_variants, DEFAULT_SCHEMA, USER_STATUS_ENUM_NAME, USER_TABLE_NAME, \
USER_STATUS_COLUMN_NAME, get_schema_by_declared_enum_values, get_declared_enum_values_with_orders_and_users
from tests.schemas import (
get_schema_with_enum_variants,
DEFAULT_SCHEMA,
USER_STATUS_ENUM_NAME,
USER_TABLE_NAME,
USER_STATUS_COLUMN_NAME,
get_schema_by_declared_enum_values,
get_declared_enum_values_with_orders_and_users,
)

if TYPE_CHECKING:
from sqlalchemy import Connection


def test_with_user_schema(connection: 'Connection'):
def test_with_user_schema(connection: "Connection"):
enum_variants = ["active", "passive"]
declared_schema = get_schema_with_enum_variants(enum_variants)

function_result = get_declared_enums(declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection)
function_result = get_declared_enums(
declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection
)

assert function_result.enum_values == {
USER_STATUS_ENUM_NAME: tuple(enum_variants)
}
assert function_result.enum_values == {USER_STATUS_ENUM_NAME: tuple(enum_variants)}
assert function_result.enum_table_references == {
USER_STATUS_ENUM_NAME: frozenset((TableReference(USER_TABLE_NAME, USER_STATUS_COLUMN_NAME),))
USER_STATUS_ENUM_NAME: frozenset(
(TableReference(USER_TABLE_NAME, USER_STATUS_COLUMN_NAME),)
)
}


def test_with_multiple_enums(connection: 'Connection'):
def test_with_multiple_enums(connection: "Connection"):
declared_enum_values = get_declared_enum_values_with_orders_and_users()
declared_schema = get_schema_by_declared_enum_values(declared_enum_values)

function_result = get_declared_enums(declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection)
function_result = get_declared_enums(
declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection
)

assert function_result == declared_enum_values
17 changes: 10 additions & 7 deletions tests/get_enum_data/test_get_defined_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@
from sqlalchemy import Connection

from alembic_postgresql_enum.get_enum_data import get_defined_enums
from tests.schemas import get_schema_with_enum_variants, DEFAULT_SCHEMA, USER_STATUS_ENUM_NAME, \
get_schema_by_declared_enum_values, get_declared_enum_values_with_orders_and_users
from tests.schemas import (
get_schema_with_enum_variants,
DEFAULT_SCHEMA,
USER_STATUS_ENUM_NAME,
get_schema_by_declared_enum_values,
get_declared_enum_values_with_orders_and_users,
)


def test_get_defined_enums(connection: 'Connection'):
def test_get_defined_enums(connection: "Connection"):
enum_variants = ["active", "passive"]
defined_schema = get_schema_with_enum_variants(enum_variants)
defined_schema.create_all(connection)

function_result = get_defined_enums(connection, DEFAULT_SCHEMA)

assert function_result == {
USER_STATUS_ENUM_NAME: tuple(enum_variants)
}
assert function_result == {USER_STATUS_ENUM_NAME: tuple(enum_variants)}


def test_with_multiple_enums(connection: 'Connection'):
def test_with_multiple_enums(connection: "Connection"):
declared_enum_values = get_declared_enum_values_with_orders_and_users()
defined_schema = get_schema_by_declared_enum_values(declared_enum_values)

Expand Down
53 changes: 32 additions & 21 deletions tests/get_enum_data/test_type_decorator_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

class ValuesEnum(sqlalchemy.types.TypeDecorator):
"""Custom enum wrapper that forces columns to store enum values, not names"""

impl = sqlalchemy.types.Enum

cache_ok = True
Expand Down Expand Up @@ -42,44 +43,54 @@ def process(value):
return process


ORDER_TABLE_NAME = 'order'
ORDER_DELIVERY_STATUS_COLUMN_NAME = 'delivery_status'
ORDER_DELIVERY_STATUS_ENUM_NAME = 'order_delivery_status'
ORDER_TABLE_NAME = "order"
ORDER_DELIVERY_STATUS_COLUMN_NAME = "delivery_status"
ORDER_DELIVERY_STATUS_ENUM_NAME = "order_delivery_status"


class OrderDeliveryStatus(enum.Enum):
WAITING_FOR_WORKER = 'waiting_for_worker'
WAITING_FOR_WORKER_TO_ARRIVE = 'waiting_for_worker_to_arrive'
WORKER_ARRIVED = 'worker_arrived'
IN_PROGRESS = 'in_progress'
WAITING_FOR_APPROVAL = 'waiting_for_approval'
DISAPPROVED = 'disapproved'
DONE = 'done'
REFUNDED = 'refunded'
BANNED = 'banned'
CANCELED = 'canceled'
WAITING_FOR_WORKER = "waiting_for_worker"
WAITING_FOR_WORKER_TO_ARRIVE = "waiting_for_worker_to_arrive"
WORKER_ARRIVED = "worker_arrived"
IN_PROGRESS = "in_progress"
WAITING_FOR_APPROVAL = "waiting_for_approval"
DISAPPROVED = "disapproved"
DONE = "done"
REFUNDED = "refunded"
BANNED = "banned"
CANCELED = "canceled"


def get_schema_with_custom_enum() -> MetaData:
schema = MetaData()

Table(ORDER_TABLE_NAME,
schema,
Column(ORDER_DELIVERY_STATUS_COLUMN_NAME, ValuesEnum(OrderDeliveryStatus, name=ORDER_DELIVERY_STATUS_ENUM_NAME))
)
Table(
ORDER_TABLE_NAME,
schema,
Column(
ORDER_DELIVERY_STATUS_COLUMN_NAME,
ValuesEnum(OrderDeliveryStatus, name=ORDER_DELIVERY_STATUS_ENUM_NAME),
),
)

return schema


def test_get_declared_enums_for_custom_enum(connection: 'Connection'):
def test_get_declared_enums_for_custom_enum(connection: "Connection"):
declared_schema = get_schema_with_custom_enum()

function_result = get_declared_enums(declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection)
function_result = get_declared_enums(
declared_schema, DEFAULT_SCHEMA, DEFAULT_SCHEMA, connection
)

assert function_result.enum_values == {
# All declared enum variants must be taken from OrderDeliveryStatus values, see ValuesEnum
ORDER_DELIVERY_STATUS_ENUM_NAME: tuple(enum_item.value for enum_item in OrderDeliveryStatus)
ORDER_DELIVERY_STATUS_ENUM_NAME: tuple(
enum_item.value for enum_item in OrderDeliveryStatus
)
}
assert function_result.enum_table_references == {
ORDER_DELIVERY_STATUS_ENUM_NAME: frozenset((TableReference(ORDER_TABLE_NAME, ORDER_DELIVERY_STATUS_COLUMN_NAME),))
ORDER_DELIVERY_STATUS_ENUM_NAME: frozenset(
(TableReference(ORDER_TABLE_NAME, ORDER_DELIVERY_STATUS_COLUMN_NAME),)
)
}
Loading

0 comments on commit 05221ff

Please sign in to comment.