Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 2, 2023
1 parent 962531f commit 9ecf6ae
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 52 deletions.
12 changes: 7 additions & 5 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class Base:
literal_imports: list[LiteralImport]
declarations: list[str]
metadata_ref: str
decorator: Optional[str] = None
decorator: str | None = None


class CodeGenerator(metaclass=ABCMeta):
Expand Down Expand Up @@ -1415,10 +1415,10 @@ def render_class_declaration(self, model: ModelClass) -> str:
if _sqla_version >= (2, 0):
return super().render_class_declaration(model)
else:
superclass_part = f"({model.parent_class.name})" if model.parent_class else ""
return (
f"@mapper_registry.mapped\n@dataclass\nclass {model.name}{superclass_part}:"
superclass_part = (
f"({model.parent_class.name})" if model.parent_class else ""
)
return f"@mapper_registry.mapped\n@dataclass\nclass {model.name}{superclass_part}:"

def render_class_variables(self, model: ModelClass) -> str:
if _sqla_version >= (2, 0):
Expand Down Expand Up @@ -1450,7 +1450,9 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
kwargs["default"] = None
python_type_name = f"Optional[{python_type_name}]"

rendered_column = self.render_column(column, column_attr.name != column.name)
rendered_column = self.render_column(
column, column_attr.name != column.name
)
kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered_column}}}"
rendered_field = render_callable("field", kwargs=kwargs)
return f"{column_attr.name}: {python_type_name} = {rendered_field}"
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.orm import clear_mappers, configure_mappers
from sqlalchemy.schema import MetaData

from sqlacodegen.generators import _sqla_version


def validate_code(generated_code: str, expected_code: str) -> None:
expected_code = dedent(expected_code)
assert generated_code == expected_code
Expand Down
19 changes: 11 additions & 8 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import pytest

from .conftest import requires_sqlalchemy_1_4
from sqlacodegen.generators import _sqla_version

from .conftest import requires_sqlalchemy_1_4

if sys.version_info < (3, 8):
from importlib_metadata import version
else:
Expand Down Expand Up @@ -82,8 +83,8 @@ def test_cli_declarative(db_path: Path, tmp_path: Path) -> None:

if _sqla_version < (2, 0):
assert (
output_path.read_text()
== f"""\
output_path.read_text()
== f"""\
from sqlalchemy import Column, Integer, Text
from {declarative_package} import declarative_base
Expand All @@ -98,9 +99,9 @@ class Foo(Base):
"""
)
else:
assert (
output_path.read_text()
== f"""\
assert (
output_path.read_text()
== f"""\
from sqlalchemy import Integer, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
Expand All @@ -114,7 +115,7 @@ class Foo(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(Text)
"""
)
)


def test_cli_dataclass(db_path: Path, tmp_path: Path) -> None:
Expand Down Expand Up @@ -169,7 +170,9 @@ class Foo(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(Text)
""")
"""
)


@requires_sqlalchemy_1_4
def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_generator_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from sqlacodegen.generators import CodeGenerator, DataclassGenerator

from .conftest import validate_code, requires_sqlalchemy_1_4
from .conftest import requires_sqlalchemy_1_4, validate_code


@requires_sqlalchemy_1_4
Expand Down
4 changes: 2 additions & 2 deletions tests/test_generator_dataclass2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from sqlacodegen.generators import CodeGenerator, DataclassGenerator

from .conftest import validate_code, requires_sqlalchemy_2_0
from .conftest import requires_sqlalchemy_2_0, validate_code


@requires_sqlalchemy_2_0
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_onetomany_optional(self, generator: CodeGenerator) -> None:
generator.generate(),
"""\
from typing import Optional
from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column, relationship
Expand Down
2 changes: 1 addition & 1 deletion tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator

from .conftest import validate_code, requires_sqlalchemy_1_4
from .conftest import requires_sqlalchemy_1_4, validate_code


@requires_sqlalchemy_1_4
Expand Down
46 changes: 23 additions & 23 deletions tests/test_generator_declarative2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator

from .conftest import validate_code, requires_sqlalchemy_2_0
from .conftest import requires_sqlalchemy_2_0, validate_code


@requires_sqlalchemy_2_0
Expand Down Expand Up @@ -633,7 +633,7 @@ class SimpleItems(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
t_container_items = Table(
'container_items', Base.metadata,
Column('item_id', ForeignKey('simple_items.id')),
Expand Down Expand Up @@ -674,7 +674,7 @@ class SimpleItems(Base):
parent: Mapped[list['SimpleItems']] = relationship('SimpleItems', secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, secondaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, back_populates='child')
child: Mapped[list['SimpleItems']] = relationship('SimpleItems', secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, back_populates='parent')
t_child_items = Table(
'child_items', Base.metadata,
Column('parent_id', ForeignKey('simple_items.id')),
Expand Down Expand Up @@ -786,21 +786,21 @@ def test_joined_inheritance(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class SimpleSuperItems(Base):
__tablename__ = 'simple_super_items'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
data1: Mapped[Optional[int]] = mapped_column(Integer)
class SimpleItems(SimpleSuperItems):
__tablename__ = 'simple_items'
super_item_id: Mapped[int] = mapped_column(ForeignKey('simple_super_items.id'), primary_key=True)
data2: Mapped[Optional[int]] = mapped_column(Integer)
class SimpleSubItems(SimpleItems):
__tablename__ = 'simple_sub_items'
Expand Down Expand Up @@ -831,13 +831,13 @@ def test_joined_inheritance_same_table_name(self, generator: CodeGenerator) -> N
class Base(DeclarativeBase):
pass
class Simple(Base):
__tablename__ = 'simple'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
class Simple_(Simple):
__tablename__ = 'simple'
__table_args__ = {'schema': 'altschema'}
Expand All @@ -863,13 +863,13 @@ def test_use_inflect(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class SimpleItem(Base):
__tablename__ = 'simple_items'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
class Singular(Base):
__tablename__ = 'singular'
Expand Down Expand Up @@ -921,7 +921,7 @@ def test_use_inflect_plural(
class Base(DeclarativeBase):
pass
class {class_name.capitalize()}(Base):
__tablename__ = '{table_name}'
Expand Down Expand Up @@ -957,7 +957,7 @@ def test_table_kwargs(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class SimpleItems(Base):
__tablename__ = 'simple_items'
__table_args__ = {'schema': 'testschema'}
Expand All @@ -982,14 +982,14 @@ def test_table_args_kwargs(self, generator: CodeGenerator) -> None:
generator.generate(),
"""\
from typing import Optional
from sqlalchemy import Index, Integer, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class SimpleItems(Base):
__tablename__ = 'simple_items'
__table_args__ = (
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_foreign_key_schema(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class OtherItems(Base):
__tablename__ = 'other_items'
__table_args__ = {'schema': 'otherschema'}
Expand All @@ -1037,7 +1037,7 @@ class OtherItems(Base):
simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', back_populates='other_item')
class SimpleItems(Base):
__tablename__ = 'simple_items'
Expand Down Expand Up @@ -1069,7 +1069,7 @@ def test_invalid_attribute_names(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class SimpleItems(Base):
__tablename__ = 'simple-items'
Expand All @@ -1096,7 +1096,7 @@ def test_pascal(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class CustomerAPIPreference(Base):
__tablename__ = 'CustomerAPIPreference'
Expand All @@ -1120,7 +1120,7 @@ def test_underscore(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class CustomerApiPreference(Base):
__tablename__ = 'customer_api_preference'
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def test_pascal_multiple_underscore(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class CustomerAPIPreference(Base):
__tablename__ = 'customer_API__Preference'
Expand Down Expand Up @@ -1223,7 +1223,7 @@ def test_table_comment(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class Simple(Base):
__tablename__ = 'simple'
__table_args__ = {'comment': "this is a 'comment'"}
Expand Down Expand Up @@ -1251,7 +1251,7 @@ def test_metadata_column(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class Simple(Base):
__tablename__ = 'simple'
Expand Down Expand Up @@ -1324,7 +1324,7 @@ def test_named_constraints(self, generator: CodeGenerator) -> None:
class Base(DeclarativeBase):
pass
class Simple(Base):
__tablename__ = 'simple'
__table_args__ = (
Expand Down
10 changes: 4 additions & 6 deletions tests/test_generator_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy.dialects import mysql, postgresql
from sqlalchemy.engine import Engine
from sqlalchemy.schema import (
CheckConstraint,
Column,
Expand All @@ -19,13 +20,10 @@
from sqlalchemy.sql.expression import text
from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text
from sqlalchemy.engine import Engine

from sqlacodegen.generators import (
CodeGenerator,
TablesGenerator,
)
from .conftest import validate_code, requires_sqlalchemy_1_4
from sqlacodegen.generators import CodeGenerator, TablesGenerator

from .conftest import requires_sqlalchemy_1_4, validate_code


@requires_sqlalchemy_1_4
Expand Down
10 changes: 4 additions & 6 deletions tests/test_generator_tables2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy.dialects import mysql, postgresql
from sqlalchemy.engine import Engine
from sqlalchemy.schema import (
CheckConstraint,
Column,
Expand All @@ -19,13 +20,10 @@
from sqlalchemy.sql.expression import text
from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text
from sqlalchemy.engine import Engine

from sqlacodegen.generators import (
CodeGenerator,
TablesGenerator,
)
from .conftest import validate_code, requires_sqlalchemy_2_0
from sqlacodegen.generators import CodeGenerator, TablesGenerator

from .conftest import requires_sqlalchemy_2_0, validate_code


@requires_sqlalchemy_2_0
Expand Down

0 comments on commit 9ecf6ae

Please sign in to comment.