From 8e2814ce6ab84ecf4c9fc94fdf24a103148b8ae7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Mar 2023 19:31:47 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/sqlacodegen/generators.py | 12 +++++--- tests/conftest.py | 2 ++ tests/test_cli.py | 19 +++++++----- tests/test_generator_dataclass.py | 2 +- tests/test_generator_dataclass2.py | 4 +-- tests/test_generator_declarative.py | 2 +- tests/test_generator_declarative2.py | 46 ++++++++++++++-------------- tests/test_generator_tables.py | 10 +++--- tests/test_generator_tables2.py | 10 +++--- 9 files changed, 55 insertions(+), 52 deletions(-) diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index cc6acab..21d9485 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -88,7 +88,7 @@ class Base: literal_imports: list[LiteralImport] declarations: list[str] metadata_ref: str - decorator: Optional[str] = None + decorator: str | None = None class CodeGenerator(metaclass=ABCMeta): @@ -1415,10 +1415,10 @@ def render_class_declaration(self, model: ModelClass) -> str: if _sqla_version >= (2, 0): return super().render_class_declaration(model) else: - superclass_part = f"({model.parent_class.name})" if model.parent_class else "" - return ( - f"@mapper_registry.mapped\n@dataclass\nclass {model.name}{superclass_part}:" + superclass_part = ( + f"({model.parent_class.name})" if model.parent_class else "" ) + return f"@mapper_registry.mapped\n@dataclass\nclass {model.name}{superclass_part}:" def render_class_variables(self, model: ModelClass) -> str: if _sqla_version >= (2, 0): @@ -1450,7 +1450,9 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str: kwargs["default"] = None python_type_name = f"Optional[{python_type_name}]" - rendered_column = self.render_column(column, column_attr.name != column.name) + rendered_column = self.render_column( + column, column_attr.name != column.name + ) kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered_column}}}" rendered_field = render_callable("field", kwargs=kwargs) return f"{column_attr.name}: {python_type_name} = {rendered_field}" diff --git a/tests/conftest.py b/tests/conftest.py index 08c39b7..8e6a308 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,10 @@ from sqlalchemy.engine import Engine, create_engine from sqlalchemy.orm import clear_mappers, configure_mappers from sqlalchemy.schema import MetaData + from sqlacodegen.generators import _sqla_version + def validate_code(generated_code: str, expected_code: str) -> None: expected_code = dedent(expected_code) assert generated_code == expected_code diff --git a/tests/test_cli.py b/tests/test_cli.py index afdad42..4577e06 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,9 +7,10 @@ import pytest -from .conftest import requires_sqlalchemy_1_4 from sqlacodegen.generators import _sqla_version +from .conftest import requires_sqlalchemy_1_4 + if sys.version_info < (3, 8): from importlib_metadata import version else: @@ -82,8 +83,8 @@ def test_cli_declarative(db_path: Path, tmp_path: Path) -> None: if _sqla_version < (2, 0): assert ( - output_path.read_text() - == f"""\ + output_path.read_text() + == f"""\ from sqlalchemy import Column, Integer, Text from {declarative_package} import declarative_base @@ -98,9 +99,9 @@ class Foo(Base): """ ) else: - assert ( - output_path.read_text() - == f"""\ + assert ( + output_path.read_text() + == f"""\ from sqlalchemy import Integer, Text from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -114,7 +115,7 @@ class Foo(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(Text) """ - ) + ) def test_cli_dataclass(db_path: Path, tmp_path: Path) -> None: @@ -169,7 +170,9 @@ class Foo(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(Text) -""") +""" + ) + @requires_sqlalchemy_1_4 def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None: diff --git a/tests/test_generator_dataclass.py b/tests/test_generator_dataclass.py index 55053ae..20cd3bd 100644 --- a/tests/test_generator_dataclass.py +++ b/tests/test_generator_dataclass.py @@ -10,7 +10,7 @@ from sqlacodegen.generators import CodeGenerator, DataclassGenerator -from .conftest import validate_code, requires_sqlalchemy_1_4 +from .conftest import requires_sqlalchemy_1_4, validate_code @requires_sqlalchemy_1_4 diff --git a/tests/test_generator_dataclass2.py b/tests/test_generator_dataclass2.py index ee567ab..54a8b7b 100644 --- a/tests/test_generator_dataclass2.py +++ b/tests/test_generator_dataclass2.py @@ -10,7 +10,7 @@ from sqlacodegen.generators import CodeGenerator, DataclassGenerator -from .conftest import validate_code, requires_sqlalchemy_2_0 +from .conftest import requires_sqlalchemy_2_0, validate_code @requires_sqlalchemy_2_0 @@ -98,7 +98,7 @@ def test_onetomany_optional(self, generator: CodeGenerator) -> None: generator.generate(), """\ from typing import Optional - + from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column, relationship diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index b800205..ece59f3 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -19,7 +19,7 @@ from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator -from .conftest import validate_code, requires_sqlalchemy_1_4 +from .conftest import requires_sqlalchemy_1_4, validate_code @requires_sqlalchemy_1_4 diff --git a/tests/test_generator_declarative2.py b/tests/test_generator_declarative2.py index 39bb016..72b33cc 100644 --- a/tests/test_generator_declarative2.py +++ b/tests/test_generator_declarative2.py @@ -19,7 +19,7 @@ from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator -from .conftest import validate_code, requires_sqlalchemy_2_0 +from .conftest import requires_sqlalchemy_2_0, validate_code @requires_sqlalchemy_2_0 @@ -633,7 +633,7 @@ class SimpleItems(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - + t_container_items = Table( 'container_items', Base.metadata, Column('item_id', ForeignKey('simple_items.id')), @@ -674,7 +674,7 @@ class SimpleItems(Base): parent: Mapped[list['SimpleItems']] = relationship('SimpleItems', secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, secondaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, back_populates='child') child: Mapped[list['SimpleItems']] = relationship('SimpleItems', secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, back_populates='parent') - + t_child_items = Table( 'child_items', Base.metadata, Column('parent_id', ForeignKey('simple_items.id')), @@ -786,21 +786,21 @@ def test_joined_inheritance(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class SimpleSuperItems(Base): __tablename__ = 'simple_super_items' id: Mapped[int] = mapped_column(Integer, primary_key=True) data1: Mapped[Optional[int]] = mapped_column(Integer) - + class SimpleItems(SimpleSuperItems): __tablename__ = 'simple_items' super_item_id: Mapped[int] = mapped_column(ForeignKey('simple_super_items.id'), primary_key=True) data2: Mapped[Optional[int]] = mapped_column(Integer) - + class SimpleSubItems(SimpleItems): __tablename__ = 'simple_sub_items' @@ -831,13 +831,13 @@ def test_joined_inheritance_same_table_name(self, generator: CodeGenerator) -> N class Base(DeclarativeBase): pass - + class Simple(Base): __tablename__ = 'simple' id: Mapped[int] = mapped_column(Integer, primary_key=True) - + class Simple_(Simple): __tablename__ = 'simple' __table_args__ = {'schema': 'altschema'} @@ -863,13 +863,13 @@ def test_use_inflect(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class SimpleItem(Base): __tablename__ = 'simple_items' id: Mapped[int] = mapped_column(Integer, primary_key=True) - + class Singular(Base): __tablename__ = 'singular' @@ -921,7 +921,7 @@ def test_use_inflect_plural( class Base(DeclarativeBase): pass - + class {class_name.capitalize()}(Base): __tablename__ = '{table_name}' @@ -957,7 +957,7 @@ def test_table_kwargs(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class SimpleItems(Base): __tablename__ = 'simple_items' __table_args__ = {'schema': 'testschema'} @@ -982,14 +982,14 @@ def test_table_args_kwargs(self, generator: CodeGenerator) -> None: generator.generate(), """\ from typing import Optional - + from sqlalchemy import Index, Integer, String from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class Base(DeclarativeBase): pass - + class SimpleItems(Base): __tablename__ = 'simple_items' __table_args__ = ( @@ -1028,7 +1028,7 @@ def test_foreign_key_schema(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class OtherItems(Base): __tablename__ = 'other_items' __table_args__ = {'schema': 'otherschema'} @@ -1037,7 +1037,7 @@ class OtherItems(Base): simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', back_populates='other_item') - + class SimpleItems(Base): __tablename__ = 'simple_items' @@ -1069,7 +1069,7 @@ def test_invalid_attribute_names(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class SimpleItems(Base): __tablename__ = 'simple-items' @@ -1096,7 +1096,7 @@ def test_pascal(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class CustomerAPIPreference(Base): __tablename__ = 'CustomerAPIPreference' @@ -1120,7 +1120,7 @@ def test_underscore(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class CustomerApiPreference(Base): __tablename__ = 'customer_api_preference' @@ -1168,7 +1168,7 @@ def test_pascal_multiple_underscore(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class CustomerAPIPreference(Base): __tablename__ = 'customer_API__Preference' @@ -1223,7 +1223,7 @@ def test_table_comment(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class Simple(Base): __tablename__ = 'simple' __table_args__ = {'comment': "this is a 'comment'"} @@ -1251,7 +1251,7 @@ def test_metadata_column(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class Simple(Base): __tablename__ = 'simple' @@ -1324,7 +1324,7 @@ def test_named_constraints(self, generator: CodeGenerator) -> None: class Base(DeclarativeBase): pass - + class Simple(Base): __tablename__ = 'simple' __table_args__ = ( diff --git a/tests/test_generator_tables.py b/tests/test_generator_tables.py index 49b5659..38a5da5 100644 --- a/tests/test_generator_tables.py +++ b/tests/test_generator_tables.py @@ -5,6 +5,7 @@ import pytest from _pytest.fixtures import FixtureRequest from sqlalchemy.dialects import mysql, postgresql +from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, Column, @@ -19,13 +20,10 @@ from sqlalchemy.sql.expression import text from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text -from sqlalchemy.engine import Engine -from sqlacodegen.generators import ( - CodeGenerator, - TablesGenerator, -) -from .conftest import validate_code, requires_sqlalchemy_1_4 +from sqlacodegen.generators import CodeGenerator, TablesGenerator + +from .conftest import requires_sqlalchemy_1_4, validate_code @requires_sqlalchemy_1_4 diff --git a/tests/test_generator_tables2.py b/tests/test_generator_tables2.py index c816e85..796b4e8 100644 --- a/tests/test_generator_tables2.py +++ b/tests/test_generator_tables2.py @@ -5,6 +5,7 @@ import pytest from _pytest.fixtures import FixtureRequest from sqlalchemy.dialects import mysql, postgresql +from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, Column, @@ -19,13 +20,10 @@ from sqlalchemy.sql.expression import text from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text -from sqlalchemy.engine import Engine -from sqlacodegen.generators import ( - CodeGenerator, - TablesGenerator, -) -from .conftest import validate_code, requires_sqlalchemy_2_0 +from sqlacodegen.generators import CodeGenerator, TablesGenerator + +from .conftest import requires_sqlalchemy_2_0, validate_code @requires_sqlalchemy_2_0