diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index fa69164..5709c06 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -88,8 +88,8 @@ class Base: literal_imports: list[LiteralImport] declarations: list[str] metadata_ref: str - decorator: Optional[str] = None - table_metadata_declaration: Optional[str] = None + decorator: str | None = None + table_metadata_declaration: str | None = None class CodeGenerator(metaclass=ABCMeta): diff --git a/tests/test_cli.py b/tests/test_cli.py index afdad42..4577e06 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,9 +7,10 @@ import pytest -from .conftest import requires_sqlalchemy_1_4 from sqlacodegen.generators import _sqla_version +from .conftest import requires_sqlalchemy_1_4 + if sys.version_info < (3, 8): from importlib_metadata import version else: @@ -82,8 +83,8 @@ def test_cli_declarative(db_path: Path, tmp_path: Path) -> None: if _sqla_version < (2, 0): assert ( - output_path.read_text() - == f"""\ + output_path.read_text() + == f"""\ from sqlalchemy import Column, Integer, Text from {declarative_package} import declarative_base @@ -98,9 +99,9 @@ class Foo(Base): """ ) else: - assert ( - output_path.read_text() - == f"""\ + assert ( + output_path.read_text() + == f"""\ from sqlalchemy import Integer, Text from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -114,7 +115,7 @@ class Foo(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(Text) """ - ) + ) def test_cli_dataclass(db_path: Path, tmp_path: Path) -> None: @@ -169,7 +170,9 @@ class Foo(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(Text) -""") +""" + ) + @requires_sqlalchemy_1_4 def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None: diff --git a/tests/test_generator_dataclass.py b/tests/test_generator_dataclass.py index 55053ae..20cd3bd 100644 --- a/tests/test_generator_dataclass.py +++ b/tests/test_generator_dataclass.py @@ -10,7 +10,7 @@ from sqlacodegen.generators import CodeGenerator, DataclassGenerator -from .conftest import validate_code, requires_sqlalchemy_1_4 +from .conftest import requires_sqlalchemy_1_4, validate_code @requires_sqlalchemy_1_4 diff --git a/tests/test_generator_dataclass2.py b/tests/test_generator_dataclass2.py index d9bf8b4..3da9ec8 100644 --- a/tests/test_generator_dataclass2.py +++ b/tests/test_generator_dataclass2.py @@ -10,7 +10,7 @@ from sqlacodegen.generators import CodeGenerator, DataclassGenerator -from .conftest import validate_code, requires_sqlalchemy_2_0, requires_python_3_9 +from .conftest import requires_python_3_9, requires_sqlalchemy_2_0, validate_code @requires_sqlalchemy_2_0 @@ -99,7 +99,7 @@ def test_onetomany_optional(self, generator: CodeGenerator) -> None: generator.generate(), """\ from typing import Optional - + from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column, relationship diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index b75b39f..79dde06 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -19,7 +19,7 @@ from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator -from .conftest import validate_code, requires_sqlalchemy_1_4 +from .conftest import requires_sqlalchemy_1_4, validate_code @requires_sqlalchemy_1_4 diff --git a/tests/test_generator_declarative2.py b/tests/test_generator_declarative2.py index 39bb016..72b33cc 100644 --- a/tests/test_generator_declarative2.py +++ b/tests/test_generator_declarative2.py @@ -19,7 +19,7 @@ from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator -from .conftest import validate_code, requires_sqlalchemy_2_0 +from .conftest import requires_sqlalchemy_2_0, validate_code @requires_sqlalchemy_2_0 @@ -633,7 +633,7 @@ class SimpleItems(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - + t_container_items = Table( 'container_items', Base.metadata, Column('item_id', ForeignKey('simple_items.id')), @@ -674,7 +674,7 @@ class SimpleItems(Base): parent: Mapped[list['SimpleItems']] = relationship('SimpleItems', secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, secondaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, back_populates='child') child: Mapped[list['SimpleItems']] = relationship('SimpleItems', secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, back_populates='parent') - + t_child_items = Table( 'child_items', Base.metadata, Column('parent_id', ForeignKey('simple_items.id')), @@ -786,21 +786,21 @@ def test_joined_inheritance(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class SimpleSuperItems(Base): __tablename__ = 'simple_super_items' id: Mapped[int] = mapped_column(Integer, primary_key=True) data1: Mapped[Optional[int]] = mapped_column(Integer) - + class SimpleItems(SimpleSuperItems): __tablename__ = 'simple_items' super_item_id: Mapped[int] = mapped_column(ForeignKey('simple_super_items.id'), primary_key=True) data2: Mapped[Optional[int]] = mapped_column(Integer) - + class SimpleSubItems(SimpleItems): __tablename__ = 'simple_sub_items' @@ -831,13 +831,13 @@ def test_joined_inheritance_same_table_name(self, generator: CodeGenerator) -> N class Base(DeclarativeBase): pass - + class Simple(Base): __tablename__ = 'simple' id: Mapped[int] = mapped_column(Integer, primary_key=True) - + class Simple_(Simple): __tablename__ = 'simple' __table_args__ = {'schema': 'altschema'} @@ -863,13 +863,13 @@ def test_use_inflect(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class SimpleItem(Base): __tablename__ = 'simple_items' id: Mapped[int] = mapped_column(Integer, primary_key=True) - + class Singular(Base): __tablename__ = 'singular' @@ -921,7 +921,7 @@ def test_use_inflect_plural( class Base(DeclarativeBase): pass - + class {class_name.capitalize()}(Base): __tablename__ = '{table_name}' @@ -957,7 +957,7 @@ def test_table_kwargs(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class SimpleItems(Base): __tablename__ = 'simple_items' __table_args__ = {'schema': 'testschema'} @@ -982,14 +982,14 @@ def test_table_args_kwargs(self, generator: CodeGenerator) -> None: generator.generate(), """\ from typing import Optional - + from sqlalchemy import Index, Integer, String from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class Base(DeclarativeBase): pass - + class SimpleItems(Base): __tablename__ = 'simple_items' __table_args__ = ( @@ -1028,7 +1028,7 @@ def test_foreign_key_schema(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class OtherItems(Base): __tablename__ = 'other_items' __table_args__ = {'schema': 'otherschema'} @@ -1037,7 +1037,7 @@ class OtherItems(Base): simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', back_populates='other_item') - + class SimpleItems(Base): __tablename__ = 'simple_items' @@ -1069,7 +1069,7 @@ def test_invalid_attribute_names(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class SimpleItems(Base): __tablename__ = 'simple-items' @@ -1096,7 +1096,7 @@ def test_pascal(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class CustomerAPIPreference(Base): __tablename__ = 'CustomerAPIPreference' @@ -1120,7 +1120,7 @@ def test_underscore(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class CustomerApiPreference(Base): __tablename__ = 'customer_api_preference' @@ -1168,7 +1168,7 @@ def test_pascal_multiple_underscore(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class CustomerAPIPreference(Base): __tablename__ = 'customer_API__Preference' @@ -1223,7 +1223,7 @@ def test_table_comment(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class Simple(Base): __tablename__ = 'simple' __table_args__ = {'comment': "this is a 'comment'"} @@ -1251,7 +1251,7 @@ def test_metadata_column(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class Simple(Base): __tablename__ = 'simple' @@ -1324,7 +1324,7 @@ def test_named_constraints(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class Simple(Base): __tablename__ = 'simple' __table_args__ = ( diff --git a/tests/test_generator_tables.py b/tests/test_generator_tables.py index 49b5659..38a5da5 100644 --- a/tests/test_generator_tables.py +++ b/tests/test_generator_tables.py @@ -5,6 +5,7 @@ import pytest from _pytest.fixtures import FixtureRequest from sqlalchemy.dialects import mysql, postgresql +from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, Column, @@ -19,13 +20,10 @@ from sqlalchemy.sql.expression import text from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text -from sqlalchemy.engine import Engine -from sqlacodegen.generators import ( - CodeGenerator, - TablesGenerator, -) -from .conftest import validate_code, requires_sqlalchemy_1_4 +from sqlacodegen.generators import CodeGenerator, TablesGenerator + +from .conftest import requires_sqlalchemy_1_4, validate_code @requires_sqlalchemy_1_4 diff --git a/tests/test_generator_tables2.py b/tests/test_generator_tables2.py index c816e85..796b4e8 100644 --- a/tests/test_generator_tables2.py +++ b/tests/test_generator_tables2.py @@ -5,6 +5,7 @@ import pytest from _pytest.fixtures import FixtureRequest from sqlalchemy.dialects import mysql, postgresql +from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, Column, @@ -19,13 +20,10 @@ from sqlalchemy.sql.expression import text from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text -from sqlalchemy.engine import Engine -from sqlacodegen.generators import ( - CodeGenerator, - TablesGenerator, -) -from .conftest import validate_code, requires_sqlalchemy_2_0 +from sqlacodegen.generators import CodeGenerator, TablesGenerator + +from .conftest import requires_sqlalchemy_2_0, validate_code @requires_sqlalchemy_2_0